Skip to content

Commit

Permalink
Refactor layer normalization parameters for consistency and clarity i…
Browse files Browse the repository at this point in the history
…n Restormer model and update assert in forward layer to support 3D images
  • Loading branch information
phisanti committed Feb 7, 2025
1 parent f520e99 commit 39d1edf
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions monai/networks/nets/restormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ def __init__(
num_heads: int,
ffn_expansion_factor: float,
bias: bool,
layer_norm_type: str = "BiasFree",
layer_norm_use_bias: bool = False,
flash_attention: bool = False,
):
super().__init__()
use_bias = layer_norm_type != "BiasFree"
self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias)
self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention)
self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias)
self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -53,11 +52,11 @@ class OverlapPatchEmbed(nn.Module):
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""

def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = Convolution(
spatial_dims=spatial_dims,
in_channels=in_c,
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=3,
strides=1,
Expand Down Expand Up @@ -89,31 +88,31 @@ class Restormer(nn.Module):
def __init__(
self,
spatial_dims: int = 2,
inp_channels: int = 3,
in_channels: int = 3,
out_channels: int = 3,
dim: int = 48,
num_blocks: tuple[int, ...] = (1, 1, 1, 1),
heads: tuple[int, ...] = (1, 1, 1, 1),
num_refinement_blocks: int = 4,
ffn_expansion_factor: float = 2.66,
bias: bool = False,
layer_norm_type: str = "WithBias",
layer_norm_use_bias: str = True,
dual_pixel_task: bool = False,
flash_attention: bool = False,
) -> None:
super().__init__()
"""Initialize Restormer model.
Args:
inp_channels: Number of input image channels
in_channels: Number of input image channels
out_channels: Number of output image channels
dim: Base feature dimension
num_blocks: Number of transformer blocks at each scale
num_refinement_blocks: Number of final refinement blocks
heads: Number of attention heads at each scale
ffn_expansion_factor: Expansion factor for feed-forward network
bias: Whether to use bias in convolutions
layer_norm_type: Type of normalization ('WithBias' or 'BiasFree')
layer_norm_use_bias: Whether to use bias in layer normalization. Default is True.
dual_pixel_task: Enable dual-pixel specific processing
flash_attention: Use flash attention if available
"""
Expand All @@ -123,7 +122,7 @@ def __init__(
assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0"

# Initial feature extraction
self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim)
self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim)
self.encoder_levels = nn.ModuleList()
self.downsamples = nn.ModuleList()
self.decoder_levels = nn.ModuleList()
Expand All @@ -147,7 +146,7 @@ def __init__(
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
layer_norm_type=layer_norm_type,
layer_norm_use_bias=layer_norm_use_bias,
flash_attention=flash_attention,
)
for _ in range(num_blocks[n])
Expand Down Expand Up @@ -176,7 +175,7 @@ def __init__(
num_heads=heads[num_steps],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
layer_norm_type=layer_norm_type,
layer_norm_use_bias=layer_norm_use_bias,
flash_attention=flash_attention,
)
for _ in range(num_blocks[num_steps])
Expand Down Expand Up @@ -224,7 +223,7 @@ def __init__(
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
layer_norm_type=layer_norm_type,
layer_norm_use_bias=layer_norm_use_bias,
flash_attention=flash_attention,
)
for _ in range(num_blocks[n])
Expand All @@ -241,7 +240,7 @@ def __init__(
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
layer_norm_type=layer_norm_type,
layer_norm_use_bias=layer_norm_use_bias,
flash_attention=flash_attention,
)
for _ in range(num_refinement_blocks)
Expand Down Expand Up @@ -272,14 +271,14 @@ def forward(self, x) -> torch.Tensor:
"""Forward pass of Restormer.
Processes input through encoder-decoder architecture with skip connections.
Args:
inp_img: Input image tensor of shape (B, C, H, W)
inp_img: Input image tensor of shape (B, C, H, W, [D])
Returns:
Restored image tensor of shape (B, C, H, W)
Restored image tensor of shape (B, C, H, W, [D])
"""
assert (
x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2**self.num_steps
), "Input dimensions should be larger than 2^number_of_step"
assert all(
x.shape[-i] > 2 ** self.num_steps for i in range(1, self.spatial_dims + 1)
), "All spatial dimensions should be larger than 2^number_of_step"

# Patch embedding
x = self.patch_embed(x)
Expand Down

0 comments on commit 39d1edf

Please sign in to comment.