From 0a9077364536d7c81b902a8225d5a7540b931635 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 17 May 2026 19:41:49 +0530 Subject: [PATCH 1/3] Add configurable anisotropic downsampling support to AutoencoderKL and relevant testcases Signed-off-by: Shubham Chandravanshi --- monai/networks/nets/autoencoderkl.py | 246 ++++++++++++++++- tests/networks/nets/test_autoencoderkl.py | 316 ++++++++++++++++++++++ 2 files changed, 554 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 11b4fcfc9e..6c9e93a633 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -25,10 +25,167 @@ __all__ = ["AutoencoderKL"] +def _validate_kernel_stride_parameters( + kernel_size: int | tuple[int, ...] | None, + stride: int | tuple[int, ...] | None, + spatial_dims: int, + param_name: str = "parameter", +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Validate and normalize kernel_size and stride parameters. + + Args: + kernel_size: int or tuple of ints representing kernel size + stride: int or tuple of ints representing stride + spatial_dims: number of spatial dimensions + param_name: name of parameter for error messages + + Returns: + Tuple of (normalized_kernel_size, normalized_stride) + + Raises: + ValueError: if parameters are invalid + """ + if kernel_size is None or stride is None: + return None, None + + # Normalize kernel_size to tuple + if isinstance(kernel_size, int): + kernel_size_tuple = (kernel_size,) * spatial_dims + else: + kernel_size_tuple = tuple(kernel_size) + + # Normalize stride to tuple + if isinstance(stride, int): + stride_tuple = (stride,) * spatial_dims + else: + stride_tuple = tuple(stride) + + # Validate lengths + if len(kernel_size_tuple) != spatial_dims: + raise ValueError(f"{param_name} kernel_size must have length {spatial_dims}, got {len(kernel_size_tuple)}") + if len(stride_tuple) != spatial_dims: + raise ValueError(f"{param_name} stride must have length {spatial_dims}, got {len(stride_tuple)}") + + # Validate kernel sizes are odd + for i, k in enumerate(kernel_size_tuple): + if k % 2 == 0: + raise ValueError(f"{param_name} kernel_size at dimension {i} must be odd, got {k}") + + # Validate all values are positive integers + for i, (k, s) in enumerate(zip(kernel_size_tuple, stride_tuple)): + if not isinstance(k, int) or k <= 0: + raise ValueError(f"{param_name} kernel_size at dimension {i} must be positive int, got {k}") + if not isinstance(s, int) or s <= 0: + raise ValueError(f"{param_name} stride at dimension {i} must be positive int, got {s}") + + return kernel_size_tuple, stride_tuple + + +def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]: + """ + Compute symmetric padding from kernel size. + + For odd kernel sizes, padding = kernel_size // 2 on all sides. + + Args: + kernel_size: tuple of odd integers + + Returns: + Tuple of padding values (one per dimension) + """ + padding = tuple(k // 2 for k in kernel_size) + return padding + + +def _normalize_downsample_parameters( + downsample_parameters: list[dict] | dict | None, + num_levels: int, + spatial_dims: int, + default_kernel_size: int = 3, + default_stride: int = 2, +) -> list[dict]: + """ + Normalize downsampling parameters to canonical internal representation. + + Accepts: + - None: use defaults for all levels + - Single dict: apply same params to all levels + - List of dicts: one dict per level + + Each dict can specify: + - "kernel_size": int or tuple + - "stride": int or tuple + - "padding": int or tuple (auto-computed if omitted) + + Returns: + List of dicts with normalized keys: + - Each dict has "kernel_size", "stride", "padding" as tuples + - Length equals num_levels + + Raises: + ValueError: if parameters are invalid or inconsistent + """ + if downsample_parameters is None: + # Default: use provided defaults for all levels + default_ks_tuple, default_s_tuple = _validate_kernel_stride_parameters( + default_kernel_size, default_stride, spatial_dims + ) + default_padding = _compute_padding(default_ks_tuple) + return [ + {"kernel_size": default_ks_tuple, "stride": default_s_tuple, "padding": default_padding} + for _ in range(num_levels) + ] + + # If single dict, apply to all levels + if isinstance(downsample_parameters, dict): + params_list = [downsample_parameters] * num_levels + else: + params_list = list(downsample_parameters) + + # Validate we have the right number of levels + if len(params_list) != num_levels: + raise ValueError(f"Expected {num_levels} downsampling parameter dicts (one per level), got {len(params_list)}") + + # Normalize each dict + normalized = [] + for i, params in enumerate(params_list): + if not isinstance(params, dict): + raise ValueError(f"Downsampling parameters at level {i} must be dict, got {type(params)}") + + kernel_size = params.get("kernel_size", default_kernel_size) + stride = params.get("stride", default_stride) + padding = params.get("padding", None) + + # Validate and normalize kernel_size and stride + ks_tuple, s_tuple = _validate_kernel_stride_parameters(kernel_size, stride, spatial_dims, f"Level {i}") + + # Compute padding if not provided + if padding is None: + padding_tuple = _compute_padding(ks_tuple) + else: + # Normalize provided padding + if isinstance(padding, int): + padding_tuple = (padding,) * spatial_dims + else: + padding_tuple = tuple(padding) + + if len(padding_tuple) != spatial_dims: + raise ValueError(f"Level {i} padding must have length {spatial_dims}, got {len(padding_tuple)}") + + normalized.append({"kernel_size": ks_tuple, "stride": s_tuple, "padding": padding_tuple}) + + return normalized + + class AsymmetricPad(nn.Module): """ Pad the input tensor asymmetrically along every spatial dimension. + .. deprecated:: 0.10.0 + This class is deprecated and no longer used by `AEKLDownsample`. + Use configurable kernel_size and stride parameters instead (see `AEKLDownsample`). + Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. """ @@ -49,24 +206,46 @@ class AEKLDownsample(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. + kernel_size: kernel size for the convolution. Can be int or tuple. Default: 3. + stride: stride for the convolution. Can be int or tuple. Default: 2. + padding: padding for the convolution. If None, computed from kernel_size. Default: None. """ - def __init__(self, spatial_dims: int, in_channels: int) -> None: + def __init__( + self, + spatial_dims: int, + in_channels: int, + kernel_size: int | tuple[int, ...] = 3, + stride: int | tuple[int, ...] = 2, + padding: int | tuple[int, ...] | None = None, + ) -> None: super().__init__() - self.pad = AsymmetricPad(spatial_dims=spatial_dims) + + # Validate and normalize kernel_size and stride + kernel_size_tuple, stride_tuple = _validate_kernel_stride_parameters( + kernel_size, stride, spatial_dims, "AEKLDownsample" + ) + + # Compute padding if not provided + if padding is None: + padding_tuple = _compute_padding(kernel_size_tuple) + else: + if isinstance(padding, int): + padding_tuple = (padding,) * spatial_dims + else: + padding_tuple = tuple(padding) self.conv = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, - strides=2, - kernel_size=3, - padding=0, + strides=stride_tuple, + kernel_size=kernel_size_tuple, + padding=padding_tuple, conv_only=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pad(x) x = self.conv(x) return x @@ -160,6 +339,7 @@ class Encoder(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: list of dicts specifying kernel_size, stride, padding for each downsampling level. """ def __init__( @@ -176,6 +356,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -187,6 +368,15 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels + # Normalize downsampling parameters + num_downsample_levels = len(channels) - 1 + normalized_downsample_params = _normalize_downsample_parameters( + downsample_parameters, num_downsample_levels, spatial_dims + ) + + # Store for decoder to use + self.downsample_parameters = normalized_downsample_params + blocks: list[nn.Module] = [] # Initial convolution blocks.append( @@ -203,6 +393,7 @@ def __init__( # Residual and downsampling blocks output_channel = channels[0] + downsample_idx = 0 for i in range(len(channels)): input_channel = output_channel output_channel = channels[i] @@ -233,7 +424,19 @@ def __init__( ) if not is_final_block: - blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + # Use downsampling parameters for this level + downsample_params = normalized_downsample_params[downsample_idx] + blocks.append( + AEKLDownsample( + spatial_dims=spatial_dims, + in_channels=input_channel, + kernel_size=downsample_params["kernel_size"], + stride=downsample_params["stride"], + padding=downsample_params["padding"], + ) + ) + downsample_idx += 1 + # Non-local attention block if with_nonlocal_attn is True: blocks.append( @@ -307,6 +510,7 @@ class Decoder(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: list of dicts with encoder downsampling parameters (strides). """ def __init__( @@ -324,6 +528,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -335,6 +540,12 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels + # Normalize downsampling parameters to get strides for upsampling + num_downsample_levels = len(channels) - 1 + normalized_downsample_params = _normalize_downsample_parameters( + downsample_parameters, num_downsample_levels, spatial_dims + ) + reversed_block_out_channels = list(reversed(channels)) blocks: list[nn.Module] = [] @@ -387,6 +598,10 @@ def __init__( reversed_attention_levels = list(reversed(attention_levels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) block_out_ch = reversed_block_out_channels[0] + + # Reverse downsample parameters for use during upsampling + reversed_downsample_params = list(reversed(normalized_downsample_params)) + for i in range(len(reversed_block_out_channels)): block_in_ch = block_out_ch block_out_ch = reversed_block_out_channels[i] @@ -418,6 +633,10 @@ def __init__( ) if not is_final_block: + # Use stride from encoder downsample as scale_factor for upsampling + # reversed_downsample_params[i] corresponds to the downsampling level we need to upsample + upsampling_stride = reversed_downsample_params[i]["stride"] + if use_convtranspose: blocks.append( Upsample( @@ -441,7 +660,7 @@ def __init__( in_channels=block_in_ch, out_channels=block_in_ch, interp_mode="nearest", - scale_factor=2.0, + scale_factor=tuple(float(s) for s in upsampling_stride), post_conv=post_conv, align_corners=None, ) @@ -492,6 +711,10 @@ class AutoencoderKL(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: downsampling parameters for each level. Can be: + - None: use default (kernel_size=3, stride=2 for all levels) + - dict: apply same parameters to all levels (e.g., {"kernel_size": (3,3,1), "stride": (2,2,1)}) + - list of dicts: one dict per downsampling level with keys "kernel_size", "stride", "padding" """ def __init__( @@ -512,6 +735,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | dict | None = None, ) -> None: super().__init__() @@ -544,7 +768,12 @@ def __init__( include_fc=include_fc, use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, + downsample_parameters=downsample_parameters, ) + + # Get downsampling parameters from encoder to ensure decoder uses the same strides + encoder_downsample_params = self.encoder.downsample_parameters + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, @@ -559,6 +788,7 @@ def __init__( include_fc=include_fc, use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, + downsample_parameters=encoder_downsample_params, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index af0c55d6ec..cb66c74a0e 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -428,6 +428,322 @@ def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False" ) + # New tests for downsampling parameters + def test_backward_compatibility_default_behavior(self): + """Test that default behavior (no downsample_parameters) is unchanged.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + # Test with standard input shape + x = torch.randn(1, 1, 16, 16).to(device) + result = net.forward(x) + # With default stride=2 and 2 downsampling levels (for 3 channel groups), + # latent shape should be 16 / 2 / 2 = 4 + self.assertEqual(result[0].shape, (1, 1, 16, 16)) + self.assertEqual(result[1].shape, (1, 4, 4, 4)) + + def test_anisotropic_stride_2d(self): + """Test 2D anisotropic stride (2,1) at first level.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Downsampling: level 0 uses (2,1), level 1 uses (2,2) + downsample_params = [{"kernel_size": 3, "stride": (2, 1)}, {"kernel_size": 3, "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/1=32 + # After level 1: 16/2=8, 32/2=16 + self.assertEqual(result[0].shape, (1, 1, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 16)) + + def test_anisotropic_stride_3d(self): + """Test 3D anisotropic stride (2,2,1) - common for thick slice spacing.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Preserve z-dimension with stride=1 + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32, 64).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/2=16, 64/1=64 + # After level 1: 16/2=8, 16/2=8, 64/1=64 + self.assertEqual(result[0].shape, (1, 1, 32, 32, 64)) + self.assertEqual(result[1].shape, (1, 4, 8, 8, 64)) + + def test_mixed_anisotropic_downsample_parameters(self): + """Test per-level configuration with mixed parameters.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Level 0: preserve z, Level 1: isotropic + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/2=16, 32/1=32 + # After level 1: 16/2=8, 16/2=8, 32/2=16 + self.assertEqual(result[0].shape, (1, 1, 32, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 8, 16)) + + def test_single_dict_applied_to_all_levels(self): + """Test that single dict is applied to all downsampling levels.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Single dict: apply (3,3) kernel with stride (2,1) to all levels + downsample_params = {"kernel_size": (3, 3), "stride": (2, 1)} + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/1=32 + # After level 1: 16/2=8, 32/1=32 + self.assertEqual(result[0].shape, (1, 1, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 32)) + + def test_validation_even_kernel_raises_error(self): + """Test that even kernel sizes raise ValueError.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [{"kernel_size": 4, "stride": 2}] # Even kernel + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_validation_invalid_tuple_length_raises_error(self): + """Test that invalid tuple length raises ValueError.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # 3D but only 2 values in tuple + downsample_params = [{"kernel_size": (3, 3), "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_validation_wrong_num_levels_raises_error(self): + """Test that wrong number of downsampling parameter dicts raises error.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), # 3 channels = 2 downsampling levels + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Only 1 dict but need 2 + downsample_params = [{"kernel_size": 3, "stride": 2}] + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_reconstruction_with_anisotropic_downsampling(self): + """Test that reconstruction shape matches input with anisotropic downsampling.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 64, 64, 128).to(device) + reconstruction = net.reconstruct(x) + self.assertEqual(reconstruction.shape, x.shape) + + def test_encode_decode_with_anisotropic_downsampling(self): + """Test encode/decode cycle with anisotropic downsampling.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [{"kernel_size": (3, 3), "stride": (2, 1)}, {"kernel_size": (3, 3), "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + z_mu, z_sigma = net.encode(x) + z = net.sampling(z_mu, z_sigma) + reconstruction = net.decode(z) + self.assertEqual(reconstruction.shape, x.shape) + + def test_reconstruction_robustness_anisotropic_non_power_of_two_odd_dims(self): + """ + Test reconstruction shape consistency with: + - Anisotropic multi-level downsampling config + - Non-power-of-two spatial dimensions (but stride-compatible) + - Mixed even/odd dimensions + + This rigorously validates encoder-decoder symmetry under challenging conditions. + + Note: Dimensions must be compatible with the stride pattern: + - Stride (2,2,1) -> (2,2,2) means dims must be divisible by (4,4,2) + - Using 60 (=4*15), 68 (=4*17), 96 (=2*48) to maximize coverage + """ + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + + # Anisotropic config: preserve Z dimension at level 0, isotropic at level 1 + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + # Stride-compatible dimensions: + # Level 0: stride (2,2,1) -> need height/width divisible by 2 + # Level 1: stride (2,2,2) -> need result divisible by 2 again + # Final requirement: dims divisible by (4, 4, 2) + # Using: 60=4*15 (not power of 2), 68=4*17 (not power of 2), 96=2*48 + x = torch.randn(1, 1, 60, 68, 96).to(device) + + # Forward pass + z_mu, z_sigma = net.encode(x) + z = net.sampling(z_mu, z_sigma) + reconstruction = net.decode(z) + + # Verify shape consistency - reconstruction should match input exactly + self.assertEqual( + reconstruction.shape, + x.shape, + f"Reconstruction shape {reconstruction.shape} does not match input shape {x.shape}", + ) + + # Also test via reconstruct method + reconstruction2 = net.reconstruct(x) + self.assertEqual( + reconstruction2.shape, + x.shape, + f"Reconstruct shape {reconstruction2.shape} does not match input shape {x.shape}", + ) + + # Verify latent shape makes sense: + # 60 -> 30 (stride=2) -> 15 (stride=2) + # 68 -> 34 (stride=2) -> 17 (stride=2) + # 96 -> 96 (stride=1) -> 48 (stride=2) + expected_latent_h = 15 + expected_latent_w = 17 + expected_latent_d = 48 + + self.assertEqual( + z_mu.shape[2], + expected_latent_h, + f"Latent H shape mismatch: got {z_mu.shape[2]}, expected {expected_latent_h}", + ) + self.assertEqual( + z_mu.shape[3], + expected_latent_w, + f"Latent W shape mismatch: got {z_mu.shape[3]}, expected {expected_latent_w}", + ) + self.assertEqual( + z_mu.shape[4], + expected_latent_d, + f"Latent D shape mismatch: got {z_mu.shape[4]}, expected {expected_latent_d}", + ) + if __name__ == "__main__": unittest.main() From 6bc4ebc226d33dba05c6a0e347742178d67b0a02 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 17 May 2026 20:41:27 +0530 Subject: [PATCH 2/3] Update AutoencoderKL test configuration and docstrings Signed-off-by: Shubham Chandravanshi --- monai/networks/nets/autoencoderkl.py | 83 ++++++++++++++++++++--- tests/networks/nets/test_autoencoderkl.py | 30 +++++++- 2 files changed, 103 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 6c9e93a633..66aa59564b 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -30,7 +30,7 @@ def _validate_kernel_stride_parameters( stride: int | tuple[int, ...] | None, spatial_dims: int, param_name: str = "parameter", -) -> tuple[tuple[int, ...], tuple[int, ...]]: +) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]: """ Validate and normalize kernel_size and stride parameters. @@ -84,15 +84,16 @@ def _validate_kernel_stride_parameters( def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]: """ - Compute symmetric padding from kernel size. + Compute symmetric padding for odd kernel sizes. - For odd kernel sizes, padding = kernel_size // 2 on all sides. + Padding is derived as: + padding[d] = kernel_size[d] // 2 Args: - kernel_size: tuple of odd integers + kernel_size: Kernel size for each spatial dimension. Returns: - Tuple of padding values (one per dimension) + Tuple of padding values for each spatial dimension. """ padding = tuple(k // 2 for k in kernel_size) return padding @@ -119,9 +120,9 @@ def _normalize_downsample_parameters( - "padding": int or tuple (auto-computed if omitted) Returns: - List of dicts with normalized keys: - - Each dict has "kernel_size", "stride", "padding" as tuples - - Length equals num_levels + List of dicts with normalized keys: + - Each dict has "kernel_size", "stride", "padding" as tuples + - Length equals num_levels Raises: ValueError: if parameters are invalid or inconsistent @@ -195,6 +196,15 @@ def __init__(self, spatial_dims: int) -> None: self.pad = (0, 1) * spatial_dims def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply asymmetric padding to the input tensor. + + Args: + x: Input tensor. + + Returns: + Padded tensor. + """ x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) return x @@ -246,6 +256,15 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply convolutional downsampling. + + Args: + x: Input tensor. + + Returns: + Downsampled tensor. + """ x = self.conv(x) return x @@ -486,6 +505,15 @@ def __init__( self.blocks = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward input through encoder blocks. + + Args: + x: Input tensor. + + Returns: + Encoded latent representation. + """ for block in self.blocks: x = block(x) return x @@ -682,6 +710,15 @@ def __init__( self.blocks = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward latent representation through decoder blocks. + + Args: + x: Latent tensor. + + Returns: + Reconstructed image tensor. + """ for block in self.blocks: x = block(x) return x @@ -890,17 +927,47 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: return dec def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encode, sample, and reconstruct an input image. + + Args: + x: Input tensor of shape BxCx[SPATIAL_DIMS]. + + Returns: + Tuple containing: + - reconstructed image + - latent mean + - latent standard deviation + """ z_mu, z_sigma = self.encode(x) z = self.sampling(z_mu, z_sigma) reconstruction = self.decode(z) return reconstruction, z_mu, z_sigma def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode an input image into latent space representation. + + Args: + x: Input tensor. + + Returns: + Sampled latent tensor. + """ z_mu, z_sigma = self.encode(x) z = self.sampling(z_mu, z_sigma) return z def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode latent representation into image space. + + Args: + z: Latent tensor. + + Returns: + Decoded image tensor. + """ image = self.decode(z) return image diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index cb66c74a0e..36272bd7cf 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -440,6 +440,8 @@ def test_backward_compatibility_default_behavior(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -462,6 +464,8 @@ def test_anisotropic_stride_2d(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Downsampling: level 0 uses (2,1), level 1 uses (2,2) downsample_params = [{"kernel_size": 3, "stride": (2, 1)}, {"kernel_size": 3, "stride": (2, 2)}] @@ -487,6 +491,8 @@ def test_anisotropic_stride_3d(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Preserve z-dimension with stride=1 downsample_params = [ @@ -515,6 +521,8 @@ def test_mixed_anisotropic_downsample_parameters(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Level 0: preserve z, Level 1: isotropic downsample_params = [ @@ -543,6 +551,8 @@ def test_single_dict_applied_to_all_levels(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Single dict: apply (3,3) kernel with stride (2,1) to all levels downsample_params = {"kernel_size": (3, 3), "stride": (2, 1)} @@ -568,8 +578,11 @@ def test_validation_even_kernel_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } - downsample_params = [{"kernel_size": 4, "stride": 2}] # Even kernel + + downsample_params = [{"kernel_size": 4, "stride": 2}, {"kernel_size": 3, "stride": 2}] # Even kernel input_param["downsample_parameters"] = downsample_params with self.assertRaises(ValueError): @@ -586,9 +599,14 @@ def test_validation_invalid_tuple_length_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # 3D but only 2 values in tuple - downsample_params = [{"kernel_size": (3, 3), "stride": (2, 2)}] + downsample_params = [ + {"kernel_size": (3, 3), "stride": (2, 2)}, # Invalid: 2 values for 3D + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] input_param["downsample_parameters"] = downsample_params with self.assertRaises(ValueError): @@ -605,6 +623,8 @@ def test_validation_wrong_num_levels_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Only 1 dict but need 2 downsample_params = [{"kernel_size": 3, "stride": 2}] @@ -624,6 +644,8 @@ def test_reconstruction_with_anisotropic_downsampling(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } downsample_params = [ {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, @@ -648,6 +670,8 @@ def test_encode_decode_with_anisotropic_downsampling(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } downsample_params = [{"kernel_size": (3, 3), "stride": (2, 1)}, {"kernel_size": (3, 3), "stride": (2, 2)}] input_param["downsample_parameters"] = downsample_params @@ -682,6 +706,8 @@ def test_reconstruction_robustness_anisotropic_non_power_of_two_odd_dims(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Anisotropic config: preserve Z dimension at level 0, isotropic at level 1 From 9ca15505bf3a54702bb633d1aedd832771aaa39a Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 24 May 2026 03:13:03 +0530 Subject: [PATCH 3/3] Update AutoencoderKL coderabbit issues, mypy type error fixes and added some relavent testcases Signed-off-by: Shubham Chandravanshi --- monai/networks/nets/autoencoderkl.py | 157 +++++++++++++---- tests/networks/nets/test_autoencoderkl.py | 196 +++++++++++++++++++++- 2 files changed, 319 insertions(+), 34 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 66aa59564b..ea1f73772d 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -30,7 +30,7 @@ def _validate_kernel_stride_parameters( stride: int | tuple[int, ...] | None, spatial_dims: int, param_name: str = "parameter", -) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]: +) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Validate and normalize kernel_size and stride parameters. @@ -46,8 +46,6 @@ def _validate_kernel_stride_parameters( Raises: ValueError: if parameters are invalid """ - if kernel_size is None or stride is None: - return None, None # Normalize kernel_size to tuple if isinstance(kernel_size, int): @@ -132,7 +130,7 @@ def _normalize_downsample_parameters( default_ks_tuple, default_s_tuple = _validate_kernel_stride_parameters( default_kernel_size, default_stride, spatial_dims ) - default_padding = _compute_padding(default_ks_tuple) + default_padding: tuple[int, ...] = _compute_padding(default_ks_tuple) return [ {"kernel_size": default_ks_tuple, "stride": default_s_tuple, "padding": default_padding} for _ in range(num_levels) @@ -163,7 +161,7 @@ def _normalize_downsample_parameters( # Compute padding if not provided if padding is None: - padding_tuple = _compute_padding(ks_tuple) + padding_tuple: tuple[int, ...] = _compute_padding(ks_tuple) else: # Normalize provided padding if isinstance(padding, int): @@ -209,6 +207,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class _RecordShapeHook(nn.Module): + """Helper module to record spatial shapes during encoding for decoder restoration.""" + + def __init__(self, shape_list: list[tuple[int, ...]]) -> None: + super().__init__() + self.shape_list = shape_list + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Record spatial dimensions and pass through.""" + self.shape_list.append(tuple(x.shape[2:])) + return x + + +class _ShapeRestoringUpsample(nn.Module): + """Upsample to exact target size (recorded by encoder) instead of using scale_factor. + + This handles arbitrary input dimensions (odd, non-power-of-2, anisotropic) by restoring + to the exact pre-downsampling shape recorded during encoding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + post_conv: nn.Module, + shape_index: int, + downsample_shapes_ref: list, + scale_factor: tuple[int, ...] | None = None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.post_conv = post_conv + self.shape_index = shape_index + self.downsample_shapes_ref = downsample_shapes_ref # Reference to the shared list, NOT a module + self.scale_factor = scale_factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample to exact target size, then apply post-convolution.""" + # Get target shape from downsample_shapes (in reverse order) + if self.downsample_shapes_ref and self.shape_index < len(self.downsample_shapes_ref): + # Shapes are stored in order, but we're using them in reverse + target_shape_index = len(self.downsample_shapes_ref) - 1 - self.shape_index + target_shape = self.downsample_shapes_ref[target_shape_index] + x = F.interpolate(x, size=target_shape, mode="nearest") + elif self.scale_factor is not None: + # Fallback for standalone decode (no encoder run): use scale_factor + sf = tuple(float(s) for s in self.scale_factor) + x = F.interpolate(x, scale_factor=sf, mode="nearest") + + x = self.post_conv(x) + return x + + class AEKLDownsample(nn.Module): """ Convolution-based downsampling layer. @@ -219,6 +273,8 @@ class AEKLDownsample(nn.Module): kernel_size: kernel size for the convolution. Can be int or tuple. Default: 3. stride: stride for the convolution. Can be int or tuple. Default: 2. padding: padding for the convolution. If None, computed from kernel_size. Default: None. + use_legacy_padding: if True and padding is None, use asymmetric padding (0,1) for each dimension + to match the original MONAI Generative implementation. Default: False. """ def __init__( @@ -228,6 +284,7 @@ def __init__( kernel_size: int | tuple[int, ...] = 3, stride: int | tuple[int, ...] = 2, padding: int | tuple[int, ...] | None = None, + use_legacy_padding: bool = False, ) -> None: super().__init__() @@ -236,22 +293,28 @@ def __init__( kernel_size, stride, spatial_dims, "AEKLDownsample" ) - # Compute padding if not provided - if padding is None: - padding_tuple = _compute_padding(kernel_size_tuple) + self.use_legacy_padding = use_legacy_padding and (padding is None) + if self.use_legacy_padding: + # Legacy behavior: asymmetric padding (0, 1) per dimension + conv with padding=0 + self.pad = (0, 1) * spatial_dims + padding_tuple = (0,) * spatial_dims else: - if isinstance(padding, int): - padding_tuple = (padding,) * spatial_dims + # New behavior: compute symmetric padding if not provided + if padding is None: + padding_tuple: tuple[int, ...] = _compute_padding(kernel_size_tuple) else: - padding_tuple = tuple(padding) + if isinstance(padding, int): + padding_tuple = (padding,) * spatial_dims + else: + padding_tuple = tuple(padding) self.conv = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, - strides=stride_tuple, - kernel_size=kernel_size_tuple, - padding=padding_tuple, + strides=tuple(stride_tuple), + kernel_size=tuple(kernel_size_tuple), + padding=tuple(padding_tuple), conv_only=True, ) @@ -265,6 +328,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Downsampled tensor. """ + if self.use_legacy_padding: + x = F.pad(x, self.pad, mode="constant", value=0.0) x = self.conv(x) return x @@ -375,7 +440,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, - downsample_parameters: list[dict] | None = None, + downsample_parameters: list[dict] | dict | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -389,12 +454,15 @@ def __init__( # Normalize downsampling parameters num_downsample_levels = len(channels) - 1 + use_legacy_padding = downsample_parameters is None # Track if using legacy defaults normalized_downsample_params = _normalize_downsample_parameters( downsample_parameters, num_downsample_levels, spatial_dims ) # Store for decoder to use self.downsample_parameters = normalized_downsample_params + self.use_legacy_padding = use_legacy_padding + self.downsample_shapes: list[tuple[int, ...]] = [] # Track shapes before each downsample blocks: list[nn.Module] = [] # Initial convolution @@ -443,6 +511,8 @@ def __init__( ) if not is_final_block: + # Record shape before downsampling (for decoder to restore exact size) + blocks.append(_RecordShapeHook(self.downsample_shapes)) # Use downsampling parameters for this level downsample_params = normalized_downsample_params[downsample_idx] blocks.append( @@ -452,6 +522,7 @@ def __init__( kernel_size=downsample_params["kernel_size"], stride=downsample_params["stride"], padding=downsample_params["padding"], + use_legacy_padding=use_legacy_padding, ) ) downsample_idx += 1 @@ -556,7 +627,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, - downsample_parameters: list[dict] | None = None, + downsample_parameters: list[dict] | dict | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -570,10 +641,15 @@ def __init__( # Normalize downsampling parameters to get strides for upsampling num_downsample_levels = len(channels) - 1 + use_legacy_padding = downsample_parameters is None # Track if using legacy defaults normalized_downsample_params = _normalize_downsample_parameters( downsample_parameters, num_downsample_levels, spatial_dims ) + # Will be populated by encoder with shapes before each downsample + self.downsample_shapes: list[tuple[int, ...]] = [] + self.use_legacy_padding = use_legacy_padding + reversed_block_out_channels = list(reversed(channels)) blocks: list[nn.Module] = [] @@ -661,10 +737,6 @@ def __init__( ) if not is_final_block: - # Use stride from encoder downsample as scale_factor for upsampling - # reversed_downsample_params[i] corresponds to the downsampling level we need to upsample - upsampling_stride = reversed_downsample_params[i]["stride"] - if use_convtranspose: blocks.append( Upsample( @@ -672,6 +744,8 @@ def __init__( ) ) else: + # For nontrainable upsampling: use exact target size from encoder + # This handles arbitrary input dimensions (odd, non-power-of-2, etc.) post_conv = Convolution( spatial_dims=spatial_dims, in_channels=block_in_ch, @@ -681,16 +755,17 @@ def __init__( padding=1, conv_only=True, ) + # pass scale_factor from reversed_downsample_params as fallback + sf = tuple(reversed_downsample_params[i]["stride"]) blocks.append( - Upsample( + _ShapeRestoringUpsample( spatial_dims=spatial_dims, - mode="nontrainable", in_channels=block_in_ch, out_channels=block_in_ch, - interp_mode="nearest", - scale_factor=tuple(float(s) for s in upsampling_stride), post_conv=post_conv, - align_corners=None, + shape_index=i, # index into reversed downsample_shapes + downsample_shapes_ref=self.downsample_shapes, # will be updated by AutoencoderKL + scale_factor=sf, ) ) @@ -827,6 +902,17 @@ def __init__( use_flash_attention=use_flash_attention, downsample_parameters=encoder_downsample_params, ) + + # Link encoder shapes to decoder for exact size restoration + # This must be done AFTER decoder creation so that _ShapeRestoringUpsample blocks + # reference the shared list (not the empty list created during decoder init) + self.decoder.downsample_shapes = self.encoder.downsample_shapes + + # Update all _ShapeRestoringUpsample blocks to reference the shared list + for block in self.decoder.blocks: + if isinstance(block, _ShapeRestoringUpsample): + block.downsample_shapes_ref = self.encoder.downsample_shapes + self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, in_channels=latent_channels, @@ -1006,12 +1092,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the attention blocks attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") - new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") - new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") - new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") - new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") - new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + if f"{block}.to_q.weight" in old_state_dict: + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + if f"{block}.to_k.weight" in old_state_dict: + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + if f"{block}.to_v.weight" in old_state_dict: + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + if f"{block}.to_q.bias" in old_state_dict: + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + if f"{block}.to_k.bias" in old_state_dict: + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + if f"{block}.to_v.bias" in old_state_dict: + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") out_w = f"{block}.attn.out_proj.weight" out_b = f"{block}.attn.out_proj.bias" @@ -1051,7 +1143,8 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict.pop(old_name) + if old_name in old_state_dict: + new_state_dict[k] = old_state_dict.pop(old_name) if verbose: # print all remaining keys in old_state_dict print("remaining keys in old_state_dict:", old_state_dict.keys()) diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index 36272bd7cf..0eb4a8de66 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -770,6 +770,198 @@ def test_reconstruction_robustness_anisotropic_non_power_of_two_odd_dims(self): f"Latent D shape mismatch: got {z_mu.shape[4]}, expected {expected_latent_d}", ) + def test_exact_reconstruction_odd_dimensions(self): + """ + Critical test: Verify exact reconstruction for truly odd/non-divisible dimensions. + + This directly demonstrates the shape restoration architecture upgrade. + Before: would fail or produce mismatched shapes + After: exact reconstruction guaranteed + """ + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "downsample_parameters": [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ], + } + + net = AutoencoderKL(**input_param).to(device) + + # Truly odd dimensions that would fail with naive stride-based approach + x = torch.randn(1, 1, 65, 67, 17).to(device) + + with eval_mode(net): + reconstruction, z_mu, z_sigma = net(x) + + # This is the key assertion proving shape restoration works + self.assertEqual( + reconstruction.shape, x.shape, f"Reconstruction shape {reconstruction.shape} != input shape {x.shape}" + ) + + def test_multi_level_anisotropic_non_divisible_dimensions(self): + """ + Test multi-level anisotropic downsampling with non-divisible dimensions. + + Validates that shape restoration handles: + - Different stride per level + - Odd dimensions on multiple axes + - Complex spatial transforms + """ + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "downsample_parameters": [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, # Preserve Z + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, # Isotropic + ], + } + + net = AutoencoderKL(**input_param).to(device) + + # Non-divisible dimensions that would fail with scale_factor approach + x = torch.randn(1, 1, 61, 73, 19).to(device) + + with eval_mode(net): + reconstruction = net.reconstruct(x) + + self.assertEqual(reconstruction.shape, x.shape) + + def test_convtranspose_path_unchanged(self): + """ + Verify ConvTranspose upsampling path remains untouched by shape restoration. + + Shape restoration only affects nontrainable upsampling path. + ConvTranspose should maintain original behavior. + """ + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "use_convtranspose": True, # Use trainable upsampling + "downsample_parameters": [{"kernel_size": 3, "stride": 2}, {"kernel_size": 3, "stride": 2}], + } -if __name__ == "__main__": - unittest.main() + net = AutoencoderKL(**input_param).to(device) + + # Standard power-of-2 size + x = torch.randn(1, 1, 64, 64).to(device) + + with eval_mode(net): + reconstruction = net.reconstruct(x) + + # Should not crash and shape should be preserved + self.assertEqual(reconstruction.shape, x.shape) + + def test_multiple_forward_passes_different_odd_shapes(self): + """ + Test multiple forward passes with different odd-dimensional inputs. + + Validates that shape state is properly maintained/reset between passes. + Catches potential stale-state bugs in shape recording. + """ + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "downsample_parameters": [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ], + } + + net = AutoencoderKL(**input_param).to(device) + + # First odd shape + x1 = torch.randn(1, 1, 65, 67, 17).to(device) + + with eval_mode(net): + reconstruction1 = net.reconstruct(x1) + + self.assertEqual(reconstruction1.shape, x1.shape) + + # Different odd shape + x2 = torch.randn(1, 1, 71, 79, 23).to(device) + + with eval_mode(net): + reconstruction2 = net.reconstruct(x2) + + self.assertEqual(reconstruction2.shape, x2.shape) + + # Verify they're actually different shapes + self.assertNotEqual(x1.shape, x2.shape) + + def test_legacy_default_behavior_with_odd_dimensions(self): + """ + Test that legacy default behavior (downsample_parameters=None) preserves asymmetric padding + and produces correct reconstruction even with odd dimensions. + + This ensures checkpoint compatibility: models using default parameters continue to work + identically after the padding changes. + """ + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + # Explicitly no downsample_parameters - should use legacy defaults + } + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + # Test with odd dimensions - crucial for verifying legacy asymmetric padding + x = torch.randn(1, 1, 17, 19).to(device) + reconstruction, z_mu, z_sigma = net(x) + + # Reconstruction should match input shape exactly + self.assertEqual( + reconstruction.shape, + x.shape, + f"Legacy default behavior with odd dims: reconstruction {reconstruction.shape} != input {x.shape}", + ) + + # Also test with even dimensions to ensure no regression + x_even = torch.randn(1, 1, 16, 20).to(device) + reconstruction_even, _, _ = net(x_even) + self.assertEqual( + reconstruction_even.shape, + x_even.shape, + f"Legacy default behavior with even dims: reconstruction {reconstruction_even.shape} != input {x_even.shape}", + )