diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 4bc557dd0..74fe9ae8b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -38,6 +38,7 @@ from pyannote.core import SlidingWindow from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary from torch.utils.data import DataLoader +from torch_audiomentations.core.composition import Compose from pyannote.audio import __version__ from pyannote.audio.core.io import Audio @@ -256,6 +257,8 @@ def on_save_checkpoint(self, checkpoint): "specifications": self.specifications, } + self.task.on_save_checkpoint(checkpoint) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]): check_version( "pyannote.audio", @@ -525,6 +528,7 @@ def from_pretrained( subfolder: Optional[str] = None, token: Union[str, bool, None] = None, cache_dir: Union[Path, str, None] = None, + protocol: Union[Protocol, None] = None, **kwargs, ) -> Optional["Model"]: """Load pretrained model @@ -548,6 +552,8 @@ def from_pretrained( Token to be used for the download. cache_dir: Path or str, optional Path to the folder where cached files are stored. + protocol: Protocol, optional + Protocol used to train the model. Needed to continue training. kwargs: optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -627,4 +633,54 @@ def default_map_location(storage, loc): raise e + # init task from the checkpoint, if any + if protocol and "task" in loaded_checkpoint["pyannote.audio"]: + task_module_name: str = loaded_checkpoint["pyannote.audio"]["task"]["module"] + task_module = import_module(task_module_name) + task_class_name: str = loaded_checkpoint["pyannote.audio"]["task"]["class"] + task_hparams = loaded_checkpoint["pyannote.audio"]["task"]["hyper_parameters"] + + TaskClass = getattr(task_module, task_class_name) + + # instantiate task augmentation + def instantiate_transform(transform_data): + transform_module = import_module(transform_data["module"]) + transform_class = transform_data["class"] + transform_kwargs = transform_data["kwargs"] + TransformClass = getattr(transform_module, transform_class) + return TransformClass(**transform_kwargs) + + augmentation_data = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] + # BaseWaveformTransform case + if isinstance(augmentation_data, Dict): + task_hparams["augmentation"] = instantiate_transform(augmentation_data) + + # Compose transform case + elif isinstance(augmentation_data , List): + transforms = [] + for transform_data in augmentation_data: + transform = instantiate_transform(transform_data) + transforms.append(transform) + + task_hparams["augmentation"] = Compose(transforms=transforms, output_type="dict") + + # instanciate task metrics + metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] + if metrics: + metric = {} + for metadata in metrics: + metric_module = import_module(metadata["module"]) + metric_class = metadata["class"] + metric_kwargs = metadata["kwargs"] + + MetricClass = getattr(metric_module, metric_class) + metric[metric_class] = MetricClass(**metric_kwargs) + else: + metric = None + + task_hparams["metric"] = metric + + # instanciate training task + model.task = TaskClass(protocol, **task_hparams) + return model diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 7b35adce1..eeaff89ca 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,7 +23,9 @@ from __future__ import annotations +import inspect import itertools +import json import multiprocessing import sys import warnings @@ -45,6 +47,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torch_audiomentations import Identity from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations.core.composition import BaseCompose from torchmetrics import Metric, MetricCollection from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss @@ -327,6 +330,7 @@ def prepare_data(self): 'metadata-values': dict of lists of values for subset, scope and database 'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array. 'metadata-labels': array of global scope labels + 'task-parameters': hyper-parameters used for the task } """ @@ -595,6 +599,20 @@ def prepare_data(self): prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) unique_labels.clear() + # keep track of task hyperparameters + parameters = [] + dtype = [] + for param_name, param_value in self.hparams.items(): + if isinstance(param_value, (bool, float, int, str, type(None))): + parameters.append(param_value) + dtype.append((param_name, type(param_value))) + + prepared_data["task-parameters"] = np.array( + tuple(parameters), dtype=np.dtype(dtype) + ) + parameters.clear() + dtype.clear() + if self.has_validation: self.prepare_validation(prepared_data) @@ -646,6 +664,18 @@ def setup(self, stage=None): f"does not correspond to the cached one ({self.prepared_data['protocol']})" ) + # checks that the task current hyperparameters matches the cached ones + for param_name, param_value in self.hparams.items(): + if param_name not in self.prepared_data["task-parameters"].dtype.names: + continue + cached_value = self.prepared_data["task-parameters"][param_name] + if param_value != cached_value: + warnings.warn( + f"Value specified for the task hyperparameter {param_name} differs from the one in the cached data." + f"Current value = {param_value}, cached value = {cached_value}." + "You may need to create a new cache with the new value for this hyperparameter.", + ) + @property def automatic_optimization(self) -> bool: return self.model.automatic_optimization @@ -878,3 +908,55 @@ def val_monitor(self): name, metric = next(iter(self.metric.items())) return name, "max" if metric.higher_is_better else "min" + + def on_save_checkpoint(self, checkpoint): + checkpoint["pyannote.audio"]["task"] = { + "module": self.__class__.__module__, + "class": self.__class__.__name__, + "hyper_parameters": self.hparams, + } + + def serialize_object(obj: Any) -> Dict: + serialized_obj = { + "module": obj.__class__.__module__, + "class": obj.__class__.__name__, + "kwargs": {}, + } + + for param in inspect.signature(obj.__init__).parameters: + param_value = getattr(obj, param, None) + if isinstance(param_value, (bool, float, int, list, dict, str, type(None))): + serialized_obj["kwargs"][param] = param_value + else: + msg = f"Cannot serialize {obj.__class__.__name__}.{param}. This parameter will not be saved in model checkpoint." + warnings.warn(msg, RuntimeWarning) + + return serialized_obj + + # save augmentation: + if not self.augmentation: + checkpoint["pyannote.audio"]["task"]["augmentation"] = None + elif isinstance(self.augmentation, BaseWaveformTransform): + checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_object( + self.augmentation + ) + elif isinstance(self.augmentation, BaseCompose): + checkpoint["pyannote.audio"]["task"]["augmentation"] = [] + for augmentation in self.augmentation.transforms: + checkpoint["pyannote.audio"]["task"]["augmentation"].append( + serialize_object(augmentation) + ) + + # save metrics: + if isinstance(self.metric, Metric): + checkpoint["pyannote.audio"]["task"]["metrics"] = [ + json.dumps(self.metric, default=serialize_object) + ] + elif isinstance(self.metric, MetricCollection): + checkpoint["pyannote.audio"]["task"]["metrics"] = [] + for metric in self.metric.values(): + checkpoint["pyannote.audio"]["task"]["metrics"].append( + serialize_object(metric) + ) + else: + checkpoint["pyannote.audio"]["task"]["metrics"] = None diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index cb6401e2b..5bc2f9f06 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -111,6 +111,8 @@ def __init__( metric=metric, ) + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + def setup_loss_func(self): _, embedding_size = self.model(self.model.example_input_array).shape diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 9184121c4..7112c2c23 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -125,6 +125,8 @@ def __init__( self.weight = weight self.classes = classes + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + # task specification depends on the data: we do not know in advance which # classes should be detected. therefore, we postpone the definition of # specifications to setup() diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 89d299a8d..0503fc33a 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -142,6 +142,9 @@ def __init__( self.balance = balance self.weight = weight + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk for overlapped speech detection diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index cff3d93fd..cb2a05156 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -165,6 +165,8 @@ def __init__( self.balance = balance self.weight = weight + self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"]) + def setup(self, stage=None): super().setup(stage) diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index e52613aeb..57defa88c 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -123,6 +123,8 @@ def __init__( ], ) + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk for voice activity detection diff --git a/pyannote/audio/tasks/separation/PixIT.py b/pyannote/audio/tasks/separation/PixIT.py index 88c3495ee..362e146e9 100644 --- a/pyannote/audio/tasks/separation/PixIT.py +++ b/pyannote/audio/tasks/separation/PixIT.py @@ -221,6 +221,8 @@ def __init__( self.separation_loss_weight = separation_loss_weight self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"]) + def setup(self, stage=None): super().setup(stage)