From 329abd934a781dd9a6014fa3eb077e19e7f02cec Mon Sep 17 00:00:00 2001 From: Simon Larsen Date: Mon, 16 Dec 2024 15:44:56 +0100 Subject: [PATCH] Replace Ignite with custom event system (#5) * Remove Torch Ignite, replace with own engine/event mechanism * Add GAN pipeline and example project * Update README.md --- .gitignore | 25 +- MANIFEST.in | 4 +- README.md | 7 +- examples/gan/configs/example.json | 11 +- examples/gan/datasets/example.py | 8 +- examples/gan/models/blocks.py | 49 --- examples/gan/models/discriminator.py | 12 +- examples/gan/models/upscaler.py | 57 ++- examples/gan/train.py | 24 +- examples/supervised/configs/example.json | 10 +- examples/supervised/datasets/example.py | 8 +- .../models/{example.py => upscaler.py} | 46 +-- examples/supervised/service.py | 43 --- examples/supervised/train.py | 29 +- frogbox/__init__.py | 7 +- frogbox/callbacks/__init__.py | 9 +- frogbox/callbacks/callback.py | 7 + frogbox/callbacks/ema.py | 77 ---- frogbox/callbacks/image_logger.py | 330 ++++++++---------- frogbox/cli.py | 266 +------------- frogbox/config.py | 174 +++++---- frogbox/data/Dockerfile | 28 -- frogbox/data/service.py | 18 - frogbox/data/train_gan.py | 42 --- frogbox/engines/engine.py | 154 ++++++++ frogbox/engines/events.py | 64 ++++ frogbox/engines/gan.py | 231 ++++++------ frogbox/engines/supervised.py | 237 +++++-------- frogbox/handlers/__init__.py | 0 frogbox/handlers/checkpoint.py | 165 +++++++++ frogbox/handlers/composite_loss_logger.py | 30 ++ frogbox/handlers/metric_logger.py | 43 +++ frogbox/handlers/output_logger.py | 21 ++ frogbox/pipelines/__init__.py | 1 + frogbox/pipelines/gan.py | 310 ++++++---------- frogbox/pipelines/logger.py | 102 ------ .../pipelines/{common.py => lr_scheduler.py} | 29 +- frogbox/pipelines/pipeline.py | 237 ++++++------- frogbox/pipelines/save_handler.py | 77 ---- frogbox/pipelines/supervised.py | 206 +++++------ frogbox/service.py | 68 ---- frogbox/utils.py | 41 +-- mypy.ini | 3 - requirements-dev.txt | 10 +- requirements.txt | 19 +- 45 files changed, 1390 insertions(+), 1949 deletions(-) delete mode 100644 examples/gan/models/blocks.py rename examples/supervised/models/{example.py => upscaler.py} (63%) delete mode 100644 examples/supervised/service.py create mode 100644 frogbox/callbacks/callback.py delete mode 100644 frogbox/callbacks/ema.py delete mode 100644 frogbox/data/Dockerfile delete mode 100644 frogbox/data/service.py delete mode 100644 frogbox/data/train_gan.py create mode 100644 frogbox/engines/engine.py create mode 100644 frogbox/engines/events.py create mode 100644 frogbox/handlers/__init__.py create mode 100644 frogbox/handlers/checkpoint.py create mode 100644 frogbox/handlers/composite_loss_logger.py create mode 100644 frogbox/handlers/metric_logger.py create mode 100644 frogbox/handlers/output_logger.py delete mode 100644 frogbox/pipelines/logger.py rename frogbox/pipelines/{common.py => lr_scheduler.py} (65%) delete mode 100644 frogbox/pipelines/save_handler.py delete mode 100644 frogbox/service.py diff --git a/.gitignore b/.gitignore index 6fbec81..0035888 100644 --- a/.gitignore +++ b/.gitignore @@ -106,8 +106,10 @@ ipython_config.py #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. -# https://pdm.fming.dev/#use-with-ide +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml +.pdm-python +.pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -159,22 +161,13 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -# pdoc -docs/*.html -docs/*.js -docs/frogbox/ - # Virtual environment symlink venv -# Training data -wandb/ -checkpoints/ - # Example projects -examples/supervised/wandb/ -examples/supervised/checkpoints/ -examples/supervised/data/ -examples/gan/wandb/ -examples/gan/checkpoints/ -examples/gan/data/ \ No newline at end of file +examples/supervised/data +examples/supervised/checkpoints +examples/supervised/wandb +examples/gan/data +examples/gan/checkpoints +examples/gan/wandb diff --git a/MANIFEST.in b/MANIFEST.in index ff3c43f..7c4eff4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1 @@ -include frogbox/data/*.json -include frogbox/data/*.py -include frogbox/data/Dockerfile +include frogbox/data/*.py \ No newline at end of file diff --git a/README.md b/README.md index d9194bd..c598edb 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,20 @@ frogbox -Frogbox is an opinionated machine learning framework for PyTorch built for rapid prototyping and research. +Frogbox is an opinionated PyTorch machine learning framework built for rapid prototyping and research. ## Features -* Built around [Torch Ignite](https://pytorch-ignite.ai) and [Accelerate](https://huggingface.co/docs/accelerate/index) to support automatic mixed precision (AMP) and distributed training. * Experiments are defined using JSON files and support [jinja2](https://jinja.palletsprojects.com) templates. +* Flexible event system inspired by [Ignite](https://pytorch.org/ignite). * Automatic experiment tracking. Currently only [Weights & Biases](https://wandb.ai/) is supported with other platforms planned. * CLI tool for easy project management. Just type `frogbox project new -t supervised` to get started. +* Integrates [Accelerate](https://huggingface.co/docs/accelerate/index) to support automatic mixed precision (AMP) and distributed training. ## Installation ```sh -pip install git+https://SimonLarsen@github.com/SimonLarsen/frogbox.git@v0.3.3 +pip install git+https://SimonLarsen@github.com/SimonLarsen/frogbox.git@v0.5.0 ``` ## Getting started diff --git a/examples/gan/configs/example.json b/examples/gan/configs/example.json index de8b9a8..653176b 100644 --- a/examples/gan/configs/example.json +++ b/examples/gan/configs/example.json @@ -1,19 +1,17 @@ { "type": "gan", "project": "frogbox-example", - "amp": true, "batch_size": 32, "loader_workers": 4, "max_epochs": 16, + "log_interval": "epoch_completed", "checkpoints": [ { "metric": "PSNR", "mode": "max" } ], - "log_interval": "epoch_completed", "datasets": {% include 'datasets.json' %}, - "disc_update_interval": 2, "model": { "class_name": "models.upscaler.Upscaler", "params": { @@ -30,8 +28,7 @@ "params": { "in_channels": 3, "hidden_channels": 32, - "num_blocks": 8, - "activation": "silu" + "num_blocks": 8 } }, "losses": { @@ -52,11 +49,11 @@ }, "metrics": { "SSIM": { - "class_name": "ignite.metrics.SSIM", + "class_name": "torchmetrics.image.StructuralSimilarityIndexMeasure", "params": {"data_range": 1.0} }, "PSNR": { - "class_name": "ignite.metrics.PSNR", + "class_name": "torchmetrics.image.PeakSignalNoiseRatio", "params": {"data_range": 1.0} } }, diff --git a/examples/gan/datasets/example.py b/examples/gan/datasets/example.py index c16b5b8..1b373d0 100644 --- a/examples/gan/datasets/example.py +++ b/examples/gan/datasets/example.py @@ -65,13 +65,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: image, _ = self.data[idx] image = pil_to_tensor(image) / 255 - X = resize(image, size=16, antialias=True) + x = resize(image, size=16, antialias=True) y = image if self.do_augment: - X = self.augment(X) + x = self.augment(x) if self.do_normalize: - X = self.normalize(X) + x = self.normalize(x) - return X, y + return x, y diff --git a/examples/gan/models/blocks.py b/examples/gan/models/blocks.py deleted file mode 100644 index 1b7ed47..0000000 --- a/examples/gan/models/blocks.py +++ /dev/null @@ -1,49 +0,0 @@ -from torch import nn - - -def get_activation(name: str) -> nn.Module: - name = name.lower() - if name == "relu": - return nn.ReLU() - elif name == "gelu": - return nn.GELU() - elif name == "sigmoid": - return nn.Sigmoid() - elif name in ("swish", "silu"): - return nn.SiLU() - elif name == "mish": - return nn.Mish() - else: - raise ValueError(f"Unsupported activation function '{name}'.") - - -class ResidualBlock(nn.Module): - def __init__( - self, - channels: int, - norm_groups: int = 4, - norm_eps: float = 1e-5, - activation: str = "relu", - ): - super().__init__() - - self.act = get_activation(activation) - - self.norm1 = nn.GroupNorm(norm_groups, channels, norm_eps) - self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1) - - self.norm2 = nn.GroupNorm(norm_groups, channels, norm_eps) - self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1) - - def forward(self, x): - identity = x - - x = self.norm1(x) - x = self.act(x) - x = self.conv1(x) - - x = self.norm2(x) - x = self.act(x) - x = self.conv2(x) - - return x + identity diff --git a/examples/gan/models/discriminator.py b/examples/gan/models/discriminator.py index 206219e..5adde8a 100644 --- a/examples/gan/models/discriminator.py +++ b/examples/gan/models/discriminator.py @@ -1,17 +1,12 @@ from torch import nn from torch.nn.utils import spectral_norm -from .blocks import get_activation class SNResidualBlock(nn.Module): - def __init__( - self, - channels: int, - activation: str = "relu", - ): + def __init__(self, channels: int): super().__init__() - self.act = get_activation(activation) + self.act = nn.GELU() self.conv1 = spectral_norm( nn.Conv2d(channels, channels, 3, 1, 1, bias=False) ) @@ -37,7 +32,6 @@ def __init__( in_channels: int = 3, hidden_channels: int = 32, num_blocks: int = 8, - activation: str = "silu", ): super().__init__() @@ -45,7 +39,7 @@ def __init__( blocks = [] for _ in range(num_blocks): - blocks.append(SNResidualBlock(hidden_channels, activation)) + blocks.append(SNResidualBlock(hidden_channels)) self.blocks = nn.Sequential(*blocks) self.conv_out = nn.Conv2d(hidden_channels, 1, 3, 1, 1) diff --git a/examples/gan/models/upscaler.py b/examples/gan/models/upscaler.py index 6366576..00fcaf9 100644 --- a/examples/gan/models/upscaler.py +++ b/examples/gan/models/upscaler.py @@ -1,6 +1,36 @@ import math from torch import nn -from .blocks import get_activation, ResidualBlock + + +class ResidualBlock(nn.Module): + def __init__( + self, + channels: int, + norm_groups: int = 4, + norm_eps: float = 1e-5, + ): + super().__init__() + + self.act = nn.GELU() + + self.norm1 = nn.GroupNorm(norm_groups, channels, norm_eps) + self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1) + + self.norm2 = nn.GroupNorm(norm_groups, channels, norm_eps) + self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1) + + def forward(self, x): + identity = x + + x = self.norm1(x) + x = self.act(x) + x = self.conv1(x) + + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + + return x + identity class Upscaler(nn.Module): @@ -12,23 +42,19 @@ def __init__( hidden_channels: int = 32, num_layers: int = 4, norm_groups: int = 4, - activation: str = "gelu", ): super().__init__() self.conv_in = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1) - features = [] + self.blocks = nn.ModuleList() for _ in range(num_layers): - features.append( + self.blocks.append( ResidualBlock( channels=hidden_channels, norm_groups=norm_groups, - activation=activation, ) ) - features.append(nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1)) - self.features = nn.Sequential(*features) upsample = [] for _ in range(int(math.log2(scale_factor))): @@ -36,18 +62,21 @@ def __init__( upsample.append( nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1) ) - upsample.append(get_activation(activation)) + upsample.append(nn.GELU()) self.upsample = nn.Sequential(*upsample) self.conv_out = nn.Sequential( nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1), - get_activation(activation), + nn.GELU(), nn.Conv2d(hidden_channels, out_channels, 3, 1, 1), ) def forward(self, x): - x = self.conv_in(x) - x = self.features(x) - x = self.upsample(x) - x = self.conv_out(x) - return x + h = self.conv_in(x) + for block in self.blocks: + h = block(h) + h + h = self.upsample(h) + h = self.conv_out(h) + + x = nn.functional.interpolate(x, h.shape[-2:], mode="bilinear") + return nn.functional.sigmoid(x + h) diff --git a/examples/gan/train.py b/examples/gan/train.py index 6061cf4..fc4a1e4 100644 --- a/examples/gan/train.py +++ b/examples/gan/train.py @@ -1,8 +1,8 @@ -from typing import Optional, Sequence +from typing import cast, Optional, Sequence from pathlib import Path import argparse -from frogbox import read_json_config, GANPipeline, Events -from frogbox.callbacks import create_image_logger +from frogbox import read_json_config, GANPipeline, GANConfig +from frogbox.callbacks import ImageLogger def parse_arguments( @@ -28,7 +28,7 @@ def parse_arguments( if __name__ == "__main__": args = parse_arguments() - config = read_json_config(args.config) + config = cast(GANConfig, read_json_config(args.config)) pipeline = GANPipeline( config=config, @@ -40,16 +40,12 @@ def parse_arguments( group=args.group, ) - dataset_params = config.datasets["test"].params - image_logger = create_image_logger( - split="test", - normalize_mean=dataset_params["normalize_mean"], - normalize_std=dataset_params["normalize_std"], - denormalize_input=dataset_params["do_normalize"], - ) - pipeline.install_callback( - event=Events.EPOCH_COMPLETED, - callback=image_logger, + ds_conf = config.datasets["train"].params + image_logger = ImageLogger( + denormalize_input=True, + normalize_mean=ds_conf["normalize_mean"], + normalize_std=ds_conf["normalize_std"], ) + pipeline.install_callback(pipeline.log_interval, image_logger) pipeline.run() diff --git a/examples/supervised/configs/example.json b/examples/supervised/configs/example.json index 56091a2..1bc7fa4 100644 --- a/examples/supervised/configs/example.json +++ b/examples/supervised/configs/example.json @@ -1,7 +1,7 @@ { "type": "supervised", "project": "frogbox-example", - "batch_size": 16, + "batch_size": 32, "loader_workers": 4, "max_epochs": 16, "log_interval": "epoch_completed", @@ -10,7 +10,7 @@ "n_saved": 3, "interval": { "event": "iteration_completed", - "every": 500 + "every": 1000 } }, { @@ -21,7 +21,7 @@ ], "datasets": {% include 'datasets.json' %}, "model": { - "class_name": "models.example.ExampleModel", + "class_name": "models.upscaler.Upscaler", "params": { "scale_factor": 2, "in_channels": 3, @@ -39,11 +39,11 @@ }, "metrics": { "SSIM": { - "class_name": "ignite.metrics.SSIM", + "class_name": "torchmetrics.image.StructuralSimilarityIndexMeasure", "params": {"data_range": 1.0} }, "PSNR": { - "class_name": "ignite.metrics.PSNR", + "class_name": "torchmetrics.image.PeakSignalNoiseRatio", "params": {"data_range": 1.0} } }, diff --git a/examples/supervised/datasets/example.py b/examples/supervised/datasets/example.py index c16b5b8..1b373d0 100644 --- a/examples/supervised/datasets/example.py +++ b/examples/supervised/datasets/example.py @@ -65,13 +65,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: image, _ = self.data[idx] image = pil_to_tensor(image) / 255 - X = resize(image, size=16, antialias=True) + x = resize(image, size=16, antialias=True) y = image if self.do_augment: - X = self.augment(X) + x = self.augment(x) if self.do_normalize: - X = self.normalize(X) + x = self.normalize(x) - return X, y + return x, y diff --git a/examples/supervised/models/example.py b/examples/supervised/models/upscaler.py similarity index 63% rename from examples/supervised/models/example.py rename to examples/supervised/models/upscaler.py index 7c804ab..00fcaf9 100644 --- a/examples/supervised/models/example.py +++ b/examples/supervised/models/upscaler.py @@ -2,33 +2,16 @@ from torch import nn -def get_activation(name: str) -> nn.Module: - name = name.lower() - if name == "relu": - return nn.ReLU() - elif name == "gelu": - return nn.GELU() - elif name == "sigmoid": - return nn.Sigmoid() - elif name in ("swish", "silu"): - return nn.SiLU() - elif name == "mish": - return nn.Mish() - else: - raise ValueError(f"Unsupported activation function '{name}'.") - - class ResidualBlock(nn.Module): def __init__( self, channels: int, norm_groups: int = 4, norm_eps: float = 1e-5, - activation: str = "relu", ): super().__init__() - self.act = get_activation(activation) + self.act = nn.GELU() self.norm1 = nn.GroupNorm(norm_groups, channels, norm_eps) self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1) @@ -50,7 +33,7 @@ def forward(self, x): return x + identity -class ExampleModel(nn.Module): +class Upscaler(nn.Module): def __init__( self, scale_factor: int = 2, @@ -59,23 +42,19 @@ def __init__( hidden_channels: int = 32, num_layers: int = 4, norm_groups: int = 4, - activation: str = "gelu", ): super().__init__() self.conv_in = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1) - features = [] + self.blocks = nn.ModuleList() for _ in range(num_layers): - features.append( + self.blocks.append( ResidualBlock( channels=hidden_channels, norm_groups=norm_groups, - activation=activation, ) ) - features.append(nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1)) - self.features = nn.Sequential(*features) upsample = [] for _ in range(int(math.log2(scale_factor))): @@ -83,18 +62,21 @@ def __init__( upsample.append( nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1) ) - upsample.append(get_activation(activation)) + upsample.append(nn.GELU()) self.upsample = nn.Sequential(*upsample) self.conv_out = nn.Sequential( nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1), - get_activation(activation), + nn.GELU(), nn.Conv2d(hidden_channels, out_channels, 3, 1, 1), ) def forward(self, x): - x = self.conv_in(x) - x = self.features(x) - x = self.upsample(x) - x = self.conv_out(x) - return x + h = self.conv_in(x) + for block in self.blocks: + h = block(h) + h + h = self.upsample(h) + h = self.conv_out(h) + + x = nn.functional.interpolate(x, h.shape[-2:], mode="bilinear") + return nn.functional.sigmoid(x + h) diff --git a/examples/supervised/service.py b/examples/supervised/service.py deleted file mode 100644 index 87824c9..0000000 --- a/examples/supervised/service.py +++ /dev/null @@ -1,43 +0,0 @@ -from pydantic import BaseModel -from kornia.augmentation import Normalize -import torch -from torchvision.io import read_image, ImageReadMode, write_jpeg -from torchvision.transforms.functional import convert_image_dtype -from frogbox.service import BaseService - - -class Request(BaseModel): - input_path: str - output_path: str - quality: int = 95 - - -class Response(BaseModel): - output_path: str - - -class SRService(BaseService): - def inference(self, request: Request): - model = self.models["sr"] - config = self.configs["sr"] - - ds_conf = config.datasets["train"].params - normalize = Normalize( - mean=ds_conf["normalize_mean"], - std=ds_conf["normalize_std"], - keepdim=True, - ) - - image = read_image(request.input_path, ImageReadMode.RGB) / 255.0 - image = normalize(image) - - with torch.inference_mode(): - pred = model(image[None].to(self.device)) - pred = pred.cpu()[0].clamp(0.0, 1.0) - - output = convert_image_dtype(pred, torch.uint8) - write_jpeg(output, request.output_path, quality=request.quality) - return Response(output_path=request.output_path) - - -app = SRService(Request, Response) diff --git a/examples/supervised/train.py b/examples/supervised/train.py index 7226b10..d2fa0cc 100644 --- a/examples/supervised/train.py +++ b/examples/supervised/train.py @@ -1,8 +1,8 @@ -from typing import Optional, Sequence +from typing import cast, Optional, Sequence from pathlib import Path import argparse -from frogbox import read_json_config, SupervisedPipeline, Events -from frogbox.callbacks import create_image_logger +from frogbox import read_json_config, SupervisedPipeline, SupervisedConfig +from frogbox.callbacks import ImageLogger def parse_arguments( @@ -17,18 +17,18 @@ def parse_arguments( parser.add_argument( "--logging", type=str, - choices=["online", "offline", "disabled"], + choices=["online", "offline"], default="online", ) + parser.add_argument("--wandb-id", type=str, required=False) parser.add_argument("--tags", type=str, nargs="+") parser.add_argument("--group", type=str) - parser.add_argument("--wandb-id", type=str, required=False) return parser.parse_args(args) if __name__ == "__main__": args = parse_arguments() - config = read_json_config(args.config) + config = cast(SupervisedConfig, read_json_config(args.config)) pipeline = SupervisedPipeline( config=config, @@ -40,17 +40,12 @@ def parse_arguments( group=args.group, ) - dataset_params = config.datasets["test"].params - image_logger = create_image_logger( - split="test", - normalize_mean=dataset_params["normalize_mean"], - normalize_std=dataset_params["normalize_std"], - denormalize_input=dataset_params["do_normalize"], - ) - - pipeline.install_callback( - event=Events.EPOCH_COMPLETED, - callback=image_logger, + ds_conf = config.datasets["train"].params + image_logger = ImageLogger( + denormalize_input=True, + normalize_mean=ds_conf["normalize_mean"], + normalize_std=ds_conf["normalize_std"], ) + pipeline.install_callback(pipeline.log_interval, image_logger) pipeline.run() diff --git a/frogbox/__init__.py b/frogbox/__init__.py index 6f7c08c..a78fab2 100644 --- a/frogbox/__init__.py +++ b/frogbox/__init__.py @@ -2,10 +2,11 @@ .. include:: ./intro.md """ -__version__ = "0.4.0.dev2" +__version__ = "0.5.0" +from accelerate.utils import set_seed # noqa: F401 +from .config import read_json_config, SupervisedConfig, GANConfig # noqa: F401 +from .engines.events import Event # noqa: F401 from .pipelines.supervised import SupervisedPipeline # noqa: F401 from .pipelines.gan import GANPipeline # noqa: F401 -from .config import read_json_config, SupervisedConfig, GANConfig # noqa: F401 from .utils import load_model_checkpoint # noqa: F401 -from ignite.engine import Events # noqa: F401 diff --git a/frogbox/callbacks/__init__.py b/frogbox/callbacks/__init__.py index 23d056b..9195715 100644 --- a/frogbox/callbacks/__init__.py +++ b/frogbox/callbacks/__init__.py @@ -1,23 +1,22 @@ """ # Custom callbacks -Custom callbacks can be created by implementing a function that accepts the pipeline as its only argument. +Custom callbacks can be created by implementing a function that accepts the pipeline as its first argument. For instance, in the following example a callback is added to unfreeze the model's encoder after 20 epochs: ```python -from frogbox import Events +from frogbox import Event def unfreeze_encoder(pipeline) model = pipeline.model model.encoder.requires_grad_(True) pipeline.install_callback( - event=Events.EPOCH_STARTED(once=20), + event=Event("epoch_started", first=20, last=20), callback=unfreeze_encoder, ) ``` """ # noqa: E501 -from .image_logger import create_image_logger # noqa: F401 -from .ema import EMACallback # noqa: F401 +from .image_logger import ImageLogger # noqa: F401 diff --git a/frogbox/callbacks/callback.py b/frogbox/callbacks/callback.py new file mode 100644 index 0000000..9ab9aff --- /dev/null +++ b/frogbox/callbacks/callback.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod +from ..pipelines.pipeline import Pipeline + + +class Callback(ABC): + @abstractmethod + def __call__(self, pipeline: Pipeline) -> None: ... diff --git a/frogbox/callbacks/ema.py b/frogbox/callbacks/ema.py deleted file mode 100644 index eb10391..0000000 --- a/frogbox/callbacks/ema.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional, Union, Dict, Any -import copy -import torch -from ..pipelines.supervised import SupervisedPipeline -from ..pipelines.gan import GANPipeline - - -PipelineType = Union[SupervisedPipeline, GANPipeline] - - -class EMACallback: - def __init__( - self, - pipeline: PipelineType, - decay: float, - handle_buffers: Optional[str] = "copy", - pipeline_model_name: str = "model", - ): - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1.") - if handle_buffers is not None and handle_buffers not in ( - "update", - "copy", - ): - raise ValueError("`handle_buffers` must be in ('update', 'copy').") - - self.decay = decay - self.handle_buffers = handle_buffers - self.pipeline_model_name = pipeline_model_name - - model = self._get_pipeline_model(pipeline) - self.ema_model = copy.deepcopy(model) - for p in self.ema_model.parameters(): - p.detach_() - self.ema_model.eval() - - def _get_pipeline_model(self, pipeline: PipelineType) -> torch.nn.Module: - model = getattr(pipeline, self.pipeline_model_name) - return model - - def __call__(self, pipeline: PipelineType) -> None: - model = self._get_pipeline_model(pipeline) - for ema_p, model_p in zip( - self.ema_model.parameters(), model.parameters() - ): - ema_p.mul_(self.decay).add_(model_p.data, alpha=(1.0 - self.decay)) - - if self.handle_buffers == "update": - for ema_b, model_b in zip( - self.ema_model.buffers(), model.buffers() - ): - try: - ema_b.mul_(self.decay).add_( - model_b.data, alpha=(1 - self.decay) - ) - except RuntimeError: - ema_b.data = model_b.data - elif self.handle_buffers == "copy": - for ema_b, model_b in zip( - self.ema_model.buffers(), model.buffers() - ): - ema_b.data = model_b.data - - def state_dict(self) -> Dict[str, Any]: - return { - "decay": self.decay, - "handle_buffers": self.handle_buffers, - "pipeline_model_name": self.pipeline_model_name, - "ema_model": self.ema_model.state_dict(), - } - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - state_dict = copy.deepcopy(state_dict) - self.decay = state_dict["decay"] - self.handle_buffers = state_dict["handle_buffers"] - self.pipeline_model_name = state_dict["pipeline_model_name"] - self.ema_model.load_state_dict(state_dict["ema_model"]) diff --git a/frogbox/callbacks/image_logger.py b/frogbox/callbacks/image_logger.py index 3e801a3..4847366 100644 --- a/frogbox/callbacks/image_logger.py +++ b/frogbox/callbacks/image_logger.py @@ -1,222 +1,184 @@ -""" -# Logging images - -The simplest way to log images during training is to create an callback with -`frogbox.callbacks.image_logger.create_image_logger`: - -```python -from frogbox import Events -from frogbox.callbacks import create_image_logger - -pipeline.install_callback( - event=Events.EPOCH_COMPLETED, - callback=create_image_logger(), -) -``` - -Images can automatically be denormalized by setting `denormalize_input`/`denormalize_output` -and providing the mean and standard deviation used for normalization. - -For instance, if input images are normalized with ImageNet parameters and outputs are in [0, 1]: - -```python -image_logger = create_image_logger( - normalize_mean=[0.485, 0.456, 0.406], - normalize_std=[0.229, 0.224, 0.225], - denormalize_input=True, -) -``` - -More advanced transformations can be made by overriding `input_transform`, `model_transform`, or `output_transform`: - -```python -from torchvision.transforms.functional import hflip - -def flip_input(x, y, y_pred): - x = hflip(x) - return x, y_pred, y - -image_logger = create_image_logger( - output_transform=flip_input, -) -``` -""" # noqa: E501 - -from typing import Callable, Any, Sequence, Optional, Union +from typing import Union, Sequence, Callable, Any, Optional import torch from torchvision.transforms.functional import ( - InterpolationMode, - resize, center_crop, + resize, + InterpolationMode, to_pil_image, ) from torchvision.utils import make_grid -from ignite.utils import convert_tensor -from kornia.enhance import Denormalize -import wandb import tqdm -from ..pipelines.supervised import SupervisedPipeline -from ..pipelines.gan import GANPipeline - - -def create_image_logger( - split: str = "test", - log_label: str = "test/images", - resize_to_fit: bool = True, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - antialias: bool = True, - num_cols: Optional[int] = None, - denormalize_input: bool = False, - denormalize_target: bool = False, - normalize_mean: Sequence[float] = (0.0, 0.0, 0.0), - normalize_std: Sequence[float] = (1.0, 1.0, 1.0), - progress: bool = False, - input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), - model_transform: Callable[[Any], Any] = lambda output: output, - output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: ( - x, - y_pred, - y, - ), -): - """ - Create image logger callback. - - Parameters - ---------- - split : str - Dataset split to evaluate on. Defaults to "test". - log_label : str - Label to log images under in Weights & Biases. - resize_to_fit : bool - If `true` smaller images are resized to fit canvas. - interpolation : torchvision.transforms.functional.InterpolationMode - Interpolation to use for resizing images. - antialias : bool - If `true` antialiasing is used when resizing images. - num_cols : int - Number of columns in image grid. - Defaults to number of elements in returned tuple. - denormalize_input : bool - If `true` input images (x) a denormalized after inference. - denormalize_target : bool - If `true` target images (y and y_pred) are denormalized after inference. - normalize_mean : (float, float, float) - RGB mean values used in image normalization. - normalize_std : (float, float, float) - RGB std.dev. values used in image normalization. - progress : bool - Show progress bar. - input_transform : Callable - Function that receives tensors `y` and `y` and outputs tuple of - tensors `(x, y)`. - model_transform : Callable - Function that receives the output from the model during evaluation - and converts it into the predictions: - `y_pred = model_transform(model(x))`. - output_transform : Callable - Function that receives `x`, `y`, `y_pred` and returns tensors to be - logged as images. Default is returning `(x, y_pred, y)`. - """ # noqa: E501 - denormalize = Denormalize( - torch.as_tensor(normalize_mean), - torch.as_tensor(normalize_std), - ) - - def _callback(pipeline: Union[SupervisedPipeline, GANPipeline]): +import wandb +from .callback import Callback +from ..pipelines.pipeline import Pipeline + + +class ImageLogger(Callback): + """Callback for logging images.""" + + def __init__( + self, + split: str = "test", + log_label: str = "test/images", + resize_to_fit: bool = True, + interpolation: Union[str, InterpolationMode] = "nearest", + num_cols: Optional[int] = None, + denormalize_input: bool = False, + denormalize_target: bool = False, + normalize_mean: Sequence[float] = (0.0, 0.0, 0.0), + normalize_std: Sequence[float] = (1.0, 1.0, 1.0), + show_progress: bool = False, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), + model_transform: Callable[[Any], Any] = lambda output: output, + output_transform: Callable[ + [Any, Any, Any], Any + ] = lambda x, y, y_pred: (x, y_pred, y), + ): + """ + Create ImageLogger. + + Parameters + ---------- + split : str + Dataset split to evaluate on. Defaults to "test". + log_label : str + Label to log images under in Weights & Biases. + resize_to_fit : bool + If `true` smaller images are resized to fit canvas. + interpolation : torchvision.transforms.functional.InterpolationMode + Interpolation to use for resizing images. + num_cols : int + Number of columns in image grid. + Defaults to number of elements in returned tuple. + denormalize_input : bool + If `true` input images (x) a denormalized after inference. + denormalize_target : bool + If `true` target images (y and y_pred) are denormalized after + inference. + normalize_mean : (float, float, float) + RGB mean values used in image normalization. + normalize_std : (float, float, float) + RGB std.dev. values used in image normalization. + show_progress : bool + Show progress bar. + input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. + model_transform : Callable + Function that receives the output from the model during evaluation + and converts it into the predictions: + `y_pred = model_transform(model(x))`. + output_transform : Callable + Function that receives `x`, `y`, `y_pred` and returns tensors to be + logged as images. Default is returning `(x, y_pred, y)`. + """ + self._split = split + self._log_label = log_label + self._resize_to_fit = resize_to_fit + self._interpolation = InterpolationMode(interpolation) + self._num_cols = num_cols + self._denormalize_input = denormalize_input + self._denormalize_target = denormalize_target + self._normalize_mean = normalize_mean + self._normalize_std = normalize_std + self._show_progress = show_progress + self._input_transform = input_transform + self._model_transform = model_transform + self._output_transform = output_transform + + def _denormalize(self, x: torch.Tensor) -> torch.Tensor: + mean = torch.as_tensor( + self._normalize_mean, device=x.device, dtype=x.dtype + ).reshape(1, -1, 1, 1) + + std = torch.as_tensor( + self._normalize_std, device=x.device, dtype=x.dtype + ).reshape(1, -1, 1, 1) + + return (x * std) + mean + + def __call__(self, pipeline: Pipeline) -> None: + if not hasattr(pipeline, "model") or not hasattr(pipeline, "loaders"): + raise RuntimeError( + f"ImageLogger not compatible with pipeline {pipeline}." + ) + model = pipeline.model loaders = pipeline.loaders + accelerator = pipeline.accelerator model.eval() - data_iter = iter(loaders[split]) - if progress: + data_iter = loaders[self._split] + if self._show_progress: data_iter = tqdm.tqdm( data_iter, desc="Images", ncols=80, leave=False, - total=len(loaders[split]), + total=len(data_iter), ) images = [] for batch in data_iter: x, y = batch - x, y = input_transform(x, y) + x, y = self._input_transform(x, y) with torch.inference_mode(): - y_pred = model_transform(model(x)) + y_pred = self._model_transform(model(x)) - x, y, y_pred = pipeline.gather_for_metrics((x, y, y_pred)) - x, y, y_pred = convert_tensor( # type: ignore - x=(x, y, y_pred), - device=torch.device("cpu"), - non_blocking=False, - ) + x, y, y_pred = accelerator.gather_for_metrics((x, y, y_pred)) + + x = torch.as_tensor(x, device=torch.device("cpu")) + y = torch.as_tensor(y, device=torch.device("cpu")) + y_pred = torch.as_tensor(y_pred, device=torch.device("cpu")) - if denormalize_input: - x = denormalize(x) - if denormalize_target: - y = denormalize(y) - y_pred = denormalize(y_pred) + if self._denormalize_input: + x = self._denormalize(x) + if self._denormalize_target: + y = self._denormalize(y) + y_pred = self._denormalize(y_pred) - output = output_transform(x, y, y_pred) + output = self._output_transform(x, y, y_pred) batch_sizes = [len(e) for e in output] assert all(s == batch_sizes[0] for s in batch_sizes) for i in range(batch_sizes[0]): - grid = _combine_test_images( - images=[e[i] for e in output], - resize_to_fit=resize_to_fit, - interpolation=interpolation, - antialias=antialias, - num_cols=num_cols, - ) + grid = self._combine_test_images([e[i] for e in output]) images.append(grid) wandb_images = [wandb.Image(to_pil_image(image)) for image in images] - pipeline.log({log_label: wandb_images}) - - return _callback - - -def _combine_test_images( - images: Sequence[torch.Tensor], - resize_to_fit: bool = True, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - antialias: bool = True, - num_cols: Optional[int] = None, -) -> torch.Tensor: - for image in images: - assert len(image.shape) == 3 - assert image.size(0) in (1, 3) - - max_h = max(image.size(1) for image in images) - max_w = max(image.size(2) for image in images) - - transformed = [] - for image in images: - C, H, W = image.shape - if H != max_h or W != max_w: - if resize_to_fit: - image = resize( - image, - size=(max_h, max_w), - interpolation=interpolation, - antialias=antialias, - ) - else: - image = center_crop(image, output_size=(max_h, max_w)) - if C == 1: - image = image.repeat((3, 1, 1)) - image = image.clamp(0.0, 1.0) - transformed.append(image) - - if len(transformed) == 1: - return transformed[0] - else: + pipeline.log({self._log_label: wandb_images}) + + def _combine_test_images( + self, images: Sequence[torch.Tensor] + ) -> torch.Tensor: + for image in images: + assert len(image.shape) == 3 + assert image.size(0) in (1, 3) + + max_h = max(image.size(1) for image in images) + max_w = max(image.size(2) for image in images) + + transformed = [] + for image in images: + c, h, w = image.shape + if (h, w) != (max_h, max_w): + if self._resize_to_fit: + image = resize( + image, + size=(max_h, max_w), + interpolation=self._interpolation, + ) + else: + image = center_crop(image, output_size=(max_h, max_w)) + if c == 1: + image = image.repeat((3, 1, 1)) + image = image.clamp(0.0, 1.0) + transformed.append(image) + return make_grid( tensor=transformed, normalize=False, - nrow=num_cols or len(transformed), + nrow=self._num_cols or len(transformed), ) diff --git a/frogbox/cli.py b/frogbox/cli.py index e912b0f..5383e23 100644 --- a/frogbox/cli.py +++ b/frogbox/cli.py @@ -1,12 +1,10 @@ """@private""" -from typing import Optional, Tuple, Sequence from pathlib import Path import os -import json import shutil +import importlib import subprocess -import importlib.resources import click @@ -65,7 +63,7 @@ def project(): "--type", "-t", "type_", - type=click.Choice(["supervised", "gan"]), + type=click.Choice(["supervised"]), default="supervised", help="Pipeline type.", ) @@ -124,26 +122,10 @@ def new_project(type_: str, dir_: Path, overwrite: bool = False): if type_ == "supervised": from .config import SupervisedConfig, ObjectDefinition - config = SupervisedConfig( + config_json = SupervisedConfig( type="supervised", project="example", - datasets={ - "train": ObjectDefinition( - class_name="datasets.example.ExampleDataset" - ), - "val": ObjectDefinition( - class_name="datasets.example.ExampleDataset" - ), - }, model=ObjectDefinition(class_name="models.example.ExampleModel"), - ) - config_json = config.model_dump_json(indent=2, exclude_none=True) - elif type_ == "gan": - from .config import GANConfig, ObjectDefinition - - config = GANConfig( - type="gan", - project="example", datasets={ "train": ObjectDefinition( class_name="datasets.example.ExampleDataset" @@ -152,12 +134,7 @@ def new_project(type_: str, dir_: Path, overwrite: bool = False): class_name="datasets.example.ExampleDataset" ), }, - model=ObjectDefinition(class_name="models.example.ExampleModel"), - disc_model=ObjectDefinition( - class_name="models.example.ExampleModel" - ), - ) - config_json = config.model_dump_json(indent=2, exclude_none=True) + ).model_dump_json(indent=4, exclude_none=True) else: raise RuntimeError(f"Unknown pipeline type {type_}.") @@ -166,240 +143,5 @@ def new_project(type_: str, dir_: Path, overwrite: bool = False): output_path.write_text(config_json) -@cli.group() -def service(): - """Manage service.""" - - -@service.command(name="new") -@click.option( - "--dir", - "-d", - "dir_", - type=click.Path( - exists=False, - file_okay=False, - dir_okay=True, - path_type=Path, - ), - default=Path("."), - help="Project root directory.", -) -@click.option( - "--overwrite", - is_flag=True, - help="Overwrite existing files if present.", -) -def new_service(dir_: Path, overwrite: bool = False): - """Create new service from template.""" - - template_inputs = [ - "service.py", - ] - - template_outputs = [ - dir_ / "service.py", - ] - - # Check if files already exist - if not overwrite: - for path in template_outputs: - if path.exists(): - raise RuntimeError( - f"File '{path}' already exists." - " Use flag --overwrite to overwrite." - ) - - # Create folders and copy template files - resource_files = importlib.resources.files("frogbox.data") - for input_resource, output_path in zip(template_inputs, template_outputs): - file_data = resource_files.joinpath(input_resource).read_text() - output_path.parent.mkdir(exist_ok=True, parents=True) - output_path.write_text(file_data) - - -@service.command(name="serve") -@click.option( - "--checkpoint", - "-c", - "checkpoints", - type=( - str, - click.Path( - exists=True, file_okay=True, dir_okay=False, path_type=Path - ), - ), - multiple=True, - help=( - "Add model checkpoint." - " Add multiple models by repeating this argument." - ), - metavar="NAME PATH", -) -@click.option( - "--device", - "-d", - type=str, - default="cpu", - help="CUDA device.", - show_default=True, -) -def serve(checkpoints: Sequence[Tuple[str, Path]], device: str): - """Serve service locally.""" - - import uvicorn - - checkpoints_env = {} - for name, path in checkpoints: - checkpoints_env[name] = str(path) - os.environ["CHECKPOINTS"] = json.dumps(checkpoints_env) - os.environ["DEVICE"] = device - - uvicorn.run("service:app", port=8000, app_dir=".") - - -@service.command(name="dockerfile") -@click.option( - "--checkpoint", - "-c", - "checkpoints", - type=( - str, - click.Path( - exists=True, file_okay=True, dir_okay=False, path_type=Path - ), - ), - multiple=True, - help=( - "Add model checkpoint." - " Add multiple models by repeating this argument." - ), - metavar="NAME PATH", -) -@click.option( - "--requirements", - "-r", - type=click.Path( - exists=True, file_okay=True, dir_okay=False, path_type=Path - ), - default="requirements.txt", - help="Path to service requirements.txt. Defaults to requirements.txt.", - metavar="PATH", -) -@click.option( - "--out", - "-o", - type=click.Path( - exists=False, - file_okay=True, - dir_okay=False, - path_type=Path, - ), - help="Write Dockerfile to file.", -) -def service_dockerfile( - checkpoints: Sequence[Tuple[str, Path]], - requirements: Path, - out: Optional[Path] = None, -): - """Build service Dockerfile.""" - from jinja2 import Environment, PackageLoader - - env = Environment( - loader=PackageLoader("frogbox", "data"), - autoescape=False, - ) - template = env.get_template("Dockerfile") - - ckpt_info = [] - for name, model_path in checkpoints: - config_path = model_path.parent / "config.json" - ckpt_info.append( - dict( - name=name, - model_path=model_path, - config_path=config_path, - parent_path=model_path.parent, - ) - ) - - env_checkpoints = {e["name"]: str(e["model_path"]) for e in ckpt_info} - output = template.render( - checkpoints=ckpt_info, - requirements=str(requirements), - env_checkpoints=json.dumps(env_checkpoints), - ) - - if out: - out.write_text(output) - else: - print(output) - - -@cli.group() -def schema(): - """Work with config schemas.""" - - -@schema.command() -@click.option( - "--file", - "-f", - type=click.Path( - exists=True, - file_okay=True, - dir_okay=False, - path_type=Path, - ), - required=True, - help="Config file path.", -) -def validate(path: Path): - """Validate config file.""" - from .config import read_json_config - - read_json_config(path) - - -@schema.command(name="write") -@click.option( - "--type", - "-t", - "type_", - type=click.Choice(["supervised", "gan"]), - default="supervised", - help="Pipeline type.", -) -@click.option( - "--out", - "-o", - type=click.Path( - exists=False, - file_okay=True, - dir_okay=False, - path_type=Path, - ), - help="Write schema to file.", -) -def write_schema(type_: str, out: Optional[Path] = None): - if type_ == "supervised": - from .config import SupervisedConfig - - schema_json = json.dumps( - SupervisedConfig.model_json_schema(), indent=2 - ) - elif type_ == "gan": - from .config import GANConfig - - schema_json = json.dumps(GANConfig.model_json_schema(), indent=2) - else: - raise RuntimeError(f"Unknown pipeline type {type_}.") - - if out: - out.write_text(schema_json) - else: - print(schema_json) - - if __name__ == "__main__": cli() diff --git a/frogbox/config.py b/frogbox/config.py index 2783ad8..310cb7e 100644 --- a/frogbox/config.py +++ b/frogbox/config.py @@ -1,29 +1,21 @@ -from typing import Union, Dict, Sequence, Any, Optional +from typing import Dict, Any, Optional, Union, Sequence from os import PathLike -import warnings from enum import Enum from pathlib import Path -from importlib import import_module import json +from pydantic import BaseModel, Field, field_validator import jinja2 -from pydantic import BaseModel, ConfigDict, Field, field_validator -from ignite.engine import Events, CallableEventWithFilter +from importlib import import_module +from .engines.events import EventStep, Event, MatchableEvent class ConfigType(str, Enum): + """Pipeline configuration type.""" + SUPERVISED = "supervised" GAN = "gan" -class CheckpointMode(str, Enum): - """ - Checkpoint evaluation mode. - """ - - MIN = "min" - MAX = "max" - - class LogInterval(BaseModel): """ Logging interval. @@ -34,10 +26,46 @@ class LogInterval(BaseModel): Event trigger. interval : int How often event should trigger. Defaults to every time (`1`). + first : int + First step where event should trigger (zero-indexed). + last : int + Last step where vent should trigger (zero-indexed). """ - event: Events + event: EventStep every: int = 1 + first: Optional[int] = None + last: Optional[int] = None + + +class CheckpointMode(str, Enum): + """Checkpoint evaluation mode.""" + + MIN = "min" + MAX = "max" + + +class CheckpointDefinition(BaseModel): + """ + Checkpoint definition. + + + Attributes + ---------- + metric : str + Name of metric to compare (optional). + mode : CheckpointMode + Whether to priority maximum or minimum metric value. + num_saved : int + Number of checkpoints to save. + interval : EventStep or LogInterval + Interval between saving checkpoints. + """ + + metric: Optional[str] = None + mode: CheckpointMode = CheckpointMode.MAX + num_saved: int = Field(3, ge=1) + interval: Union[EventStep, LogInterval] = EventStep.EPOCH_COMPLETED class ObjectDefinition(BaseModel): @@ -105,14 +133,6 @@ class LRSchedulerDefinition(BaseModel): warmup_steps: int = Field(0, ge=0) -class CheckpointDefinition(BaseModel): - """Checkpoint definition.""" - metric: Optional[str] = None - mode: CheckpointMode = CheckpointMode.MAX - n_saved: int = Field(3, ge=1) - interval: Union[Events, LogInterval] = Events.EPOCH_COMPLETED - - class Config(BaseModel): """ Base configuration. @@ -123,34 +143,18 @@ class Config(BaseModel): Pipeline type. project : str Project name. - checkpoint_metric : str - Name of metric to use for evaluating checkpoints. - checkpoint_mode : CheckpointMode - Either `min` or `max`. Determines whether to keep the checkpoints - with the greatest or lowest metric score. - checkpoint_n_saved : int - Number of checkpoints to keep. - log_interval : Events or LogInterval + log_interval : EventStep or LogInterval At which interval to log metrics. """ - model_config = ConfigDict(extra="allow") - type: ConfigType project: str - log_interval: Union[Events, LogInterval] = Events.EPOCH_COMPLETED - checkpoints: Sequence[CheckpointDefinition] = ( - CheckpointDefinition( - metric=None, - n_saved=3, - interval=Events.EPOCH_COMPLETED, - ), - ) + log_interval: Union[EventStep, LogInterval] = EventStep.EPOCH_COMPLETED class SupervisedConfig(Config): """ - Trainer configuration. + Supervised pipeline configuration. Attributes ---------- @@ -172,7 +176,7 @@ class SupervisedConfig(Config): loaders : dict of ObjectDefinition Data loader definitions. model : ObjectDefinition - Model object definition. + Model definition. losses : dict of LossDefinition Loss functions. metrics : dict of ObjectDefinition @@ -185,43 +189,64 @@ class SupervisedConfig(Config): batch_size: int = Field(32, ge=1) loader_workers: int = Field(0, ge=0) - max_epochs: int = Field(32, ge=1) + max_epochs: int = Field(50, ge=1) clip_grad_norm: Optional[float] = None clip_grad_value: Optional[float] = None gradient_accumulation_steps: int = Field(1, ge=1) - datasets: Dict[str, ObjectDefinition] - loaders: Dict[str, ObjectDefinition] = dict() + metrics: Dict[str, ObjectDefinition] = dict() + checkpoints: Sequence[CheckpointDefinition] = ( + CheckpointDefinition( + metric=None, + num_saved=3, + interval=EventStep.EPOCH_COMPLETED, + ) + ), model: ObjectDefinition losses: Dict[str, LossDefinition] = dict() - metrics: Dict[str, ObjectDefinition] = dict() + datasets: Dict[str, ObjectDefinition] + loaders: Dict[str, ObjectDefinition] = dict() optimizer: ObjectDefinition = ObjectDefinition( - class_name="torch.optim.AdamW" + class_name="torch.optim.AdamW", + params={"lr": 1e-3}, ) lr_scheduler: LRSchedulerDefinition = LRSchedulerDefinition() - @field_validator("datasets") + @field_validator("type") @classmethod - def validate_datasets(cls, v): - assert "train" in v, "'train' missing in datasets definition." - assert "val" in v, "'val' missing in datasets definition." - return v - - @field_validator("losses") - @classmethod - def validate_losses(cls, v): - if len(v) == 0: - warnings.warn("No loss functions defined.") + def check_type(cls, v: ConfigType) -> ConfigType: + assert v == ConfigType.SUPERVISED return v class GANConfig(SupervisedConfig): + """ + GAN pipeline configuration. + + Attributes + ---------- + disc_model : ObjectDefinition + Discriminator model definition. + disc_losses: dict of LossDefinition + Discriminator loss functions. + disc_optimizer : ObjectDefinition + Discriminator Torch optimizer. + disc_lr_scheduler : LRSchedulerDefinition + Discriminator learning rate scheduler. + """ disc_model: ObjectDefinition disc_losses: Dict[str, LossDefinition] = dict() disc_optimizer: ObjectDefinition = ObjectDefinition( - class_name="torch.optim.AdamW" + class_name="torch.optim.AdamW", + params={"lr": 1e-3}, ) disc_lr_scheduler: LRSchedulerDefinition = LRSchedulerDefinition() + @field_validator("type") + @classmethod + def check_type(cls, v: ConfigType) -> ConfigType: + assert v == ConfigType.GAN + return v + def read_json_config(path: Union[str, PathLike]) -> Config: """ @@ -245,21 +270,20 @@ def read_json_config(path: Union[str, PathLike]) -> Config: raise RuntimeError(f"Unknown config type {config['type']}.") -def parse_log_interval( - e: Union[Events, LogInterval] -) -> CallableEventWithFilter: - """ - Create ignite event from string or dictionary configuration. - Dictionary must have a ``event`` entry. - """ - if isinstance(e, Events): - return e - - config = dict(e) - event = config.pop("event") - if len(config) > 0: - event = event(**config) - return event +def parse_log_interval(e: Union[str, LogInterval]) -> MatchableEvent: + """Create matchable event from log interval configuration.""" + if isinstance(e, str): + return Event(event=e) + + if isinstance(e, LogInterval): + return Event( + event=e.event, + every=e.every, + first=e.first, + last=e.last, + ) + + raise ValueError(f"Cannot parse log interval {e}.") def _get_class(path: str) -> Any: diff --git a/frogbox/data/Dockerfile b/frogbox/data/Dockerfile deleted file mode 100644 index 920c67c..0000000 --- a/frogbox/data/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -ARG PYTORCH_VERSION=2.1.2-cuda11.8-cudnn8-runtime -ARG PYTORCH_PLATFORM=linux/amd64 -ARG CHECKPOINT - -FROM --platform=${PYTORCH_PLATFORM} pytorch/pytorch:${PYTORCH_VERSION} - -EXPOSE 8000 - -WORKDIR /app - -RUN apt-get update && apt-get install -y git - -COPY service.py ./ -COPY {{ requirements }} ./requirements.txt -COPY models ./models - -RUN pip install -r requirements.txt -RUN pip install -U uvicorn - -{% for ckpt in checkpoints %} -RUN mkdir -p {{ ckpt.parent_path }} -COPY {{ ckpt.model_path }} {{ ckpt.model_path }} -COPY {{ ckpt.config_path }} {{ ckpt.config_path }} -{% endfor %} - -ENV CHECKPOINTS='{{ env_checkpoints }}' - -CMD ["uvicorn", "service:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/frogbox/data/service.py b/frogbox/data/service.py deleted file mode 100644 index ce9d614..0000000 --- a/frogbox/data/service.py +++ /dev/null @@ -1,18 +0,0 @@ -from pydantic import BaseModel -from frogbox.service import BaseService - - -class Request(BaseModel): - pass - - -class Response(BaseModel): - pass - - -class ExampleService(BaseService): - def inference(self, request: Request): - return Response() - - -app = ExampleService(Request, Response) diff --git a/frogbox/data/train_gan.py b/frogbox/data/train_gan.py deleted file mode 100644 index 106b3ce..0000000 --- a/frogbox/data/train_gan.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import cast, Optional, Sequence -from pathlib import Path -import argparse -from frogbox import read_json_config, GANPipeline, GANConfig - - -def parse_arguments( - args: Optional[Sequence[str]] = None, -) -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "-c", "--config", type=Path, default="configs/example.json" - ) - parser.add_argument("--checkpoint", type=Path) - parser.add_argument("--checkpoint-keys", type=str, nargs="+") - parser.add_argument( - "--logging", - type=str, - choices=["online", "offline"], - default="online", - ) - parser.add_argument("--wandb-id", type=str, required=False) - parser.add_argument("--tags", type=str, nargs="+") - parser.add_argument("--group", type=str) - return parser.parse_args(args) - - -if __name__ == "__main__": - args = parse_arguments() - config = cast(GANConfig, read_json_config(args.config)) - - pipeline = GANPipeline( - config=config, - checkpoint=args.checkpoint, - checkpoint_keys=args.checkpoint_keys, - logging=args.logging, - wandb_id=args.wandb_id, - tags=args.tags, - group=args.group, - ) - - pipeline.run() diff --git a/frogbox/engines/engine.py b/frogbox/engines/engine.py new file mode 100644 index 0000000..a09c5cb --- /dev/null +++ b/frogbox/engines/engine.py @@ -0,0 +1,154 @@ +from typing import ( + Callable, + Iterable, + List, + Any, + Union, + Dict, + Mapping, + Optional, + Sequence, +) +import tqdm +from .events import EventStep, MatchableEvent, Event + + +class Engine: + def __init__( + self, + process_fn: Callable, + show_progress: bool = True, + progress_label: Optional[str] = None, + ): + self.process_fn = process_fn + self.show_progress = show_progress + self.progress_label = progress_label + + self.epoch = 0 + self.iteration = 0 + self.max_epochs = 1 + + self.event_handlers: List[EventHandler] = [] + self.output_handlers: List[OutputHandler] = [] + + def _fire_event(self, event: EventStep) -> None: + step = 0 + if event in (EventStep.EPOCH_STARTED, EventStep.EPOCH_COMPLETED): + step = self.epoch + if event in ( + EventStep.ITERATION_STARTED, + EventStep.ITERATION_COMPLETED, + ): + step = self.iteration + + for handler in self.event_handlers: + if handler.event.matches(event, step): + handler.function(*handler.args, **handler.kwargs) + + def _handle_output(self, output: Any) -> None: + for handler in self.output_handlers: + handler.function(output) + + def _get_progress_label(self) -> str: + label = "" + if self.max_epochs > 1: + label += f"[{self.epoch+1}/{self.max_epochs}]" + if self.progress_label is not None and len(self.progress_label) > 0: + label = self.progress_label + " " + label + return label + + def _get_data_iterator( + self, + loader: Iterable, + ) -> Iterable: + iterator = loader + if self.show_progress: + desc = self._get_progress_label() + iterator = tqdm.tqdm( + iterator, + desc=desc, + ncols=80, + leave=False, + ) + return iterator + + def _is_done(self) -> bool: + return self.epoch >= self.max_epochs + + def add_event_handler( + self, + event: Union[str, EventStep, MatchableEvent], + function: Callable[..., None], + *args, + **kwargs, + ): + if isinstance(event, str) or isinstance(event, EventStep): + event = Event(event) + self.event_handlers.append( + EventHandler(event, function, *args, **kwargs) + ) + + def add_output_handler( + self, + function: Callable[[Any], None], + ): + self.output_handlers.append(OutputHandler(function)) + + def run(self, loader: Iterable, max_epochs: int = 1) -> None: + self.max_epochs = max_epochs + + if self._is_done(): + self.epoch = 0 + self.iteration = 0 + + self._fire_event(EventStep.STARTED) + + while not self._is_done(): + self._fire_event(EventStep.EPOCH_STARTED) + + iterations = self._get_data_iterator(loader) + for batch in iterations: + self._fire_event(EventStep.ITERATION_STARTED) + + output = self.process_fn(batch) + self._handle_output(output) + + self.iteration += 1 + self._fire_event(EventStep.ITERATION_COMPLETED) + + self.epoch += 1 + self._fire_event(EventStep.EPOCH_COMPLETED) + + self._fire_event(EventStep.COMPLETED) + + def state_dict(self) -> Dict[str, Any]: + return dict( + epoch=self.epoch, + iteration=self.iteration, + ) + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + self.epoch = state_dict["epoch"] + self.iteration = state_dict["iteration"] + + +class EventHandler: + def __init__( + self, + event: MatchableEvent, + function: Callable[..., None], + *args, + **kwargs, + ): + self.event = event + self.function = function + self.args: Sequence[Any] = args + self.kwargs: Dict[str, Any] = kwargs + + +class OutputHandler: + def __init__( + self, + function: Callable[[Any], None], + ): + self.function = function diff --git a/frogbox/engines/events.py b/frogbox/engines/events.py new file mode 100644 index 0000000..be4dfc9 --- /dev/null +++ b/frogbox/engines/events.py @@ -0,0 +1,64 @@ +from typing import Optional, Union +from abc import ABC, abstractmethod +from enum import Enum + + +class EventStep(str, Enum): + STARTED = "started" + EPOCH_STARTED = "epoch_started" + ITERATION_STARTED = "iteration_started" + ITERATION_COMPLETED = "iteration_completed" + EPOCH_COMPLETED = "epoch_completed" + COMPLETED = "completed" + + +class MatchableEvent(ABC): + @abstractmethod + def matches(self, event: EventStep, step: int) -> bool: ... + + +class EventList(MatchableEvent): + def __init__(self): + self.events = [] + + def __or__(self, other: MatchableEvent) -> "EventList": + self.events.append(other) + return self + + def matches(self, event: EventStep, step: int) -> bool: + for entry in self.events: + if entry.matches(event, step): + return True + return False + + +class Event(MatchableEvent): + def __init__( + self, + event: Union[str, EventStep], + every: Optional[int] = None, + first: Optional[int] = None, + last: Optional[int] = None, + ): + self.event = EventStep(event) + self.every = every + self.first = first + self.last = last + + def matches(self, event: EventStep, step: int) -> bool: + if event != self.event: + return False + + if self.first is not None and step < self.first: + return False + + if self.last is not None and step > self.last: + return False + + if self.every is not None and step % self.every != 0: + return False + + return True + + def __or__(self, other: MatchableEvent) -> EventList: + return EventList() | self | other diff --git a/frogbox/engines/gan.py b/frogbox/engines/gan.py index 66c105d..21dd9d8 100644 --- a/frogbox/engines/gan.py +++ b/frogbox/engines/gan.py @@ -1,150 +1,125 @@ -from typing import Callable, Any, Union, Sequence, Tuple, Optional +from typing import Callable, Any, Optional import torch -from ignite.engine import Engine, DeterministicEngine from accelerate import Accelerator +from .engine import Engine -def create_gan_trainer( - accelerator: Accelerator, - model: torch.nn.Module, - disc_model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - disc_optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - disc_scheduler: torch.optim.lr_scheduler.LRScheduler, - loss_fn: Union[Callable, torch.nn.Module], - disc_loss_fn: Union[Callable, torch.nn.Module], - clip_grad_norm: Optional[float] = None, - clip_grad_value: Optional[float] = None, - input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), - model_transform: Callable[[Any], Any] = lambda output: output, - disc_model_transform: Callable[[Any], Any] = lambda output: output, - output_transform: Callable[ - [Any, Any, Any, torch.Tensor, torch.Tensor], Any - ] = lambda x, y, y_pred, loss, disc_loss: ( - loss.item(), - disc_loss.item(), - ), - deterministic: bool = False, -) -> Engine: - """ - Factory function for GAN trainer. - - Parameters - ---------- - model : torch.nn.Module - The model to train. - disc_model : torch.nn.Module - The discriminator to train. - optimizer : torch optimizer - The optimizer to use for model. - disc_optimizer : torch optimizer - The optimizer to use for discriminator. - scheduler : torch LRScheduler - Model learning rate scheduler. - disc_scheduler : torch LRScheduler - Discriminator learning rate scheduler. - loss_fn : torch.nn.Module - The supervised loss function to use for model. - disc_loss_fn : torch.nn.Module - The loss function to use discriminator. - clip_grad_norm : float - Clip gradients to norm if provided. - update_interval : int - How many steps between updating `model`. - disc_update_interval : int - How many steps between updating `disc_model`. - input_transform : Callable - Function that receives tensors `y` and `y` and outputs tuple of - tensors `(x, y)`. - model_transform : Callable - Function that receives the output from the model and - convert it into the form as required by the loss function. - disc_model_transform : Callable - Function that receives the output from the discriminator and - convert it into the form as required by the loss function. - output_transform : Callable - Function that receives `x`, `y`, `y_pred`, `loss` and returns value - to be assigned to engine's state.output after each iteration. Default - is returning `(loss.item(), disc_loss.item())`. - deterministic : bool - If `True`, returns `DeterministicEngine`, otherwise `Engine`. - - Returns - ------- - trainer : torch.engine.Engine - A trainer engine with GAN update function. - """ - - def _update( - engine: Engine, batch: Sequence[torch.Tensor] - ) -> Union[Any, Tuple[torch.Tensor]]: - model.train() - disc_model.train() - - x, y = batch - x, y = input_transform(x, y) +class GANTrainer(Engine): + def __init__( + self, + accelerator: Accelerator, + model: torch.nn.Module, + disc_model: torch.nn.Module, + loss_fn: Callable[..., Any], + disc_loss_fn: Callable[..., Any], + optimizer: torch.optim.Optimizer, + disc_optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + disc_scheduler: torch.optim.lr_scheduler.LRScheduler, + clip_grad_norm: Optional[float] = None, + clip_grad_value: Optional[float] = None, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), + model_transform: Callable[[Any], Any] = lambda output: output, + disc_model_transform: Callable[[Any], Any] = lambda output: output, + output_transform: Callable[ + [Any, Any, Any, Any, Any], Any + ] = lambda x, y, y_pred, loss, disc_loss: ( + loss.item(), + disc_loss.item(), + ), + **kwargs, + ): + self.accelerator = accelerator + self.model = model + self.disc_model = disc_model + self.loss_fn = loss_fn + self.disc_loss_fn = disc_loss_fn + self.optimizer = optimizer + self.disc_optimizer = disc_optimizer + self.scheduler = scheduler + self.disc_scheduler = disc_scheduler + + self.clip_grad_norm = clip_grad_norm + self.clip_grad_value = clip_grad_value + + self._input_transform = input_transform + self._model_transform = model_transform + self._disc_model_transform = disc_model_transform + self._output_transform = output_transform + + super().__init__(self.process, **kwargs) + + def process(self, batch): + self.model.train() + self.disc_model.train() + + inputs, targets = batch + inputs, targets = self._input_transform(inputs, targets) # Update discriminator - with accelerator.accumulate(disc_model): - disc_optimizer.zero_grad() - y_pred = model_transform(model(x)).detach() - disc_pred_real = disc_model_transform(disc_model(y)) - disc_pred_fake = disc_model_transform(disc_model(y_pred)) - disc_loss = disc_loss_fn( - y_pred, - y, + with self.accelerator.accumulate(self.disc_model): + self.disc_optimizer.zero_grad() + outputs = self._model_transform(self.model(inputs)).detach() + disc_pred_real = self._disc_model_transform( + self.disc_model(targets) + ) + disc_pred_fake = self._disc_model_transform( + self.disc_model(outputs) + ) + + disc_loss = self.disc_loss_fn( + outputs, + targets, disc_real=disc_pred_real, disc_fake=disc_pred_fake, ) + self.accelerator.backward(disc_loss) - accelerator.backward(disc_loss) - if accelerator.sync_gradients: - if clip_grad_norm: - accelerator.clip_grad_norm_( - parameters=disc_model.parameters(), - max_norm=clip_grad_norm, + if self.accelerator.sync_gradients: + if self.clip_grad_norm: + self.accelerator.clip_grad_norm_( + parameters=self.disc_model.parameters(), + max_norm=self.clip_grad_norm, ) - if clip_grad_value: - accelerator.clip_grad_value_( - parameters=disc_model.parameters(), - clip_value=clip_grad_value, + if self.clip_grad_value: + self.accelerator.clip_grad_value_( + parameters=self.disc_model.parameters(), + clip_value=self.clip_grad_value, ) - disc_optimizer.step() - disc_scheduler.step() + self.disc_optimizer.step() + self.disc_scheduler.step() # Update generator - with accelerator.accumulate(model): - optimizer.zero_grad() - y_pred = model_transform(model(x)) - disc_pred_fake = disc_model_transform(disc_model(y_pred)) - loss = loss_fn( - y_pred, - y, + with self.accelerator.accumulate(self.model): + self.optimizer.zero_grad() + outputs = self._model_transform(self.model(inputs)) + disc_pred_fake = self._disc_model_transform( + self.disc_model(outputs) + ) + + loss = self.loss_fn( + outputs, + targets, disc_fake=disc_pred_fake, ) + self.accelerator.backward(loss) - accelerator.backward(loss) - if accelerator.sync_gradients: - if clip_grad_norm: - accelerator.clip_grad_norm_( - parameters=model.parameters(), - max_norm=clip_grad_norm, + if self.accelerator.sync_gradients: + if self.clip_grad_norm: + self.accelerator.clip_grad_norm_( + parameters=self.model.parameters(), + max_norm=self.clip_grad_norm, ) - if clip_grad_value: - accelerator.clip_grad_value_( - parameters=model.parameters(), - clip_value=clip_grad_value, + if self.clip_grad_value: + self.accelerator.clip_grad_value_( + parameters=self.model.parameters(), + clip_value=self.clip_grad_value, ) - optimizer.step() - scheduler.step() - - return output_transform(x, y, y_pred, loss, disc_loss) - - trainer = ( - Engine(_update) if not deterministic else DeterministicEngine(_update) - ) + self.optimizer.step() + self.scheduler.step() - return trainer + return self._output_transform( + inputs, targets, outputs, loss, disc_loss + ) diff --git a/frogbox/engines/supervised.py b/frogbox/engines/supervised.py index 6a5cffd..cda2c04 100644 --- a/frogbox/engines/supervised.py +++ b/frogbox/engines/supervised.py @@ -1,149 +1,106 @@ -from typing import Union, Callable, Any, Sequence, Tuple, Optional, Dict +from typing import Callable, Any, Optional import torch -from ignite.engine import Engine, DeterministicEngine -from ignite.metrics import Metric from accelerate import Accelerator - - -def create_supervised_trainer( - accelerator: Accelerator, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - loss_fn: Union[Callable, torch.nn.Module], - clip_grad_norm: Optional[float] = None, - clip_grad_value: Optional[float] = None, - input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), - model_transform: Callable[[Any], Any] = lambda output: output, - output_transform: Callable[ - [Any, Any, Any, torch.Tensor], Any - ] = lambda x, y, y_pred, loss: loss.item(), - deterministic: bool = False, -) -> Engine: - """ - Factory function for supervised trainer. - - Parameters - ---------- - model : torch.nn.Module - The model to train. - optimizer : torch optimizer - The optimizer to use. - scheduler : torch LRScheduler - Learning rate scheduler. - loss_fn : torch.nn.Module - The loss function to use. - clip_grad_norm : float - Clip gradients to norm if provided. - clip_grad_value : float - Clip gradients to value if provided. - input_transform : Callable - Function that receives tensors `y` and `y` and outputs tuple of - tensors `(x, y)`. - model_transform : Callable - Function that receives the output from the model and - convert it into the form as required by the loss function. - output_transform : Callable - Function that receives `x`, `y`, `y_pred`, `loss` and - returns value to be assigned to engine's state.output after each - iteration. Default is returning `loss.item()`. - deterministic : bool - If `True`, returns `DeterministicEngine`, otherwise `Engine`. - - Returns - ------- - trainer : torch.ignite.Engine - A trainer engine with supervised update function. - """ - - def _update( - engine: Engine, batch: Sequence[torch.Tensor] - ) -> Union[Any, Tuple[torch.Tensor]]: - model.train() - - x, y = batch - x, y = input_transform(x, y) - - with accelerator.accumulate(model): - optimizer.zero_grad() - y_pred = model_transform(model(x)) - loss = loss_fn(y_pred, y) - - accelerator.backward(loss) - if accelerator.sync_gradients: - if clip_grad_norm: - accelerator.clip_grad_norm_( - parameters=model.parameters(), - max_norm=clip_grad_norm, +from .engine import Engine + + +class SupervisedTrainer(Engine): + def __init__( + self, + accelerator: Accelerator, + model: torch.nn.Module, + loss_fn: Callable[[Any, Any], Any], + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + clip_grad_norm: Optional[float] = None, + clip_grad_value: Optional[float] = None, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), + model_transform: Callable[[Any], Any] = lambda output: output, + output_transform: Callable[ + [Any, Any, Any, Any], Any + ] = lambda x, y, y_pred, loss: loss.item(), + **kwargs, + ): + self.accelerator = accelerator + self.model = model + self.loss_fn = loss_fn + self.optimizer = optimizer + self.scheduler = scheduler + + self.clip_grad_norm = clip_grad_norm + self.clip_grad_value = clip_grad_value + + self._input_transform = input_transform + self._model_transform = model_transform + self._output_transform = output_transform + + super().__init__(process_fn=self.process, **kwargs) + + def process(self, batch): + self.model.train() + + inputs, targets = batch + inputs, targets = self._input_transform(inputs, targets) + + with self.accelerator.accumulate(self.model): + self.optimizer.zero_grad() + outputs = self._model_transform(self.model(inputs)) + + loss = self.loss_fn(outputs, targets) + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients: + if self.clip_grad_norm: + self.accelerator.clip_grad_norm_( + parameters=self.model.parameters(), + max_norm=self.clip_grad_norm, ) - if clip_grad_value: - accelerator.clip_grad_value_( - parameters=model.parameters(), - clip_value=clip_grad_value, + if self.clip_grad_value: + self.accelerator.clip_grad_value_( + parameters=self.model.parameters(), + clip_value=self.clip_grad_value, ) - optimizer.step() - scheduler.step() - - return output_transform(x, y, y_pred, loss) - - trainer = ( - Engine(_update) if not deterministic else DeterministicEngine(_update) - ) - - return trainer - - -def create_supervised_evaluator( - accelerator: Accelerator, - model: torch.nn.Module, - metrics: Optional[Dict[str, Metric]] = None, - input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), - model_transform: Callable[[Any], Any] = lambda output: output, - output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: ( - y_pred, - y, - ), -) -> Engine: - """ - Factory function for supervised evaluator. - - Parameters - ---------- - model : torch.nn.Module - The model to train. - metrics : dict - Dictionary of evaluation metrics. - input_transform : Callable - Function that receives tensors `y` and `y` and outputs tuple of - tensors `(x, y)`. - model_transform : Callable - Function that receives the output from the model and convert it into - the predictions: `y_pred = model_transform(model(x))`. - output_transform : Callable - Function that receives `x`, `y`, `y_pred` and returns value to be - assigned to engine's state.output after each iteration. - Default is returning `(y_pred, y,)` which fits output expected by - metrics. If you change it you should use `output_transform` in metrics. - """ - - def _step( - engine: Engine, batch: Sequence[torch.Tensor] - ) -> Union[Any, Tuple[torch.Tensor]]: - model.eval() - - x, y = batch - x, y = input_transform(x, y) - with torch.no_grad(): - y_pred = model_transform(model(x)) + self.optimizer.step() + self.scheduler.step() + + return self._output_transform(inputs, targets, outputs, loss) + + +class SupervisedEvaluator(Engine): + def __init__( + self, + accelerator: Accelerator, + model: torch.nn.Module, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), + model_transform: Callable[[Any], Any] = lambda output: output, + output_transform: Callable[ + [Any, Any, Any], Any + ] = lambda x, y, y_pred: (y_pred, y), + **kwargs, + ): + self.accelerator = accelerator + self.model = model - x, y, y_pred = accelerator.gather_for_metrics((x, y, y_pred)) - return output_transform(x, y, y_pred) + self._input_transform = input_transform + self._model_transform = model_transform + self._output_transform = output_transform + + super().__init__(process_fn=self.process, **kwargs) + + def process(self, batch): + self.model.eval() + + inputs, targets = batch + inputs, targets = self._input_transform(inputs, targets) + + with torch.no_grad(): + outputs = self._model_transform(self.model(inputs)) - evaluator = Engine(_step) + outputs, targets = self._output_transform(inputs, targets, outputs) - if metrics: - for name, metric in metrics.items(): - metric.attach(evaluator, name) + outputs, targets = self.accelerator.gather_for_metrics( + (outputs, targets) + ) - return evaluator + return outputs, targets diff --git a/frogbox/handlers/__init__.py b/frogbox/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/frogbox/handlers/checkpoint.py b/frogbox/handlers/checkpoint.py new file mode 100644 index 0000000..ae0a387 --- /dev/null +++ b/frogbox/handlers/checkpoint.py @@ -0,0 +1,165 @@ +from typing import ( + Mapping, + Any, + Union, + Callable, + Sequence, + Optional, + List, + Dict, +) +import os +from pathlib import Path +from collections import namedtuple +import torch +from accelerate import Accelerator +from accelerate.utils.other import is_compiled_module +from ..config import Config + + +SavedCheckpoint = namedtuple("SavedCheckpoint", ["filename", "priority"]) + + +class Checkpoint: + def __init__( + self, + accelerator: Accelerator, + config: Config, + to_save: Mapping[str, Any], + output_folder: Union[str, os.PathLike], + global_step_function: Callable[[], int], + score_function: Optional[Callable[[], float]] = None, + score_name: Optional[str] = None, + score_mode: str = "max", + to_unwrap: Optional[Sequence[str]] = None, + filename_prefix: str = "checkpoint", + max_saved: int = 3, + ): + assert score_mode in ("min", "max") + + self._accelerator = accelerator + self._config = config + self._to_save = to_save + self._output_folder = output_folder + self._global_step_function = global_step_function + self._score_function = score_function + self._score_name = score_name + self._score_mode = score_mode + self._filename_prefix = filename_prefix + self._max_saved = max_saved + if to_unwrap is None: + to_unwrap = [] + self._to_unwrap = to_unwrap + + self._saved: List[SavedCheckpoint] = [] + + def _get_filename( + self, + step: Optional[int], + score: Optional[float], + ) -> str: + name = str(Path(self._output_folder) / self._filename_prefix) + + if step is not None: + name += "_" + str(step) + + if score is not None: + if self._score_name is not None: + name += f"_{self._score_name}={score:.4f}" + else: + name += f"_{score:.4f}" + + name += ".pt" + return name + + def _save_checkpoint( + self, + filename: str, + ) -> None: + if not self._accelerator.is_local_main_process: + return + + # Create parent directory if it doesn't exist + parent_dir = Path(filename).parent + parent_dir.mkdir(parents=True, exist_ok=True) + + # Save config file + config_json = self._config.model_dump_json(indent=4, exclude_none=True) + with (parent_dir / "config.json").open("w") as fp: + fp.write(config_json) + + # Extract state dicts from objects + state_dicts = {} + for key, obj in self._to_save.items(): + if key in self._to_unwrap: + obj = self._accelerator.unwrap_model(obj) + if is_compiled_module(obj): + obj = obj._orig_mod + state_dicts[key] = obj.state_dict() + + torch.save(state_dicts, filename) + + def __call__(self) -> None: + # Compute checkpoint score/priority + score = None + priority = 0.0 + if self._score_function is not None: + score = self._score_function() + priority = score + if self._score_mode == "min": + priority = -priority + + # Check if new checkpoint should be accepted + if len(self._saved) == self._max_saved: + # Remove old lowest priority checkpoint + min_index = min( + range(len(self._saved)), key=lambda i: self._saved[i].priority + ) + min_priority = self._saved[min_index].priority + + if priority < min_priority: + return + + os.remove(self._saved[min_index].filename) + self._saved.pop(min_index) + + # Get output filename + step = self._global_step_function() + filename = self._get_filename(step, score) + + # Save checkpoint + self._save_checkpoint(filename) + + # Record new checkpoint + self._saved.append(SavedCheckpoint(filename, priority)) + + @staticmethod + def load_checkpoint( + accelerator: Accelerator, + path: Union[str, os.PathLike], + to_load: Mapping[str, Any], + to_unwrap: Optional[Sequence[str]] = None, + ) -> None: + if to_unwrap is None: + to_unwrap = [] + + ckpt = torch.load( + f=path, + map_location="cpu", + weights_only=True, + ) + + for key, obj in to_load.items(): + if key in to_unwrap: + obj = accelerator.unwrap_model(obj) + if is_compiled_module(obj): + obj = obj._orig_mod + + obj.load_state_dict(ckpt[key]) + + def state_dict(self) -> Dict[str, Any]: + saved = list(map(tuple, self._saved)) + return dict(saved=saved) + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + self._saved = [SavedCheckpoint(*e) for e in state_dict["saved"]] diff --git a/frogbox/handlers/composite_loss_logger.py b/frogbox/handlers/composite_loss_logger.py new file mode 100644 index 0000000..5638702 --- /dev/null +++ b/frogbox/handlers/composite_loss_logger.py @@ -0,0 +1,30 @@ +from typing import Optional, Callable, Any, Dict +from ..engines.events import EventStep +from ..engines.engine import Engine +from ..pipelines.composite_loss import CompositeLoss + + +class CompositeLossLogger: + def __init__( + self, + loss: CompositeLoss, + log_function: Callable[[Any], None], + prefix: Optional[str] = None, + ): + self._loss = loss + self._log_function = log_function + if prefix is None: + prefix = "" + self._prefix = prefix + + def attach(self, engine: Engine) -> None: + engine.add_event_handler( + event=EventStep.ITERATION_COMPLETED, + function=self._iteration_completed, + ) + + def _iteration_completed(self) -> None: + data: Dict[str, Any] = {} + for label, loss in zip(self._loss.labels, self._loss.last_values): + data[self._prefix + label] = loss + self._log_function(data) diff --git a/frogbox/handlers/metric_logger.py b/frogbox/handlers/metric_logger.py new file mode 100644 index 0000000..076ed3a --- /dev/null +++ b/frogbox/handlers/metric_logger.py @@ -0,0 +1,43 @@ +from typing import Mapping, Callable, Optional, Any +from torchmetrics import Metric +from ..engines.engine import Engine +from ..engines.events import EventStep + + +class MetricLogger: + def __init__( + self, + metrics: Mapping[str, Metric], + log_function: Callable[[Any], None], + prefix: Optional[str] = None, + ): + self._metrics = metrics + self._log_function = log_function + if prefix is None: + prefix = "" + self._prefix = prefix + + def attach(self, engine: Engine) -> None: + engine.add_output_handler(self._handle_output) + engine.add_event_handler( + event=EventStep.EPOCH_STARTED, + function=self._epoch_started, + ) + engine.add_event_handler( + event=EventStep.EPOCH_COMPLETED, + function=self._epoch_completed, + ) + + def _handle_output(self, outputs) -> None: + for metric in self._metrics.values(): + metric(*outputs) + + def _epoch_started(self) -> None: + for metric in self._metrics.values(): + metric.reset() + + def _epoch_completed(self) -> None: + data = {} + for label, metric in self._metrics.items(): + data[self._prefix + label] = metric.compute().item() + self._log_function(data) diff --git a/frogbox/handlers/output_logger.py b/frogbox/handlers/output_logger.py new file mode 100644 index 0000000..641e5eb --- /dev/null +++ b/frogbox/handlers/output_logger.py @@ -0,0 +1,21 @@ +from typing import Callable, Any +from ..engines.engine import Engine + + +class OutputLogger: + def __init__( + self, + label: str, + log_function: Callable[[Any], None], + output_transform: Callable[[Any], Any] = lambda x: x, + ): + self._label = label + self._log_function = log_function + self._output_transform = output_transform + + def attach(self, engine: Engine): + engine.add_output_handler(self._handle_output) + + def _handle_output(self, output: Any) -> None: + output = self._output_transform(output) + self._log_function({self._label: output}) diff --git a/frogbox/pipelines/__init__.py b/frogbox/pipelines/__init__.py index a571bea..2fe9dbe 100644 --- a/frogbox/pipelines/__init__.py +++ b/frogbox/pipelines/__init__.py @@ -10,4 +10,5 @@ The `frogbox.pipelines.gan.GANPipeline` is used for training a generative adversarial model. + """ diff --git a/frogbox/pipelines/gan.py b/frogbox/pipelines/gan.py index a7b33af..01fdf47 100644 --- a/frogbox/pipelines/gan.py +++ b/frogbox/pipelines/gan.py @@ -4,8 +4,8 @@ The GAN pipeline is similar to the supervised pipelines, except that it adds another model, the discriminator, with its own loss function(s). -The discriminator model is configured in the `disc_model ` similarly to the -(generator) model: +The discriminator model is configured in the `disc_model` field similarly +to the (generator) model: ```json { @@ -28,8 +28,9 @@ loss function for the generator and `disc_losses` defines the loss function for the disciminator. -The discriminator loss takes two optional arguments, `disc_real` and -`disc_fake`, and the generator loss takes one optional argument, `disc_fake`. +The discriminator loss takes two keyword arguments, `disc_real` and +`disc_fake`. +The generator loss takes one optional argument, `disc_fake`. These tensors contain the predictions from the discriminator model when passed the batch of real and fake data, respectively. @@ -56,53 +57,34 @@ def forward(self, input, target, disc_real, disc_fake): loss_fake = self.loss_fn(disc_fake, torch.zeros_like(disc_fake)) return loss_real + loss_fake ``` - -## Updating models at different intervals - -It is possible to update the generator and disciminator models at different -intervals using the `update_interval` and `disc_update_interval` fields. -For instance, in order to update the discriminator only every five iterations: - -```json -{ - "type": "gan", - "update_interval": 1, - "disc_update_interval": 5, - ... -} -``` """ -from typing import Any, Dict, Callable, Union, Optional, Sequence + +from typing import Dict, Optional, Union, Sequence, Callable, Any from os import PathLike -from pathlib import Path +from functools import partial import torch from torch.utils.data import Dataset, DataLoader -from ignite.engine import Events, CallableEventWithFilter -from ignite.handlers import global_step_from_engine -from ignite.contrib.handlers import ProgressBar from accelerate import Accelerator +from torchmetrics import Metric from .pipeline import Pipeline -from .common import ( - create_composite_loss, - create_lr_scheduler, -) -from ..config import ( - GANConfig, - create_object_from_config, - parse_log_interval, -) -from ..engines.supervised import create_supervised_evaluator -from ..engines.gan import create_gan_trainer +from ..config import GANConfig, create_object_from_config, parse_log_interval +from ..engines.gan import GANTrainer +from ..engines.supervised import SupervisedEvaluator +from .lr_scheduler import create_lr_scheduler from .composite_loss import CompositeLoss -from .logger import AccelerateLogger +from ..handlers.output_logger import OutputLogger +from ..handlers.metric_logger import MetricLogger +from ..handlers.composite_loss_logger import CompositeLossLogger +from ..handlers.checkpoint import Checkpoint class GANPipeline(Pipeline): """GAN pipeline.""" config: GANConfig - log_interval: CallableEventWithFilter + trainer: GANTrainer + evaluator: SupervisedEvaluator datasets: Dict[str, Dataset] loaders: Dict[str, DataLoader] model: torch.nn.Module @@ -113,13 +95,13 @@ class GANPipeline(Pipeline): disc_lr_scheduler: torch.optim.lr_scheduler.LRScheduler loss_fn: CompositeLoss disc_loss_fn: CompositeLoss + metrics: Dict[str, Metric] def __init__( self, config: GANConfig, checkpoint: Optional[Union[str, PathLike]] = None, checkpoint_keys: Optional[Sequence[str]] = None, - checkpoint_dir: Union[str, PathLike] = Path("checkpoints"), logging: str = "online", wandb_id: Optional[str] = None, tags: Optional[Sequence[str]] = None, @@ -133,7 +115,7 @@ def __init__( [Any], Any ] = lambda output: output, trainer_output_transform: Callable[ - [Any, Any, Any, torch.Tensor, torch.Tensor], Any + [Any, Any, Any, Any, Any], Any ] = lambda x, y, y_pred, loss, disc_loss: ( loss.item(), disc_loss.item(), @@ -151,60 +133,12 @@ def __init__( ): """ Create GAN pipeline. - - Parameters - ---------- - config : GANConfig - Pipeline configuration. - checkpoint : path-like - Path to experiment checkpoint. - checkpoint_keys : list of str - List of keys for objects to load from checkpoint. - Defaults to all keys. - checkpoint_dir : str or path - Path to directory to store checkpoints. - logging : str - Logging mode. Must be either "online" or "offline". - wandb_id : str - W&B run ID to resume from. - tags : list of str - List of tags to add to the run in W&B. - group : str - Group to add run to in W&B. - trainer_input_transform : Callable - Function that receives tensors `x` and `y` and outputs tuple of - tensors `(x, y)`. - trainer_model_transform : Callable - Function that receives the output from the model during training - and converts it into the form as required by the loss function. - trainer_disc_model_transform : Callable - Function that receives the output from the discriminator - during training and converts it into the form as required - by the loss function. - trainer_output_transform : Callable - Function that receives `x`, `y`, `y_pred`, `loss` and returns value - to be assigned to trainer's `state.output` after each iteration. - Default is returning `loss.item()`. - evaluator_input_transform : Callable - Function that receives tensors `x` and `y` and outputs tuple of - tensors `(x, y)`. - evaluator_model_transform : Callable - Function that receives the output from the model during evaluation - and converts it into the predictions: - `y_pred = model_transform(model(x))`. - evaluator_output_transform : Callable - Function that receives `x`, `y`, `y_pred` and returns value to be - assigned to evaluator's `state.output` after each iteration. - Default is returning `(y_pred, y)` which fits output expected by - metrics. """ - logging = logging.lower() - assert logging in ("online", "offline") # Parse config + logging = logging.lower() + assert logging in ("online", "offline") self.config = config - self.log_interval = parse_log_interval(config.log_interval) - self._generate_name() # Create accelerator self.accelerator = Accelerator( @@ -212,6 +146,13 @@ def __init__( log_with="wandb", ) + self._setup_tracking( + mode=logging, + wandb_id=wandb_id, + tags=tags, + group=group, + ) + # Create datasets and data loaders self.datasets, self.loaders = self._create_data_loaders( batch_size=config.batch_size, @@ -220,13 +161,14 @@ def __init__( loaders=config.loaders, ) - # Create models + # Create model self.model = create_object_from_config(config.model) self.optimizer = create_object_from_config( config=config.optimizer, params=self.model.parameters(), ) + # Create discriminator model self.disc_model = create_object_from_config(config.disc_model) self.disc_optimizer = create_object_from_config( config=config.disc_optimizer, @@ -246,14 +188,6 @@ def __init__( max_iterations=max_iterations, ) - # Load model weights before potential compilation - if checkpoint: - self._load_checkpoint( - path=checkpoint, - to_load={"model": self.model, "disc_model": self.disc_model}, - keys=checkpoint_keys, - ) - # Wrap with accelerator self.model, self.optimizer, self.lr_scheduler = ( self.accelerator.prepare( @@ -268,140 +202,120 @@ def __init__( for split in self.loaders.keys(): self.loaders[split] = self.accelerator.prepare(self.loaders[split]) - # Create trainer - self.loss_fn = create_composite_loss( - config.losses, self.accelerator.device - ) - self.disc_loss_fn = create_composite_loss( - config.disc_losses, self.accelerator.device - ) - self.trainer = create_gan_trainer( + self.loss_fn = self._create_composite_loss(config.losses) + self.disc_loss_fn = self._create_composite_loss(config.disc_losses) + self.trainer = GANTrainer( accelerator=self.accelerator, model=self.model, disc_model=self.disc_model, + loss_fn=self.loss_fn, + disc_loss_fn=self.disc_loss_fn, optimizer=self.optimizer, disc_optimizer=self.disc_optimizer, scheduler=self.lr_scheduler, disc_scheduler=self.disc_lr_scheduler, - loss_fn=self.loss_fn, - disc_loss_fn=self.disc_loss_fn, clip_grad_norm=config.clip_grad_norm, clip_grad_value=config.clip_grad_value, input_transform=trainer_input_transform, model_transform=trainer_model_transform, disc_model_transform=trainer_disc_model_transform, output_transform=trainer_output_transform, + progress_label="train", ) - # Create evaluator - metrics = {} - for metric_label, metric_conf in config.metrics.items(): - metrics[metric_label] = create_object_from_config(metric_conf) + OutputLogger("train/loss", self.log, lambda o: o[0]).attach( + self.trainer + ) + OutputLogger("train/disc_loss", self.log, lambda o: o[1]).attach( + self.trainer + ) + CompositeLossLogger(self.loss_fn, self.log, "loss/").attach( + self.trainer + ) + CompositeLossLogger(self.disc_loss_fn, self.log, "disc_loss/").attach( + self.trainer + ) - self.evaluator = create_supervised_evaluator( + # Create evaluator + self.evaluator = SupervisedEvaluator( accelerator=self.accelerator, model=self.model, - metrics=metrics, input_transform=evaluator_input_transform, model_transform=evaluator_model_transform, output_transform=evaluator_output_transform, + progress_label="val", ) - self.trainer.add_event_handler( - event_name=self.log_interval, - handler=lambda: self.evaluator.run(self.loaders["val"]), - ) - - # Set up checkpoints - self._setup_checkpoints( - to_save={ - "model": self.model, - "disc_model": self.disc_model, - "optimizer": self.optimizer, - "disc_optimizer": self.disc_optimizer, - "trainer": self.trainer, - "lr_scheduler": self.lr_scheduler, - "disc_lr_scheduler": self.disc_lr_scheduler, - }, - checkpoint_dir=checkpoint_dir, - to_unwrap=["model", "disc_model"], + event=self.log_interval, + function=lambda: self.evaluator.run(self.loaders["val"]), ) - if checkpoint: - self._load_checkpoint( - path=checkpoint, - to_load={ - "optimizer": self.optimizer, - "disc_optimizer": self.disc_optimizer, - "trainer": self.trainer, - "lr_scheduler": self.lr_scheduler, - "disc_lr_scheduler": self.disc_lr_scheduler, - }, - keys=checkpoint_keys, + # Set up metric logging + self.metrics = {} + for metric_label, metric_conf in config.metrics.items(): + self.metrics[metric_label] = create_object_from_config( + config=metric_conf, + sync_on_compute=False, + ).to(self.device) + + MetricLogger( + metrics=self.metrics, + log_function=self.log, + prefix="val/", + ).attach(self.evaluator) + + # Set up checkpoint handlers + to_save = { + "trainer": self.trainer, + "model": self.model, + "disc_model": self.disc_model, + "optimizer": self.optimizer, + "disc_optimizer": self.disc_optimizer, + "scheduler": self.lr_scheduler, + "disc_scheduler": self.disc_lr_scheduler, + } + to_unwrap = ["model", "disc_model"] + output_folder = f"checkpoints/{self.run_name}" + + for ckpt_def in config.checkpoints: + score_function = None + if ckpt_def.metric is not None: + score_function = partial( + lambda metric: metric.compute().item(), + metric=self.metrics[ckpt_def.metric], + ) + + checkpoint_handler = Checkpoint( + accelerator=self.accelerator, + config=self.config, + to_save=to_save, + output_folder=output_folder, + global_step_function=lambda: self.trainer.iteration, + score_function=score_function, + score_name=ckpt_def.metric, + score_mode=ckpt_def.mode, + to_unwrap=to_unwrap, + max_saved=ckpt_def.num_saved, ) - - # Set up logging - self._setup_tracking( - mode=logging, - wandb_id=wandb_id, - tags=tags, - group=group, - ) - - if self.accelerator.is_main_process: - ProgressBar(desc="Train", ncols=80).attach(self.trainer) - ProgressBar(desc="Val", ncols=80).attach(self.evaluator) - - def log_losses(): - fns = (self.loss_fn, self.disc_loss_fn) - prefixes = ("loss", "disc_loss") - for prefix, fn in zip(prefixes, fns): - labels = [f"{prefix}/{label}" for label in fn.labels] - losses = dict(zip(labels, fn.last_values)) - self.log(losses) - self.trainer.add_event_handler( - Events.ITERATION_COMPLETED, log_losses + event=parse_log_interval(ckpt_def.interval), + function=checkpoint_handler, ) - self.logger = AccelerateLogger(self.accelerator) - self.logger.attach_output_handler( - engine=self.trainer, - event_name=Events.ITERATION_COMPLETED, - tag="train", - output_transform=lambda losses: { - "loss": losses[0], - "disc_loss": losses[1], - }, - ) - self.logger.attach_opt_params_handler( - engine=self.trainer, - event_name=Events.ITERATION_COMPLETED, - optimizer=self.optimizer, - tag="optimizer", - param_name="lr", - ) - self.logger.attach_opt_params_handler( - engine=self.trainer, - event_name=Events.ITERATION_COMPLETED, - optimizer=self.disc_optimizer, - tag="disc_optimizer", - param_name="lr", - ) - self.logger.attach_output_handler( - engine=self.evaluator, - event_name=Events.COMPLETED, - tag="val", - metric_names="all", - global_step_transform=global_step_from_engine( - self.trainer, Events.ITERATION_COMPLETED - ), + # Load checkpoint + if checkpoint is not None: + self._load_checkpoint( + path=checkpoint, + to_load=to_save, + to_unwrap=to_unwrap, + keys=checkpoint_keys, ) def run(self) -> None: + """Run pipeline.""" try: self.trainer.run( - data=self.loaders["train"], + loader=self.loaders["train"], max_epochs=self.config.max_epochs, ) except KeyboardInterrupt: diff --git a/frogbox/pipelines/logger.py b/frogbox/pipelines/logger.py deleted file mode 100644 index 9f63264..0000000 --- a/frogbox/pipelines/logger.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Any, Callable, List, Optional, Union -from torch.optim import Optimizer -from ignite.engine import Engine, Events -from ignite.handlers.base_logger import ( - BaseLogger, - BaseOptimizerParamsHandler, - BaseOutputHandler, -) -from accelerate import Accelerator - - -class AccelerateLogger(BaseLogger): - def __init__(self, accelerator: Accelerator): - self._accelerator = accelerator - - def _create_output_handler( - self, *args: Any, **kwargs: Any - ) -> "OutputHandler": - return OutputHandler(*args, **kwargs) - - def _create_opt_params_handler( - self, *args: Any, **kwargs: Any - ) -> "OptimizerParamsHandler": - return OptimizerParamsHandler(*args, **kwargs) - - -class OutputHandler(BaseOutputHandler): - def __init__( - self, - tag: str, - metric_names: Optional[List[str]] = None, - output_transform: Optional[Callable] = None, - global_step_transform: Optional[ - Callable[[Engine, Union[str, Events]], int] - ] = None, - state_attributes: Optional[List[str]] = None, - ): - super().__init__( - tag, - metric_names, - output_transform, - global_step_transform, - state_attributes, - ) - - def __call__( - self, - engine: Engine, - logger: AccelerateLogger, - event_name: Union[str, Events], - ) -> None: - if not isinstance(logger, AccelerateLogger): - raise RuntimeError( - f"Handler '{self.__class__.__name__}'" - " works only with AccelerateLogger." - ) - - global_step = self.global_step_transform(engine, event_name) - if not isinstance(global_step, int): - raise TypeError( - f"global_step must be int, got {type(global_step)}." - " Please check the output of global_step_transform." - ) - - metrics = self._setup_output_metrics_state_attrs( - engine, log_text=True, key_tuple=False - ) - logger._accelerator.log(metrics, step=global_step) - - -class OptimizerParamsHandler(BaseOptimizerParamsHandler): - def __init__( - self, - optimizer: Optimizer, - param_name: str = "lr", - tag: Optional[str] = None, - ): - super(OptimizerParamsHandler, self).__init__( - optimizer, param_name, tag - ) - - def __call__( - self, - engine: Engine, - logger: AccelerateLogger, - event_name: Union[str, Events], - ) -> None: - if not isinstance(logger, AccelerateLogger): - raise RuntimeError( - "Handler OptimizerParamsHandler works" - " only with AccelerateLogger." - ) - - global_step = engine.state.get_event_attrib_value(event_name) - tag_prefix = f"{self.tag}/" if self.tag else "" - params = { - f"{tag_prefix}{self.param_name}/group_{i}": float( - param_group[self.param_name] - ) - for i, param_group in enumerate(self.optimizer.param_groups) - } - logger._accelerator.log(params, step=global_step) diff --git a/frogbox/pipelines/common.py b/frogbox/pipelines/lr_scheduler.py similarity index 65% rename from frogbox/pipelines/common.py rename to frogbox/pipelines/lr_scheduler.py index ff61270..ec3ad6a 100644 --- a/frogbox/pipelines/common.py +++ b/frogbox/pipelines/lr_scheduler.py @@ -1,32 +1,5 @@ -from typing import Dict, Union import torch -from ..config import ( - LossDefinition, - SchedulerType, - LRSchedulerDefinition, - create_object_from_config, -) -from .composite_loss import CompositeLoss - - -def create_composite_loss( - config: Dict[str, LossDefinition], - device: Union[str, torch.device], -) -> CompositeLoss: - loss_labels = [] - loss_modules = [] - loss_weights = [] - for loss_label, loss_conf in config.items(): - loss_labels.append(loss_label) - loss_modules.append(create_object_from_config(loss_conf)) - loss_weights.append(loss_conf.weight) - - loss_fn = CompositeLoss( - labels=loss_labels, - losses=loss_modules, - weights=loss_weights, - ).to(torch.device(device)) - return loss_fn +from ..config import LRSchedulerDefinition, SchedulerType def create_lr_scheduler( diff --git a/frogbox/pipelines/pipeline.py b/frogbox/pipelines/pipeline.py index 7edb5cd..30ce860 100644 --- a/frogbox/pipelines/pipeline.py +++ b/frogbox/pipelines/pipeline.py @@ -1,34 +1,31 @@ from typing import ( - Union, - Optional, + Dict, Any, - Callable, - Sequence, + Optional, Tuple, - Dict, + Mapping, + Sequence, + Union, + Callable, ) -from abc import ABC, abstractmethod from os import PathLike -from pathlib import Path +from abc import ABC, abstractmethod import datetime -from functools import partial import torch from torch.utils.data import Dataset, DataLoader -from ignite.engine import Engine, Events, CallableEventWithFilter -from ignite.handlers import global_step_from_engine, Checkpoint -from ignite.handlers.checkpoint import BaseSaveHandler -from ignite.handlers.base_logger import BaseLogger from accelerate import Accelerator -from accelerate.utils.other import is_compiled_module -from .save_handler import NoneSaveHandler, AccelerateDiskSaver -from .name_generation import generate_name +from ..engines.engine import Engine +from ..engines.events import MatchableEvent from ..config import ( Config, - CheckpointMode, ObjectDefinition, - create_object_from_config, + LossDefinition, parse_log_interval, + create_object_from_config, ) +from ..handlers.checkpoint import Checkpoint +from .composite_loss import CompositeLoss +from .name_generation import generate_name class Pipeline(ABC): @@ -37,25 +34,15 @@ class Pipeline(ABC): config: Config accelerator: Accelerator trainer: Engine - evaluator: Engine - logger: BaseLogger - run_name: str - @abstractmethod - def run(self) -> None: ... - - def _generate_name(self) -> None: - suffix = generate_name() - now = datetime.datetime.now(datetime.timezone.utc) - timestamp = now.strftime("%Y%m%d-%H%M") - self.run_name = f"{timestamp}-{suffix}" + _run_name: Optional[str] = None def _create_data_loaders( self, batch_size: int, loader_workers: int, - datasets: Dict[str, ObjectDefinition], - loaders: Optional[Dict[str, ObjectDefinition]] = None, + datasets: Mapping[str, ObjectDefinition], + loaders: Optional[Mapping[str, ObjectDefinition]] = None, ) -> Tuple[Dict[str, Dataset], Dict[str, DataLoader]]: if loaders is None: loaders = {} @@ -85,100 +72,25 @@ def _create_data_loaders( return out_datasets, out_loaders - def _setup_checkpoints( + def _create_composite_loss( self, - to_save: Dict[str, Any], - checkpoint_dir: Union[str, PathLike], - to_unwrap: Optional[Sequence[str]] = None, - ) -> None: - run_dir = Path(checkpoint_dir) / self.run_name - - save_handler: BaseSaveHandler - if self.accelerator.is_main_process: - save_handler = AccelerateDiskSaver( - dirname=str(run_dir), - accelerator=self.accelerator, - ) - - config_json = self.config.model_dump_json( - indent=True, exclude_none=True - ) - run_dir.mkdir(parents=True, exist_ok=True) - with (run_dir / "config.json").open("w") as fp: - fp.write(config_json) - else: - save_handler = NoneSaveHandler() - - def evaluator_score_fn( - engine: Engine, - evaluator: Engine, - metric: str, - sign: float, - ): - return sign * evaluator.state.metrics[metric] - - def unwrap_to_save_fn(): - output = {} - for k, v in to_save.items(): - if k in to_unwrap: - v = self.accelerator.unwrap_model(v) - if is_compiled_module(v): - v = v._orig_mod - output[k] = v - return output - - def handler_fn(engine: Engine, handler: Checkpoint): - handler.to_save = unwrap_to_save_fn() - return handler(engine) - - for checkpoint in self.config.checkpoints: - score_function = None - if checkpoint.metric: - score_sign = ( - 1.0 if checkpoint.mode == CheckpointMode.MAX else -1.0 - ) - score_function = partial( - evaluator_score_fn, - evaluator=self.evaluator, - metric=checkpoint.metric, - sign=score_sign, - ) - - log_interval = parse_log_interval(checkpoint.interval) - handler = Checkpoint( - to_save=unwrap_to_save_fn(), - save_handler=save_handler, - score_name=checkpoint.metric, - score_function=score_function, - n_saved=checkpoint.n_saved, - global_step_transform=global_step_from_engine( - self.trainer, - Events(log_interval.value), - ), - ) - - self.trainer.add_event_handler( - log_interval, partial(handler_fn, handler=handler) - ) - - def _load_checkpoint( - self, - path: Union[str, PathLike], - to_load: Dict[str, Any], - keys: Optional[Sequence[str]] = None, - ) -> None: - if keys is None: - keys = list(to_load.keys()) - to_load = {k: to_load[k] for k in keys if k in to_load} - - Checkpoint.load_objects( - to_load=to_load, - checkpoint=torch.load( - str(path), map_location="cpu", weights_only=True - ), + config: Mapping[str, LossDefinition], + ) -> CompositeLoss: + loss_labels = [] + loss_modules = [] + loss_weights = [] + for loss_label, loss_conf in config.items(): + loss_labels.append(loss_label) + loss_modules.append(create_object_from_config(loss_conf)) + loss_weights.append(loss_conf.weight) + + loss_fn = CompositeLoss( + labels=loss_labels, + losses=loss_modules, + weights=loss_weights, ) - - self.accelerator.wait_for_everyone() + loss_fn = loss_fn.to(self.device) + return loss_fn def _setup_tracking( self, @@ -188,7 +100,7 @@ def _setup_tracking( group: Optional[str] = None, ) -> None: self.accelerator.init_trackers( - self.config.project, + project_name=self.config.project, config=self.config.model_dump(), init_kwargs={ "wandb": { @@ -201,34 +113,79 @@ def _setup_tracking( }, ) + def _load_checkpoint( + self, + path: Union[str, PathLike], + to_load: Mapping[str, Any], + to_unwrap: Optional[Sequence[str]] = None, + keys: Optional[Sequence[str]] = None, + ) -> None: + """ + Load checkpoint from file. + + Attributes + ---------- + path : path-like + Path to checkpoint file. + to_load : mapping + Mapping with objects to load. + to_unwrap : list of str + Keys for objects to unwrap before loading. + keys : list of str (optional) + List of keys to filter. + """ + if keys is None: + keys = list(to_load.keys()) + to_load = {k: to_load[k] for k in keys} + + Checkpoint.load_checkpoint( + accelerator=self.accelerator, + path=path, + to_load=to_load, + to_unwrap=to_unwrap, + ) + self.accelerator.wait_for_everyone() + def install_callback( self, - event: CallableEventWithFilter, + event: MatchableEvent, callback: Callable[["Pipeline"], None], + engine: str = "trainer", only_main_process: bool = False, + **kwargs, ) -> None: """ Install callback in pipeline. Parameters ---------- - event : Events - Event to trigger callback at. + event : MatchableEvent + Event to trigger callback. callback : callable - Callback to install. + Callback function. + Should take a single argument `pipeline` and return nothing. + engine : str + Which engine to install callback in. Defaults to "trainer". only_main_process : bool Install only in main process. Only affects distributed setups. + kwargs : keyword arguments + Optional keyword arguments to be passed to callback. """ if not only_main_process or self.accelerator.is_main_process: - self.trainer.add_event_handler( - event_name=event, - handler=callback, - pipeline=self, - ) + target = getattr(self, engine) + if not isinstance(target, Engine): + raise ValueError( + f"'{engine}' is not an engine. Cannot install callback." + ) + + target.add_event_handler(event, callback, self, **kwargs) + + @abstractmethod + def run(self) -> None: ... def log(self, data: Dict[str, Any]) -> None: """Log data to tracker(s).""" - self.accelerator.log(data, step=self.trainer.state.iteration) + self.accelerator.log(data, step=self.trainer.iteration) def print(self, *args, **kwargs) -> None: """Drop in replacement of `print()` to only print once per server.""" @@ -260,3 +217,17 @@ def is_main_process(self) -> bool: def is_local_main_process(self) -> bool: """True for one process per server.""" return self.accelerator.is_local_main_process + + @property + def run_name(self) -> str: + """Get name of current run.""" + if self._run_name is None: + suffix = generate_name() + now = datetime.datetime.now(datetime.timezone.utc) + timestamp = now.strftime("%Y%m%d-%H%M") + self._run_name = f"{timestamp}-{suffix}" + return self._run_name + + @property + def log_interval(self) -> MatchableEvent: + return parse_log_interval(self.config.log_interval) diff --git a/frogbox/pipelines/save_handler.py b/frogbox/pipelines/save_handler.py deleted file mode 100644 index a293b7b..0000000 --- a/frogbox/pipelines/save_handler.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Union, Optional, Callable, Mapping -import os -from pathlib import Path -import tempfile -import stat -from accelerate import Accelerator -from ignite.handlers.checkpoint import BaseSaveHandler - - -class NoneSaveHandler(BaseSaveHandler): - """@private""" - - def __call__( - self, - checkpoint: Mapping, - filename: str, - metadata: Optional[Mapping] = None, - ) -> None: - pass - - def remove(self, filename: str) -> None: - pass - - -class AccelerateDiskSaver(BaseSaveHandler): - """@private""" - - def __init__( - self, - dirname: Union[str, os.PathLike], - accelerator: Accelerator, - atomic: bool = True, - **kwargs, - ): - self.dirname = Path(dirname).expanduser() - self.accelerator = accelerator - self.atomic = atomic - self.kwargs = kwargs - - if not self.dirname.exists(): - self.dirname.mkdir(parents=True) - - def __call__( - self, - checkpoint: Mapping, - filename: str, - metadata: Optional[Mapping] = None, - ) -> None: - path = self.dirname / filename - self._save_func(checkpoint, path, self.accelerator.save) - - def _save_func( - self, checkpoint: Mapping, path: Path, func: Callable - ) -> None: - if not self.atomic: - func(checkpoint, path, **self.kwargs) - else: - tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) - tmp_file = tmp.file - tmp_name = tmp.name - try: - func(checkpoint, tmp_file, **self.kwargs) - except BaseException: - tmp.close() - os.remove(tmp_name) - raise - else: - tmp.close() - os.replace(tmp.name, path) - # append group/others read mode - os.chmod( - path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH - ) - - def remove(self, filename: str) -> None: - path = self.dirname / filename - path.unlink() diff --git a/frogbox/pipelines/supervised.py b/frogbox/pipelines/supervised.py index a8ba440..5b07ac3 100644 --- a/frogbox/pipelines/supervised.py +++ b/frogbox/pipelines/supervised.py @@ -1,48 +1,44 @@ -from typing import Any, Dict, Optional, Union, Sequence, Callable +from typing import Dict, Optional, Sequence, Union, Callable, Any from os import PathLike -from pathlib import Path +from functools import partial import torch from torch.utils.data import Dataset, DataLoader -from ignite.engine import Events, CallableEventWithFilter -from ignite.handlers import global_step_from_engine -from ignite.contrib.handlers import ProgressBar from accelerate import Accelerator +from torchmetrics import Metric from .pipeline import Pipeline -from .common import ( - create_composite_loss, - create_lr_scheduler, -) from ..config import ( SupervisedConfig, create_object_from_config, parse_log_interval, ) -from ..engines.supervised import ( - create_supervised_trainer, - create_supervised_evaluator, -) +from ..engines.supervised import SupervisedTrainer, SupervisedEvaluator +from .lr_scheduler import create_lr_scheduler from .composite_loss import CompositeLoss -from .logger import AccelerateLogger +from ..handlers.output_logger import OutputLogger +from ..handlers.metric_logger import MetricLogger +from ..handlers.composite_loss_logger import CompositeLossLogger +from ..handlers.checkpoint import Checkpoint class SupervisedPipeline(Pipeline): """Supervised pipeline.""" config: SupervisedConfig - log_interval: CallableEventWithFilter + trainer: SupervisedTrainer + evaluator: SupervisedEvaluator datasets: Dict[str, Dataset] loaders: Dict[str, DataLoader] model: torch.nn.Module optimizer: torch.optim.Optimizer lr_scheduler: torch.optim.lr_scheduler.LRScheduler loss_fn: CompositeLoss + metrics: Dict[str, Metric] def __init__( self, config: SupervisedConfig, checkpoint: Optional[Union[str, PathLike]] = None, checkpoint_keys: Optional[Sequence[str]] = None, - checkpoint_dir: Union[str, PathLike] = Path("checkpoints"), logging: str = "online", wandb_id: Optional[str] = None, tags: Optional[Sequence[str]] = None, @@ -53,7 +49,7 @@ def __init__( ), trainer_model_transform: Callable[[Any], Any] = lambda output: output, trainer_output_transform: Callable[ - [Any, Any, Any, torch.Tensor], Any + [Any, Any, Any, Any], Any ] = lambda x, y, y_pred, loss: loss.item(), evaluator_input_transform: Callable[[Any, Any], Any] = lambda x, y: ( x, @@ -78,8 +74,6 @@ def __init__( checkpoint_keys : list of str List of keys for objects to load from checkpoint. Defaults to all keys. - checkpoint_dir : str or path - Path to directory to store checkpoints. logging : str Logging mode. Must be either "online" or "offline". wandb_id : str @@ -88,36 +82,34 @@ def __init__( List of tags to add to the run in W&B. group : str Group to add run to in W&B. - trainer_input_transform : Callable + trainer_input_transform : callable Function that receives tensors `x` and `y` and outputs tuple of tensors `(x, y)`. - trainer_model_transform : Callable + trainer_model_transform : callable Function that receives the output from the model during training and converts it into the form as required by the loss function. - trainer_output_transform : Callable + trainer_output_transform : callable Function that receives `x`, `y`, `y_pred`, `loss` and returns value to be assigned to trainer's `state.output` after each iteration. Default is returning `loss.item()`. - evaluator_input_transform : Callable + evaluator_input_transform : callable Function that receives tensors `x` and `y` and outputs tuple of tensors `(x, y)`. - evaluator_model_transform : Callable + evaluator_model_transform : callable Function that receives the output from the model during evaluation and converts it into the predictions: `y_pred = model_transform(model(x))`. - evaluator_output_transform : Callable + evaluator_output_transform : callable Function that receives `x`, `y`, `y_pred` and returns value to be - assigned to evaluator's `state.output` after each iteration. + passed to output handlers after each iteration. Default is returning `(y_pred, y)` which fits output expected by metrics. """ - logging = logging.lower() - assert logging in ("online", "offline") # Parse config + logging = logging.lower() + assert logging in ("online", "offline") self.config = config - self.log_interval = parse_log_interval(config.log_interval) - self._generate_name() # Create accelerator self.accelerator = Accelerator( @@ -125,6 +117,13 @@ def __init__( log_with="wandb", ) + self._setup_tracking( + mode=logging, + wandb_id=wandb_id, + tags=tags, + group=group, + ) + # Create datasets and data loaders self.datasets, self.loaders = self._create_data_loaders( batch_size=config.batch_size, @@ -148,14 +147,6 @@ def __init__( max_iterations=max_iterations, ) - # Load model weights before potential compilation - if checkpoint: - self._load_checkpoint( - path=checkpoint, - to_load={"model": self.model}, - keys=checkpoint_keys, - ) - # Wrap with accelerator self.model, self.optimizer, self.lr_scheduler = ( self.accelerator.prepare( @@ -166,114 +157,103 @@ def __init__( self.loaders[split] = self.accelerator.prepare(self.loaders[split]) # Create trainer - self.loss_fn = create_composite_loss( - config=config.losses, - device=self.accelerator.device, - ) - self.trainer = create_supervised_trainer( + self.loss_fn = self._create_composite_loss(config.losses) + self.trainer = SupervisedTrainer( accelerator=self.accelerator, model=self.model, + loss_fn=self.loss_fn, optimizer=self.optimizer, scheduler=self.lr_scheduler, - loss_fn=self.loss_fn, clip_grad_norm=config.clip_grad_norm, clip_grad_value=config.clip_grad_value, input_transform=trainer_input_transform, model_transform=trainer_model_transform, output_transform=trainer_output_transform, + progress_label="train", ) - # Create evaluator - metrics = {} - for metric_label, metric_conf in config.metrics.items(): - metrics[metric_label] = create_object_from_config(metric_conf) + OutputLogger("train/loss", self.log).attach(self.trainer) + CompositeLossLogger(self.loss_fn, self.log, "loss/").attach( + self.trainer + ) - self.evaluator = create_supervised_evaluator( + # Create evaluator + self.evaluator = SupervisedEvaluator( accelerator=self.accelerator, model=self.model, - metrics=metrics, input_transform=evaluator_input_transform, model_transform=evaluator_model_transform, output_transform=evaluator_output_transform, + progress_label="val", ) - self.trainer.add_event_handler( - event_name=self.log_interval, - handler=lambda: self.evaluator.run(self.loaders["val"]), + event=self.log_interval, + function=lambda: self.evaluator.run(self.loaders["val"]), ) - # Set up checkpoints - self._setup_checkpoints( - to_save={ - "model": self.model, - "optimizer": self.optimizer, - "trainer": self.trainer, - "lr_scheduler": self.lr_scheduler, - }, - checkpoint_dir=checkpoint_dir, - to_unwrap=["model"], - ) + # Set up metric logging + self.metrics = {} + for metric_label, metric_conf in config.metrics.items(): + self.metrics[metric_label] = create_object_from_config( + config=metric_conf, + sync_on_compute=False, + ).to(self.device) - # Load checkpoint - if checkpoint: - self._load_checkpoint( - path=checkpoint, - to_load={ - "optimizer": self.optimizer, - "trainer": self.trainer, - "lr_scheduler": self.lr_scheduler, - }, - keys=checkpoint_keys, - ) + MetricLogger( + metrics=self.metrics, + log_function=self.log, + prefix="val/", + ).attach(self.evaluator) - # Set up logging - self._setup_tracking( - mode=logging, - wandb_id=wandb_id, - tags=tags, - group=group, - ) + # Set up checkpoint handlers + to_save = { + "trainer": self.trainer, + "model": self.model, + "optimizer": self.optimizer, + "scheduler": self.lr_scheduler, + } + to_unwrap = ["model"] + output_folder = f"checkpoints/{self.run_name}" - if self.accelerator.is_main_process: - ProgressBar(desc="Train", ncols=80).attach(self.trainer) - ProgressBar(desc="Val", ncols=80).attach(self.evaluator) - - def log_losses(): - labels = [f"loss/{label}" for label in self.loss_fn.labels] - losses = dict(zip(labels, self.loss_fn.last_values)) - self.log(losses) + for ckpt_def in config.checkpoints: + score_function = None + if ckpt_def.metric is not None: + score_function = partial( + lambda metric: metric.compute().item(), + metric=self.metrics[ckpt_def.metric], + ) + checkpoint_handler = Checkpoint( + accelerator=self.accelerator, + config=self.config, + to_save=to_save, + output_folder=output_folder, + global_step_function=lambda: self.trainer.iteration, + score_function=score_function, + score_name=ckpt_def.metric, + score_mode=ckpt_def.mode, + to_unwrap=to_unwrap, + max_saved=ckpt_def.num_saved, + ) self.trainer.add_event_handler( - Events.ITERATION_COMPLETED, log_losses + event=parse_log_interval(ckpt_def.interval), + function=checkpoint_handler, ) - self.logger = AccelerateLogger(self.accelerator) - self.logger.attach_output_handler( - engine=self.trainer, - event_name=Events.ITERATION_COMPLETED, - tag="train", - output_transform=lambda loss: {"loss": loss}, - ) - self.logger.attach_opt_params_handler( - engine=self.trainer, - event_name=Events.ITERATION_COMPLETED, - optimizer=self.optimizer, - param_name="lr", - ) - self.logger.attach_output_handler( - engine=self.evaluator, - event_name=Events.COMPLETED, - tag="val", - metric_names="all", - global_step_transform=global_step_from_engine( - self.trainer, Events.ITERATION_COMPLETED - ), + # Load checkpoint + if checkpoint is not None: + self._load_checkpoint( + path=checkpoint, + to_load=to_save, + to_unwrap=to_unwrap, + keys=checkpoint_keys, ) def run(self) -> None: + """Run pipeline.""" try: self.trainer.run( - data=self.loaders["train"], + loader=self.loaders["train"], max_epochs=self.config.max_epochs, ) except KeyboardInterrupt: diff --git a/frogbox/service.py b/frogbox/service.py deleted file mode 100644 index 8d879d7..0000000 --- a/frogbox/service.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Dict, Type -from contextlib import asynccontextmanager -from fastapi import FastAPI -from pydantic import BaseModel -from pydantic_settings import BaseSettings -import torch -from .config import Config -from .utils import load_model_checkpoint - - -class BaseServiceSettings(BaseSettings): - checkpoints: Dict[str, str] = {} - device: str = "cpu" - - -class HealthResponse(BaseModel): - status: str = "OK" - - -class BaseService(FastAPI): - configs: Dict[str, Config] - models: Dict[str, torch.nn.Module] - device: torch.device - - def __init__( - self, - request_class: Type[BaseModel], - response_class: Type[BaseModel], - ): - self.settings = BaseServiceSettings() - - super().__init__(lifespan=self._lifespan) - - @self.post("/inference") - async def do_inference( - request: request_class, # type: ignore - ) -> response_class: # type: ignore - """Run model inference.""" - if not hasattr(self, "inference"): - raise RuntimeError("inference method not implemented.") - return self.inference(request) - - @self.get("/health") - async def get_health() -> HealthResponse: - """Get service health.""" - return HealthResponse(status="OK") - - @asynccontextmanager - async def _lifespan(self, app: FastAPI): - self.on_startup() - yield - self.on_shutdown() - - def on_startup(self): - self.configs = {} - self.models = {} - self.device = torch.device(self.settings.device) - - for name, path in self.settings.checkpoints.items(): - model, config = load_model_checkpoint(path) - - model = model.eval().to(self.device) - self.models[name] = model - self.configs[name] = config - - def on_shutdown(self): - del self.models - torch.cuda.empty_cache() diff --git a/frogbox/utils.py b/frogbox/utils.py index 863056f..ccf7ce8 100644 --- a/frogbox/utils.py +++ b/frogbox/utils.py @@ -1,40 +1,20 @@ -""" -# Loading a trained model - -Trained models can be loaded with `frogbox.utils.load_model_checkpoint`. -The function returns the trained model as well the trainer configuration. - -```python -import torch -from frogbox.utils import load_model_checkpoint - -device = torch.device("cuda:0") - -model, config = load_model_checkpoint( - "checkpoints/smooth-jazz-123/best_checkpoint_1_PSNR=26.6363.pt" -) -model = model.eval().to(device) - -x = torch.rand((1, 3, 16, 16), device=device) -with torch.inference_mode(): - pred = model(x) -``` -""" - -from typing import Any, Union, Tuple, cast +from typing import cast, Union, Tuple, Any, Optional from os import PathLike from pathlib import Path -import torch from .config import ( read_json_config, - create_object_from_config, Config, SupervisedConfig, GANConfig, + create_object_from_config, ) +import torch -def load_model_checkpoint(path: Union[str, PathLike]) -> Tuple[Any, Config]: +def load_model_checkpoint( + path: Union[str, PathLike], + config_path: Optional[Union[str, PathLike]] = None, +) -> Tuple[Any, Config]: """ Load model from checkpoint. @@ -42,6 +22,9 @@ def load_model_checkpoint(path: Union[str, PathLike]) -> Tuple[Any, Config]: ---------- path : path-like Path to checkpoint file. + config_path : path-like + Path to config file. If empty config will be read from "config.json" + in the same folder as `path`. Returns ------- @@ -49,8 +32,8 @@ def load_model_checkpoint(path: Union[str, PathLike]) -> Tuple[Any, Config]: Model checkpoint and config. """ path = Path(path) - - config_path = path.parent / "config.json" + if config_path is None: + config_path = path.parent / "config.json" base_config = read_json_config(config_path) ckpt = torch.load(path, map_location="cpu", weights_only=True) diff --git a/mypy.ini b/mypy.ini index d662f52..a0f7a15 100644 --- a/mypy.ini +++ b/mypy.ini @@ -9,6 +9,3 @@ ignore_missing_imports = True [mypy-accelerate.*] ignore_missing_imports = True - -[mypy-wandb.*] -ignore_missing_imports = True \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 6d9f1ed..0b290db 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ -flake8>=6.1.0 -mypy>=1.6.0 -pdoc>=14.5.0 -pytest>=7.4.2 -black==24.4.2 +flake8>=7.1.1 +black>=24.10.0 +mypy>=1.13.0 +pdoc>=15.0.0 +pytest>=8.3.4 diff --git a/requirements.txt b/requirements.txt index 2a4e9d2..86c939c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,9 @@ -torch>=2.0.0 -torchvision>=0.15.0 -pytorch-ignite>=0.4.12 -kornia>=0.7.0 -wandb>=0.18.3 -tqdm>=4.66.1 +torch>=2.5.1 +torchvision>=0.20.1 +accelerate>=1.2.0 +torchmetrics>=1.6.0 +pydantic>=2.10.3 +tqdm>=4.67.1 +wandb>=0.18.7 +jinja2>=3.1.4 click>=8.1.7 -pydantic>=2.4.2 -pydantic-settings>=2.1.0 -fastapi>=0.109.0 -uvicorn>=0.27.0 -accelerate==1.0.0rc1