Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manual optimization to core task #1796

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from torchmetrics import Metric, MetricCollection

from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss
from pyannote.audio.utils.params import merge_dict
from pyannote.audio.utils.protocol import check_protocol

Subsets = list(Subset.__args__)
Expand Down Expand Up @@ -231,6 +232,9 @@ class Task(pl.LightningDataModule):
If True, data loaders will copy tensors into CUDA pinned
memory before returning them. See pytorch documentation
for more details. Defaults to False.
gradient: dict, optional
Keywords arguments for gradient calculation.
Defaults to {"clip_val": 5.0, "clip_algorithm": "norm", "accumulate_batches": 1}
augmentation : BaseWaveformTransform, optional
torch_audiomentations waveform transform, used by dataloader
during training.
Expand All @@ -245,6 +249,12 @@ class Task(pl.LightningDataModule):

"""

GRADIENT_DEFAULTS = {
"clip_val": 5.0,
"clip_algorithm": "norm",
"accumulate_batches": 1,
Comment on lines +253 to +255
Copy link
Member

@hbredin hbredin Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we instead try to grab these values from trainer options directly?
Maybe trainer is exposed as an attribute of model?

Copy link
Member

@hbredin hbredin Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like you could use something like

self.model.trainer.{accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing accumulate_grad_batches, gradient_clip_val or gradient_clip_algorithm to the Trainer will raise a MisconfigurationException if we set automatic_optimization=False.

lightning_fabric.utilities.exceptions.MisconfigurationException: Automatic gradient clipping is not supported for manual optimization. Remove `Trainer(gradient_clip_val=5.0)` or switch to automatic optimization.

}

def __init__(
self,
protocol: Protocol,
Expand All @@ -255,6 +265,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand Down Expand Up @@ -302,6 +313,7 @@ def __init__(

self.num_workers = num_workers
self.pin_memory = pin_memory
self.gradient = merge_dict(self.GRADIENT_DEFAULTS, gradient)
self.augmentation = augmentation or Identity(output_type="dict")
self._metric = metric

Expand Down Expand Up @@ -810,6 +822,45 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]):
# can obviously be overriden for each task
def training_step(self, batch, batch_idx: int):
return self.common_step(batch, batch_idx, "train")

def manual_optimization(self, loss: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""Process manual optimization for each optimizer

Parameters
----------
loss: torch.Tensor
Computed loss for current training step.
batch_idx: int
Batch index.

Returns
-------
scaled_loss: torch.Tensor
Loss scaled by `1 / Task.gradient["accumulate_batches"]`.
"""
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]

num_accumulate_batches = self.gradient["accumulate_batches"]
if batch_idx % num_accumulate_batches == 0:
for optimizer in optimizers:
optimizer.zero_grad()

# scale loss to keep the gradient magnitude as it would be using batches
# with size = batch_size * num_accumulate_batches
scaled_loss = loss / num_accumulate_batches
self.model.manual_backward(scaled_loss)

if (batch_idx + 1) % num_accumulate_batches == 0:
for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.gradient["clip_val"],
gradient_clip_algorithm=self.gradient["clip_algorithm"],
)
optimizer.step()

return scaled_loss

def val__getitem__(self, idx):
# will become val_dataset.__getitem__ method
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/embedding/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
scale: float = 64.0,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -106,6 +107,7 @@ def __init__(
min_duration=min_duration,
batch_size=self.batch_size,
num_workers=num_workers,
gradient=gradient,
pin_memory=pin_memory,
augmentation=augmentation,
metric=metric,
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -116,6 +117,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
cache: Optional[Union[str, None]] = None,
Expand All @@ -122,6 +123,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
17 changes: 3 additions & 14 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
max_num_speakers: Optional[
Expand All @@ -132,6 +133,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down Expand Up @@ -466,20 +468,7 @@ def training_step(self, batch, batch_idx: int):
)

if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=5.0,
gradient_clip_algorithm="norm",
)
optimizer.step()
loss = self.manual_optimization(loss, batch_idx)

return {"loss": loss}

Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/voice_activity_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -104,6 +105,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
18 changes: 3 additions & 15 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
max_num_speakers: Optional[
Expand All @@ -185,6 +186,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down Expand Up @@ -1009,22 +1011,8 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

# using multiple optimizers requires manual optimization
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
optimizer.step()
loss = self.manual_optimization(loss, batch_idx)

return {"loss": loss}

Expand Down