Skip to content

Commit

Permalink
Clarify input tensor shape in pixelshuffle and pixelunshuffle functio…
Browse files Browse the repository at this point in the history
…ns and simplify ValueError message in pixelunshuffle
  • Loading branch information
phisanti committed Feb 7, 2025
1 parent 61efefb commit 091887b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
Args:
x: Input tensor
x: Input tensor with shape BCHW[D]
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
scale_factor: factor to rescale the spatial dimensions by, must be >=1
Expand Down Expand Up @@ -423,7 +423,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
Args:
x: Input tensor
x: Input tensor with shape BCHW[D]
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
scale_factor: factor to reduce the spatial dimensions by, must be >=1
Expand All @@ -443,7 +443,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor

if any(d % factor != 0 for d in input_size[2:]):
raise ValueError(
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}"
)
output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]
reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])
Expand Down

0 comments on commit 091887b

Please sign in to comment.