Skip to content
36 changes: 32 additions & 4 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,32 @@
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}


class _PatchRearrange(nn.Module):
"""Fallback patch rearrangement using pure PyTorch, for einops compatibility."""

def __init__(self, spatial_dims: int, patch_size: tuple) -> None:
super().__init__()
self.spatial_dims = spatial_dims
self.patch_size = patch_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, channels = x.shape[0], x.shape[1]
sp = x.shape[2:]
g = tuple(s // p for s, p in zip(sp, self.patch_size))
v: list[int] = [batch, channels]
for gi, pi in zip(g, self.patch_size):
v += [gi, pi]
x = x.view(*v)
n = self.spatial_dims
gdims = list(range(2, 2 + 2 * n, 2))
pdims = list(range(3, 3 + 2 * n, 2))
x = x.permute(0, *gdims, *pdims, 1).contiguous()
n_patches = 1
for gi in g:
n_patches *= gi
return x.reshape(batch, n_patches, -1)


class PatchEmbeddingBlock(nn.Module):
"""
A patch embedding block, based on: "Dosovitskiy et al.,
Expand Down Expand Up @@ -97,14 +123,16 @@ def __init__(
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.proj_type == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
# for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)}
self.patch_embeddings = nn.Sequential(
Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
)
try:
rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len)
except TypeError:
rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size))
self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size))
Comment on lines +129 to +138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

einops Rearrange pattern integer literals support 0.8

💡 Result:

No, einops Rearrange patterns do not support standalone integer literals (e.g., 'b 2 h w') as axis specifications. Patterns use symbolic axis names (letters like 'b', 'h'), anonymous axes (numbers in parentheses like '(b1 b2)'), ellipsis (...), or underscores (_ for skipping). Integer values are provided via the axes_lengths keyword argument (e.g., rearrange(x, 'b (h1 h) w c -> ...', h1=2)) or parsed from tensor shapes via parse_shape. This syntax has been consistent across versions, including 0.8.0+ (released 2024-04-28), with no changes introducing integer literal support in patterns per release notes, changelogs, and docs. Examples in official documentation and tutorials confirm named axes only.

Citations:


🏁 Script executed:

# Get the file and examine _PatchRearrange class and test coverage
head -150 monai/networks/blocks/patchembedding.py | tail -100

Repository: Project-MONAI/MONAI

Length of output: 4640


🏁 Script executed:

# Check for existing tests covering the fallback path
find . -name "*.py" -path "*/test*" -type f | xargs grep -l "patchembedding" | head -5

Repository: Project-MONAI/MONAI

Length of output: 130


🏁 Script executed:

# Get the _PatchRearrange class definition (around lines 32-55)
sed -n '32,56p' monai/networks/blocks/patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 958


🏁 Script executed:

# Check the test file for coverage of _PatchRearrange and the fallback path
cat monai/tests/networks/blocks/test_patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 145


🏁 Script executed:

# Find the test file with correct path
find . -name "test_patchembedding.py" -type f 2>/dev/null

Repository: Project-MONAI/MONAI

Length of output: 109


🏁 Script executed:

# Also check for any tests that import or test _PatchRearrange
rg "_PatchRearrange" --type py

Repository: Project-MONAI/MONAI

Length of output: 275


🏁 Script executed:

# Read the test file
cat ./tests/networks/blocks/test_patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 7510


🏁 Script executed:

# Check einops version constraints and any comments about the fallback logic
head -30 monai/networks/blocks/patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 1308


🏁 Script executed:

# Verify if x.view() in _PatchRearrange could fail on non-contiguous tensors
rg "x\.view\(" monai/networks/blocks/patchembedding.py -A 2 -B 2

Repository: Project-MONAI/MONAI

Length of output: 233


Retract the simplification suggestion; the current try/except approach is necessary.

einops Rearrange does not support integer literals in patterns—numeric axis values must be passed via axes_lengths (the current axes_len dict). The suggested approach of embedding integers as literals (e.g., "b c (h 16) (w 16) (d 16)") is not feasible. The fallback with _PatchRearrange is the correct strategy for version compatibility.

However, address these remaining issues in _PatchRearrange:

  1. Missing Google-style docstrings: Add docstrings to __init__ and forward methods describing arguments, return values, and behavior per coding guidelines.
  2. Type hint specificity: Change patch_size: tuple to patch_size: tuple[int, ...].
  3. Use reshape() instead of view(): Line 47 uses x.view(*v), which fails on non-contiguous tensors; reshape() is safer.
  4. Incomplete test coverage: The test suite only exercises the Rearrange path (since einops is installed). The fallback is never deterministically validated. Add a test that directly instantiates and tests _PatchRearrange independently.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/blocks/patchembedding.py` around lines 126 - 135, The
_PatchRearrange fallback needs fixes: add Google-style docstrings to the class
methods __init__ and forward describing arguments, return values, and behavior;
change the type hint patch_size: tuple to patch_size: tuple[int, ...]; replace
any use of x.view(*v) with x.reshape(*v) to avoid errors on non-contiguous
tensors; and add a deterministic unit test that directly instantiates and
exercises _PatchRearrange (independent of einops/Rearrange) to validate its
behavior for representative spatial_dims/patch_size combinations.

Comment on lines +134 to +138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fallback path isn't deterministically covered by tests.

_PatchRearrange only executes when Rearrange(..., **axes_len) raises TypeError, i.e. only on einops ≥ 0.8. Existing test_shape cases (tests/networks/blocks/test_patchembedding.py) exercise whichever branch the installed einops version selects, so CI never compares both paths in the same run. Suggest a targeted test that forces the fallback — e.g. monkey-patch Rearrange to raise TypeError, or instantiate _PatchRearrange directly and compare its output against the Rearrange path for a known input.

As per coding guidelines: "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/blocks/patchembedding.py` around lines 131 - 135, Tests don't
deterministically exercise the fallback _PatchRearrange because current tests
run whichever einops version is installed; add a unit test that forces the
fallback by monkey-patching the Rearrange symbol to raise TypeError (or by
directly instantiating _PatchRearrange) and assert that outputs of the fallback
match the normal Rearrange path for a representative input shape; target names
to modify/assert are _PatchRearrange, Rearrange and the patch embedding behavior
(e.g., the patch_embeddings sequence) in the existing
tests/networks/blocks/test_patchembedding.py so both branches are covered and
compared.

self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)

Expand Down
Loading