Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 238 additions & 8 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,167 @@
__all__ = ["AutoencoderKL"]


def _validate_kernel_stride_parameters(
kernel_size: int | tuple[int, ...] | None,
stride: int | tuple[int, ...] | None,
spatial_dims: int,
param_name: str = "parameter",
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Validate and normalize kernel_size and stride parameters.

Args:
kernel_size: int or tuple of ints representing kernel size
stride: int or tuple of ints representing stride
spatial_dims: number of spatial dimensions
param_name: name of parameter for error messages

Returns:
Tuple of (normalized_kernel_size, normalized_stride)

Raises:
ValueError: if parameters are invalid
"""
if kernel_size is None or stride is None:
return None, None

# Normalize kernel_size to tuple
if isinstance(kernel_size, int):
kernel_size_tuple = (kernel_size,) * spatial_dims
else:
kernel_size_tuple = tuple(kernel_size)

# Normalize stride to tuple
if isinstance(stride, int):
stride_tuple = (stride,) * spatial_dims
else:
stride_tuple = tuple(stride)

# Validate lengths
if len(kernel_size_tuple) != spatial_dims:
raise ValueError(f"{param_name} kernel_size must have length {spatial_dims}, got {len(kernel_size_tuple)}")
if len(stride_tuple) != spatial_dims:
raise ValueError(f"{param_name} stride must have length {spatial_dims}, got {len(stride_tuple)}")

# Validate kernel sizes are odd
for i, k in enumerate(kernel_size_tuple):
if k % 2 == 0:
raise ValueError(f"{param_name} kernel_size at dimension {i} must be odd, got {k}")

# Validate all values are positive integers
for i, (k, s) in enumerate(zip(kernel_size_tuple, stride_tuple)):
if not isinstance(k, int) or k <= 0:
raise ValueError(f"{param_name} kernel_size at dimension {i} must be positive int, got {k}")
if not isinstance(s, int) or s <= 0:
raise ValueError(f"{param_name} stride at dimension {i} must be positive int, got {s}")

return kernel_size_tuple, stride_tuple


def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]:
"""
Compute symmetric padding from kernel size.

For odd kernel sizes, padding = kernel_size // 2 on all sides.

Args:
kernel_size: tuple of odd integers

Returns:
Tuple of padding values (one per dimension)
"""
padding = tuple(k // 2 for k in kernel_size)
return padding
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _normalize_downsample_parameters(
downsample_parameters: list[dict] | dict | None,
num_levels: int,
spatial_dims: int,
default_kernel_size: int = 3,
default_stride: int = 2,
) -> list[dict]:
"""
Normalize downsampling parameters to canonical internal representation.

Accepts:
- None: use defaults for all levels
- Single dict: apply same params to all levels
- List of dicts: one dict per level

Each dict can specify:
- "kernel_size": int or tuple
- "stride": int or tuple
- "padding": int or tuple (auto-computed if omitted)

Returns:
List of dicts with normalized keys:
- Each dict has "kernel_size", "stride", "padding" as tuples
- Length equals num_levels

Raises:
ValueError: if parameters are invalid or inconsistent
"""
if downsample_parameters is None:
# Default: use provided defaults for all levels
default_ks_tuple, default_s_tuple = _validate_kernel_stride_parameters(
default_kernel_size, default_stride, spatial_dims
)
default_padding = _compute_padding(default_ks_tuple)
return [
{"kernel_size": default_ks_tuple, "stride": default_s_tuple, "padding": default_padding}
for _ in range(num_levels)
]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# If single dict, apply to all levels
if isinstance(downsample_parameters, dict):
params_list = [downsample_parameters] * num_levels
else:
params_list = list(downsample_parameters)

# Validate we have the right number of levels
if len(params_list) != num_levels:
raise ValueError(f"Expected {num_levels} downsampling parameter dicts (one per level), got {len(params_list)}")

# Normalize each dict
normalized = []
for i, params in enumerate(params_list):
if not isinstance(params, dict):
raise ValueError(f"Downsampling parameters at level {i} must be dict, got {type(params)}")

kernel_size = params.get("kernel_size", default_kernel_size)
stride = params.get("stride", default_stride)
padding = params.get("padding", None)

# Validate and normalize kernel_size and stride
ks_tuple, s_tuple = _validate_kernel_stride_parameters(kernel_size, stride, spatial_dims, f"Level {i}")

# Compute padding if not provided
if padding is None:
padding_tuple = _compute_padding(ks_tuple)
else:
# Normalize provided padding
if isinstance(padding, int):
padding_tuple = (padding,) * spatial_dims
else:
padding_tuple = tuple(padding)

if len(padding_tuple) != spatial_dims:
raise ValueError(f"Level {i} padding must have length {spatial_dims}, got {len(padding_tuple)}")

normalized.append({"kernel_size": ks_tuple, "stride": s_tuple, "padding": padding_tuple})

return normalized


class AsymmetricPad(nn.Module):
"""
Pad the input tensor asymmetrically along every spatial dimension.

.. deprecated:: 0.10.0
This class is deprecated and no longer used by `AEKLDownsample`.
Use configurable kernel_size and stride parameters instead (see `AEKLDownsample`).

Args:
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
"""
Expand All @@ -49,24 +206,46 @@ class AEKLDownsample(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
kernel_size: kernel size for the convolution. Can be int or tuple. Default: 3.
stride: stride for the convolution. Can be int or tuple. Default: 2.
padding: padding for the convolution. If None, computed from kernel_size. Default: None.
"""

def __init__(self, spatial_dims: int, in_channels: int) -> None:
def __init__(
self,
spatial_dims: int,
in_channels: int,
kernel_size: int | tuple[int, ...] = 3,
stride: int | tuple[int, ...] = 2,
padding: int | tuple[int, ...] | None = None,
) -> None:
super().__init__()
self.pad = AsymmetricPad(spatial_dims=spatial_dims)

# Validate and normalize kernel_size and stride
kernel_size_tuple, stride_tuple = _validate_kernel_stride_parameters(
kernel_size, stride, spatial_dims, "AEKLDownsample"
)

# Compute padding if not provided
if padding is None:
padding_tuple = _compute_padding(kernel_size_tuple)
else:
if isinstance(padding, int):
padding_tuple = (padding,) * spatial_dims
else:
padding_tuple = tuple(padding)

self.conv = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=2,
kernel_size=3,
padding=0,
strides=stride_tuple,
kernel_size=kernel_size_tuple,
padding=padding_tuple,
conv_only=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pad(x)
x = self.conv(x)
return x

Expand Down Expand Up @@ -160,6 +339,7 @@ class Encoder(nn.Module):
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
downsample_parameters: list of dicts specifying kernel_size, stride, padding for each downsampling level.
"""

def __init__(
Expand All @@ -176,6 +356,7 @@ def __init__(
include_fc: bool = True,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
downsample_parameters: list[dict] | None = None,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand All @@ -187,6 +368,15 @@ def __init__(
self.norm_eps = norm_eps
self.attention_levels = attention_levels

# Normalize downsampling parameters
num_downsample_levels = len(channels) - 1
normalized_downsample_params = _normalize_downsample_parameters(
downsample_parameters, num_downsample_levels, spatial_dims
)

# Store for decoder to use
self.downsample_parameters = normalized_downsample_params

blocks: list[nn.Module] = []
# Initial convolution
blocks.append(
Expand All @@ -203,6 +393,7 @@ def __init__(

# Residual and downsampling blocks
output_channel = channels[0]
downsample_idx = 0
for i in range(len(channels)):
input_channel = output_channel
output_channel = channels[i]
Expand Down Expand Up @@ -233,7 +424,19 @@ def __init__(
)

if not is_final_block:
blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel))
# Use downsampling parameters for this level
downsample_params = normalized_downsample_params[downsample_idx]
blocks.append(
AEKLDownsample(
spatial_dims=spatial_dims,
in_channels=input_channel,
kernel_size=downsample_params["kernel_size"],
stride=downsample_params["stride"],
padding=downsample_params["padding"],
)
)
downsample_idx += 1

# Non-local attention block
if with_nonlocal_attn is True:
blocks.append(
Expand Down Expand Up @@ -307,6 +510,7 @@ class Decoder(nn.Module):
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
downsample_parameters: list of dicts with encoder downsampling parameters (strides).
"""

def __init__(
Expand All @@ -324,6 +528,7 @@ def __init__(
include_fc: bool = True,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
downsample_parameters: list[dict] | None = None,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand All @@ -335,6 +540,12 @@ def __init__(
self.norm_eps = norm_eps
self.attention_levels = attention_levels

# Normalize downsampling parameters to get strides for upsampling
num_downsample_levels = len(channels) - 1
normalized_downsample_params = _normalize_downsample_parameters(
downsample_parameters, num_downsample_levels, spatial_dims
)

reversed_block_out_channels = list(reversed(channels))

blocks: list[nn.Module] = []
Expand Down Expand Up @@ -387,6 +598,10 @@ def __init__(
reversed_attention_levels = list(reversed(attention_levels))
reversed_num_res_blocks = list(reversed(num_res_blocks))
block_out_ch = reversed_block_out_channels[0]

# Reverse downsample parameters for use during upsampling
reversed_downsample_params = list(reversed(normalized_downsample_params))

for i in range(len(reversed_block_out_channels)):
block_in_ch = block_out_ch
block_out_ch = reversed_block_out_channels[i]
Expand Down Expand Up @@ -418,6 +633,10 @@ def __init__(
)

if not is_final_block:
# Use stride from encoder downsample as scale_factor for upsampling
# reversed_downsample_params[i] corresponds to the downsampling level we need to upsample
upsampling_stride = reversed_downsample_params[i]["stride"]

if use_convtranspose:
blocks.append(
Upsample(
Expand All @@ -441,7 +660,7 @@ def __init__(
in_channels=block_in_ch,
out_channels=block_in_ch,
interp_mode="nearest",
scale_factor=2.0,
scale_factor=tuple(float(s) for s in upsampling_stride),
post_conv=post_conv,
align_corners=None,
)
Expand Down Expand Up @@ -492,6 +711,10 @@ class AutoencoderKL(nn.Module):
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
downsample_parameters: downsampling parameters for each level. Can be:
- None: use default (kernel_size=3, stride=2 for all levels)
- dict: apply same parameters to all levels (e.g., {"kernel_size": (3,3,1), "stride": (2,2,1)})
- list of dicts: one dict per downsampling level with keys "kernel_size", "stride", "padding"
"""

def __init__(
Expand All @@ -512,6 +735,7 @@ def __init__(
include_fc: bool = True,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
downsample_parameters: list[dict] | dict | None = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -544,7 +768,12 @@ def __init__(
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
downsample_parameters=downsample_parameters,
)

# Get downsampling parameters from encoder to ensure decoder uses the same strides
encoder_downsample_params = self.encoder.downsample_parameters

self.decoder: nn.Module = Decoder(
spatial_dims=spatial_dims,
channels=channels,
Expand All @@ -559,6 +788,7 @@ def __init__(
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
downsample_parameters=encoder_downsample_params,
)
self.quant_conv_mu = Convolution(
spatial_dims=spatial_dims,
Expand Down
Loading
Loading