Skip to content

Commit

Permalink
Replace Ignite with custom event system (#5)
Browse files Browse the repository at this point in the history
* Remove Torch Ignite, replace with own engine/event mechanism

* Add GAN pipeline and example project

* Update README.md
  • Loading branch information
SimonLarsen authored Dec 16, 2024
1 parent 92dac75 commit 329abd9
Show file tree
Hide file tree
Showing 45 changed files with 1,390 additions and 1,949 deletions.
25 changes: 9 additions & 16 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ ipython_config.py
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
Expand Down Expand Up @@ -159,22 +161,13 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# pdoc
docs/*.html
docs/*.js
docs/frogbox/

# Virtual environment symlink
venv

# Training data
wandb/
checkpoints/

# Example projects
examples/supervised/wandb/
examples/supervised/checkpoints/
examples/supervised/data/
examples/gan/wandb/
examples/gan/checkpoints/
examples/gan/data/
examples/supervised/data
examples/supervised/checkpoints
examples/supervised/wandb
examples/gan/data
examples/gan/checkpoints
examples/gan/wandb
4 changes: 1 addition & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
include frogbox/data/*.json
include frogbox/data/*.py
include frogbox/data/Dockerfile
include frogbox/data/*.py
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
frogbox
</h1>

Frogbox is an opinionated machine learning framework for PyTorch built for rapid prototyping and research.
Frogbox is an opinionated PyTorch machine learning framework built for rapid prototyping and research.

## Features

* Built around [Torch Ignite](https://pytorch-ignite.ai) and [Accelerate](https://huggingface.co/docs/accelerate/index) to support automatic mixed precision (AMP) and distributed training.
* Experiments are defined using JSON files and support [jinja2](https://jinja.palletsprojects.com) templates.
* Flexible event system inspired by [Ignite](https://pytorch.org/ignite).
* Automatic experiment tracking. Currently only [Weights & Biases](https://wandb.ai/) is supported with other platforms planned.
* CLI tool for easy project management. Just type `frogbox project new -t supervised` to get started.
* Integrates [Accelerate](https://huggingface.co/docs/accelerate/index) to support automatic mixed precision (AMP) and distributed training.

## Installation

```sh
pip install git+https://[email protected]/SimonLarsen/frogbox.git@v0.3.3
pip install git+https://[email protected]/SimonLarsen/frogbox.git@v0.5.0
```

## Getting started
Expand Down
11 changes: 4 additions & 7 deletions examples/gan/configs/example.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
{
"type": "gan",
"project": "frogbox-example",
"amp": true,
"batch_size": 32,
"loader_workers": 4,
"max_epochs": 16,
"log_interval": "epoch_completed",
"checkpoints": [
{
"metric": "PSNR",
"mode": "max"
}
],
"log_interval": "epoch_completed",
"datasets": {% include 'datasets.json' %},
"disc_update_interval": 2,
"model": {
"class_name": "models.upscaler.Upscaler",
"params": {
Expand All @@ -30,8 +28,7 @@
"params": {
"in_channels": 3,
"hidden_channels": 32,
"num_blocks": 8,
"activation": "silu"
"num_blocks": 8
}
},
"losses": {
Expand All @@ -52,11 +49,11 @@
},
"metrics": {
"SSIM": {
"class_name": "ignite.metrics.SSIM",
"class_name": "torchmetrics.image.StructuralSimilarityIndexMeasure",
"params": {"data_range": 1.0}
},
"PSNR": {
"class_name": "ignite.metrics.PSNR",
"class_name": "torchmetrics.image.PeakSignalNoiseRatio",
"params": {"data_range": 1.0}
}
},
Expand Down
8 changes: 4 additions & 4 deletions examples/gan/datasets/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
image, _ = self.data[idx]
image = pil_to_tensor(image) / 255

X = resize(image, size=16, antialias=True)
x = resize(image, size=16, antialias=True)
y = image

if self.do_augment:
X = self.augment(X)
x = self.augment(x)

if self.do_normalize:
X = self.normalize(X)
x = self.normalize(x)

return X, y
return x, y
49 changes: 0 additions & 49 deletions examples/gan/models/blocks.py

This file was deleted.

12 changes: 3 additions & 9 deletions examples/gan/models/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from torch import nn
from torch.nn.utils import spectral_norm
from .blocks import get_activation


class SNResidualBlock(nn.Module):
def __init__(
self,
channels: int,
activation: str = "relu",
):
def __init__(self, channels: int):
super().__init__()

self.act = get_activation(activation)
self.act = nn.GELU()
self.conv1 = spectral_norm(
nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
)
Expand All @@ -37,15 +32,14 @@ def __init__(
in_channels: int = 3,
hidden_channels: int = 32,
num_blocks: int = 8,
activation: str = "silu",
):
super().__init__()

self.conv_in = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)

blocks = []
for _ in range(num_blocks):
blocks.append(SNResidualBlock(hidden_channels, activation))
blocks.append(SNResidualBlock(hidden_channels))
self.blocks = nn.Sequential(*blocks)

self.conv_out = nn.Conv2d(hidden_channels, 1, 3, 1, 1)
Expand Down
57 changes: 43 additions & 14 deletions examples/gan/models/upscaler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import math
from torch import nn
from .blocks import get_activation, ResidualBlock


class ResidualBlock(nn.Module):
def __init__(
self,
channels: int,
norm_groups: int = 4,
norm_eps: float = 1e-5,
):
super().__init__()

self.act = nn.GELU()

self.norm1 = nn.GroupNorm(norm_groups, channels, norm_eps)
self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)

self.norm2 = nn.GroupNorm(norm_groups, channels, norm_eps)
self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)

def forward(self, x):
identity = x

x = self.norm1(x)
x = self.act(x)
x = self.conv1(x)

x = self.norm2(x)
x = self.act(x)
x = self.conv2(x)

return x + identity


class Upscaler(nn.Module):
Expand All @@ -12,42 +42,41 @@ def __init__(
hidden_channels: int = 32,
num_layers: int = 4,
norm_groups: int = 4,
activation: str = "gelu",
):
super().__init__()

self.conv_in = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)

features = []
self.blocks = nn.ModuleList()
for _ in range(num_layers):
features.append(
self.blocks.append(
ResidualBlock(
channels=hidden_channels,
norm_groups=norm_groups,
activation=activation,
)
)
features.append(nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1))
self.features = nn.Sequential(*features)

upsample = []
for _ in range(int(math.log2(scale_factor))):
upsample.append(nn.Upsample(scale_factor=2, mode="nearest"))
upsample.append(
nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1)
)
upsample.append(get_activation(activation))
upsample.append(nn.GELU())
self.upsample = nn.Sequential(*upsample)

self.conv_out = nn.Sequential(
nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1),
get_activation(activation),
nn.GELU(),
nn.Conv2d(hidden_channels, out_channels, 3, 1, 1),
)

def forward(self, x):
x = self.conv_in(x)
x = self.features(x)
x = self.upsample(x)
x = self.conv_out(x)
return x
h = self.conv_in(x)
for block in self.blocks:
h = block(h) + h
h = self.upsample(h)
h = self.conv_out(h)

x = nn.functional.interpolate(x, h.shape[-2:], mode="bilinear")
return nn.functional.sigmoid(x + h)
24 changes: 10 additions & 14 deletions examples/gan/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Sequence
from typing import cast, Optional, Sequence
from pathlib import Path
import argparse
from frogbox import read_json_config, GANPipeline, Events
from frogbox.callbacks import create_image_logger
from frogbox import read_json_config, GANPipeline, GANConfig
from frogbox.callbacks import ImageLogger


def parse_arguments(
Expand All @@ -28,7 +28,7 @@ def parse_arguments(

if __name__ == "__main__":
args = parse_arguments()
config = read_json_config(args.config)
config = cast(GANConfig, read_json_config(args.config))

pipeline = GANPipeline(
config=config,
Expand All @@ -40,16 +40,12 @@ def parse_arguments(
group=args.group,
)

dataset_params = config.datasets["test"].params
image_logger = create_image_logger(
split="test",
normalize_mean=dataset_params["normalize_mean"],
normalize_std=dataset_params["normalize_std"],
denormalize_input=dataset_params["do_normalize"],
)
pipeline.install_callback(
event=Events.EPOCH_COMPLETED,
callback=image_logger,
ds_conf = config.datasets["train"].params
image_logger = ImageLogger(
denormalize_input=True,
normalize_mean=ds_conf["normalize_mean"],
normalize_std=ds_conf["normalize_std"],
)
pipeline.install_callback(pipeline.log_interval, image_logger)

pipeline.run()
10 changes: 5 additions & 5 deletions examples/supervised/configs/example.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"type": "supervised",
"project": "frogbox-example",
"batch_size": 16,
"batch_size": 32,
"loader_workers": 4,
"max_epochs": 16,
"log_interval": "epoch_completed",
Expand All @@ -10,7 +10,7 @@
"n_saved": 3,
"interval": {
"event": "iteration_completed",
"every": 500
"every": 1000
}
},
{
Expand All @@ -21,7 +21,7 @@
],
"datasets": {% include 'datasets.json' %},
"model": {
"class_name": "models.example.ExampleModel",
"class_name": "models.upscaler.Upscaler",
"params": {
"scale_factor": 2,
"in_channels": 3,
Expand All @@ -39,11 +39,11 @@
},
"metrics": {
"SSIM": {
"class_name": "ignite.metrics.SSIM",
"class_name": "torchmetrics.image.StructuralSimilarityIndexMeasure",
"params": {"data_range": 1.0}
},
"PSNR": {
"class_name": "ignite.metrics.PSNR",
"class_name": "torchmetrics.image.PeakSignalNoiseRatio",
"params": {"data_range": 1.0}
}
},
Expand Down
Loading

0 comments on commit 329abd9

Please sign in to comment.