diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 7070a56..6d64df4 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -1457,8 +1457,8 @@ def check_valid(self) -> None: isinstance(size, int) for size in self.named_axes.values() ): raise ValueError( - "NamedArray.named_axes must be an ordered dictionary of named" - " axis shapes" + "NamedArray.named_axes must be an ordered dictionary of named axis" + " shapes" ) if any(isinstance(name, int) for name in self.named_axes.keys()): diff --git a/penzai/nn/linear_and_affine.py b/penzai/nn/linear_and_affine.py index f696faf..116105a 100644 --- a/penzai/nn/linear_and_affine.py +++ b/penzai/nn/linear_and_affine.py @@ -15,14 +15,13 @@ """Generalized linear operator layer and associated utilities.""" from __future__ import annotations - import collections import dataclasses import functools import itertools -from typing import Any, Literal, Protocol, Sequence - +from typing import Any, Literal, Protocol, Sequence, cast import jax +import abc import jax.numpy as jnp from penzai.core import named_axes from penzai.core import shapecheck @@ -32,6 +31,7 @@ from penzai.nn import layer as layer_base from penzai.nn import parameters + NamedArray = named_axes.NamedArray Parameter = variables.Parameter ParameterValue = variables.ParameterValue @@ -168,6 +168,7 @@ def variance_scaling_initializer( distribution="uniform", ) + xavier_normal_initializer = functools.partial( variance_scaling_initializer, scale=1.0, @@ -247,6 +248,37 @@ def treescope_color(self) -> tuple[str, str]: return "#eba875", "color-mix(in oklab, #eba875 25%, white)" +@struct.pytree_dataclass +class ConvInPlace(grouping.Sequential): + """Container for "in-place" convolution operators that preserve axis names. + + This is used when initializing `Conv` layers that have overlapping names in + their input and output specifications. We subclass `Sequential` to make + this layer type easier to identify and manipulate. + """ + + sublayers: list[layer_base.Layer] + + def treescope_color(self) -> tuple[str, str]: + return "#79eb75", "color-mix(in oklab, #79eb75 25%, white)" + + +@struct.pytree_dataclass +class ConvTransposeInPlace(grouping.Sequential): + """Container for "in-place" transposed convolution operators that preserve + axis names. + + This is used when initializing `ConvTranspose` layers that have overlapping + names in their input and output specifications. We subclass `Sequential` to + make this layer type easier to identify and manipulate. + """ + + sublayers: list[layer_base.Layer] + + def treescope_color(self) -> tuple[str, str]: + return "#c7eb75", "color-mix(in oklab, #c7eb75 25%, white)" + + def contract( names: str | Sequence[named_axes.AxisName], left: NamedArray, @@ -395,12 +427,82 @@ def _output_structure(self) -> shapecheck.StructureAnnotation: ) +def _maybe_rename_output_axes( + input_axes: dict[str, int], + output_axes: dict[str, int], + parallel_axes: dict[str, int], + parallel_broadcast_axes: dict[str, int], + rename_outputs_if_necessary: bool, +): + """Checks for name overlap between input and output axes, and renames if + needed to avoid collisions. + + Args: + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator should + produce. + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting linear operator will apply a different operator to each + slice. (This is similar to a block-diagonal matrix.) + parallel_broadcast_axes: Names and lengths for axes that should be treated + like `parallel_axes` but will only appear in the output. The input will be + implicitly broadcast over these axes. + rename_outputs_if_necessary: If True, renames output axes that overlap with + input axes by appending "_out" to their names. + """ + + # By default no rename & no wrapping + output_axes_after_rename = output_axes + primed_names, original_names = None, None + + if any(name in input_axes for name in output_axes): + # Name overlap! + if rename_outputs_if_necessary: + output_axes_after_rename = {} + original_names = [] + primed_names = [] + + for old_name in output_axes.keys(): + if old_name in input_axes: + primed_name = old_name + "_out" + if primed_name in input_axes: + raise ValueError( + f"Tried to rename {old_name} to {primed_name} to avoid a" + " conflict, but both names are already in input_axes. Please" + " rename axes manually to avoid this conflict." + ) + original_names.append(old_name) + primed_names.append(primed_name) + output_axes_after_rename[primed_name] = output_axes[old_name] + else: + output_axes_after_rename[old_name] = output_axes[old_name] + else: + raise ValueError( + "input_axes and output_axes must not overlap if" + " rename_outputs_if_necessary is not set; got" + f" input_axes={input_axes}, output_axes={output_axes}." + ) + + if set(parallel_axes).intersection(set(input_axes).union(output_axes)) or set( + parallel_broadcast_axes + ).intersection(set(input_axes).union(output_axes, parallel_axes)): + raise ValueError( + "parallel_axes and parallel_broadcast_axes must not overlap with" + f" each other or with input/output axes; got input_axes={input_axes}," + f" output_axes={output_axes}, parallel_axes={parallel_axes}," + f" parallel_broadcast_axes={parallel_broadcast_axes}." + ) + return output_axes_after_rename, primed_names, original_names + + @struct.pytree_dataclass class Linear(layer_base.Layer): """A generalized linear (not affine) operator, for named arrays. Applies an arbitrary contraction to the input `NamedArray` and a weight - parameter. This can be used to express an arbitrary linear operator. + parameter. This can be used to express an arbitrary dense linear operator. ``Linear`` layers are often (but not always) followed by `AddBias` to make an affine transformation. @@ -504,80 +606,45 @@ def from_config( parallel_axes = {} if parallel_broadcast_axes is None: parallel_broadcast_axes = {} - if any(name in input_axes for name in output_axes): - # Name overlap! - if rename_outputs_if_necessary: - output_axes_after_rename = {} - original_names = [] - primed_names = [] - - for old_name in output_axes.keys(): - if old_name in input_axes: - primed_name = old_name + "_out" - if primed_name in input_axes: - raise ValueError( - f"Tried to rename {old_name} to {primed_name} to avoid a" - " conflict, but both names are already in input_axes. Please" - " rename axes manually to avoid this conflict." - ) - original_names.append(old_name) - primed_names.append(primed_name) - output_axes_after_rename[primed_name] = output_axes[old_name] - else: - output_axes_after_rename[old_name] = output_axes[old_name] - - return LinearInPlace( - sublayers=[ - cls.from_config( - name=name, - init_base_rng=init_base_rng, - input_axes=input_axes, - output_axes=output_axes_after_rename, - parallel_axes=parallel_axes, - parallel_broadcast_axes=parallel_broadcast_axes, - initializer=initializer, - dtype=dtype, - rename_outputs_if_necessary=False, - ), - RenameAxes(old=tuple(primed_names), new=tuple(original_names)), - ], - ) - else: - raise ValueError( - "input_axes and output_axes must not overlap if" - " rename_outputs_if_necessary is not set; got" - f" input_axes={input_axes}, output_axes={output_axes}." - ) - if set(parallel_axes).intersection( - set(input_axes).union(output_axes) - ) or set(parallel_broadcast_axes).intersection( - set(input_axes).union(output_axes, parallel_axes) - ): - raise ValueError( - "parallel_axes and parallel_broadcast_axes must not overlap with" - f" each other or with input/output axes; got input_axes={input_axes}," - f" output_axes={output_axes}, parallel_axes={parallel_axes}," - f" parallel_broadcast_axes={parallel_broadcast_axes}." - ) + output_axes_after_rename, primed_names, original_names = ( + _maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) - return cls( + core_layer = cls( weights=parameters.make_parameter( f"{name}.weights", init_base_rng, initializer, input_axes=input_axes, - output_axes=output_axes, + output_axes=output_axes_after_rename, parallel_axes={**parallel_axes, **parallel_broadcast_axes}, convolution_spatial_axes={}, dtype=dtype, ), in_axis_names=tuple(input_axes.keys()), out_axis_names=( - tuple(output_axes.keys()) + tuple(parallel_broadcast_axes.keys()) + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) ), ) + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return LinearInPlace( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + return core_layer + def _input_structure(self): known_in_axes = { name: size @@ -771,3 +838,674 @@ class ConstantRescale(layer_base.Layer): def __call__(self, value: Any, **_unused_side_inputs) -> Any: """Scales its input by the scaling factor.""" return jax.tree_util.tree_map(lambda x: x * self.by, value) + + +def _prepare_for_conv( + inputs: NamedArray, + kernel: NamedArray, + spatial_axis_names: Sequence[str], + in_axis_names: Sequence[str], + out_axis_names: Sequence[str], +) -> tuple[NamedArray, NamedArray]: + """Preprocess lhs and rhs for jax convolution operator. + + Merges the in axes of the inputs into a single in channel axis, and merges the + out axes of the kernel into a single out channel axis. This is necessary to + use the jax convolution operator, which expects the inputs to have a single + in channel axis and the kernel to have a single out channel axis. + + Args: + inputs: The input named array. + kernel: The kernel named array. + spatial_axis_names: Names of the spatial axes in the input and kernel. + in_axis_names: Names of the input axes that will be contracted with the + kernel. + out_axis_names: Names of the output axes that will be produced by the + convolution. + Returns: + A tuple of two named arrays. The first one is the conv input with the in + axes merged into a single in channel axis. Its positional axis layout is + [spatial_axes..., channel_axis]. The second one is the convolution kernel + with the in axes merged into a single in channel axis and the out axes + merged into a single out channel axis. Its positional axis layout is + [spatial_axes..., in_channel_axis, out_channel_axis]. + """ + + lhs = inputs + rhs = kernel + + in_axis_name = named_axes.TmpPosAxisMarker() + out_axis_name = named_axes.TmpPosAxisMarker() + + # merge in axes into one in channel axis for the inputs and the kernel + lhs = lhs.untag(*in_axis_names).flatten().tag(in_axis_name) + rhs = rhs.untag(*in_axis_names).flatten().tag(in_axis_name) + + # merge out axes into one out channels axis for jax convolution + rhs = rhs.untag(*out_axis_names).flatten().tag(out_axis_name) + + # untag spatial axes + lhs = lhs.untag(*spatial_axis_names, in_axis_name) + rhs = rhs.untag(*spatial_axis_names, in_axis_name, out_axis_name) + return lhs, rhs + + +def _get_named_axis_back_after_conv( + result: NamedArray, + spatial_axis_names: Sequence[str], + out_axis_names: Sequence[str], + out_axis_shape: Sequence[int], +) -> NamedArray: + """Postprocess result from jax convolution operator + + Restores the spatial axes and output axes to the result of the jax convolution + operator. The spatial axes are tagged back, and the output axes are reshaped + to the original shape and tagged back. It supposes that the result have a + positional axis layout of [spatial_axes..., out_axis] with out_axis of + size equals to the product of the dimensions in out_axis_shape. This is + necessary to restore the desired shape of the output after the convolution + operator has been applied, since the convolution operates on positional + spatial axes and only outputs a single out_axis. + + Args: + result: The result of the jax convolution operator. + spatial_axis_names: Names of the spatial axes in the input and kernel. + out_axis_names: Names of the output axes that will be produced by the + convolution. + out_axis_shape: The shape of the output axes, which will be used to reshape + the result back to the original shape. + Returns: + A named array with the spatial axes and output axes tagged back, and the + output axes reshaped to the original shape. + """ + return ( + result.tag_prefix(*spatial_axis_names) + .reshape(out_axis_shape) + .tag(*out_axis_names) + ) + + +def _maybe_broadcast(value: int | Sequence[int], count: int) -> Sequence[int]: + """Broadcasts a value to a sequence of the given count. + + If the value is an integer, it will be repeated `count` times. + If the value is already a sequence, it will be returned as is. + + Args: + value: The value to broadcast, either an integer or a sequence of integers. + count: The number of times to repeat the value if it is an integer. + Returns: + A sequence of integers with the value repeated `count` times if it was an + integer, or the original sequence if it was already a sequence. + """ + + if isinstance(value, int): + return [value] * count + else: + assert ( + len(value) == count + ), "If value is a sequence, it must match the count." + return value + + +def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: + """Returns the dimension numbers for a convolution operator. + Args: + ndim: The number of spatial dimensions of the convolution operator. + Returns: + A `jax.lax.ConvDimensionNumbers` object that specifies the dimension numbers + for the convolution operator. It assumes that the input and output have the + following positional axis layout: [B, Spatial..., C] and the kernel has the + following positional axis layout: [Spatial..., I, O], where B is the batch + axis, C is the channel axis, I is the input channel axis, and O is the + output channel axis. It matches the result of _prepare_for_conv. + """ + + return jax.lax.ConvDimensionNumbers( + lhs_spec=(0, ndim + 1) + + tuple(range(1, ndim + 1)), # BCSpatial -> BCSpatial + rhs_spec=(ndim + 1, ndim) + tuple(range(ndim)), # SpatialIO -> OISpatial + out_spec=(0, ndim + 1) + + tuple(range(1, ndim + 1)), # BSpatialC -> BCSpatial + ) + + +@struct.pytree_dataclass +class AbstractGeneralConv(layer_base.Layer): + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + + def __call__(self, in_array: NamedArray, **_side_inputs) -> NamedArray: + """Runs the Convolution operator.""" + in_struct = self._input_structure() + + # pytype: disable=attribute-error + if isinstance( + self.kernel, + Parameter | ParameterValue, + ) and self.kernel.label.endswith(".kernel"): + error_prefix = f"({self.kernel.label[: 7]}) " + else: + error_prefix = "" + # pytype: enable=attribute-error + + dimvars = shapecheck.check_structure( + in_array, in_struct, error_prefix=error_prefix # TODO: here + ) + + lhs, rhs = _prepare_for_conv( + in_array, + self.kernel.value, + self.spatial_axis_names, + self.in_axis_names, + self.out_axis_names, + ) + + if self._is_transposed(): + # Perform actual transposed convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_transpose( + lhs=lhs[None, ...], + rhs=rhs, + strides=self.strides, + padding=self.padding, + rhs_dilation=self.kernel_dilation, + dimension_numbers=_get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + else: + # Perform actual convolution + result = named_axes.nmap( + lambda lhs, rhs: jax.lax.conv_general_dilated( + lhs=lhs[None, ...], + rhs=rhs, + window_strides=self.strides, + padding=self.padding, + lhs_dilation=self.inputs_dilation, + rhs_dilation=self.kernel_dilation, + dimension_numbers=_get_dimension_numbers( + ndim=len(self.spatial_axis_names) + ), + )[0] + )(lhs, rhs) + + result = _get_named_axis_back_after_conv( + result, + self.spatial_axis_names, + self.out_axis_names, + [self.output_axes[name] for name in self.out_axis_names], + ) + + out_struct = self._output_structure() + shapecheck.check_structure( + result, out_struct, known_vars=dimvars, error_prefix=error_prefix + ) + return result + + @classmethod + def _from_config( + cls, + inplace_class: type[ConvInPlace | ConvTransposeInPlace], + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + inputs_dilation: int | Sequence[int] = 1, + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> AbstractGeneralConv | ConvInPlace | ConvTransposeInPlace: + """Constructs a ``AbstractGeneralConv`` layer from a configuration. + + This can be used when building a new convolution or transposed convolution + operator at the start of training. For more details see Conv or + ConvTranspose. + """ + + spatial_dim_count = len(convolution_spatial_axes) + + strides = _maybe_broadcast(strides, spatial_dim_count) + inputs_dilation = _maybe_broadcast(inputs_dilation, spatial_dim_count) + kernel_dilation = _maybe_broadcast(kernel_dilation, spatial_dim_count) + + if parallel_axes is None: + parallel_axes = {} + if parallel_broadcast_axes is None: + parallel_broadcast_axes = {} + + output_axes_after_rename, primed_names, original_names = ( + _maybe_rename_output_axes( + input_axes, + output_axes, + parallel_axes, + parallel_broadcast_axes, + rename_outputs_if_necessary, + ) + ) + + core_layer = cls( + kernel=parameters.make_parameter( + f"{name}.kernel", + init_base_rng, + initializer, + input_axes=input_axes, + output_axes=output_axes_after_rename, + parallel_axes={**parallel_axes, **parallel_broadcast_axes}, + convolution_spatial_axes=convolution_spatial_axes, + dtype=dtype, + ), + strides=strides, + padding=padding, + inputs_dilation=inputs_dilation, + kernel_dilation=kernel_dilation, + spatial_axis_names=tuple(convolution_spatial_axes.keys()), + in_axis_names=tuple(input_axes.keys()), + out_axis_names=( + tuple(output_axes_after_rename.keys()) + + tuple(parallel_broadcast_axes.keys()) + ), + ) + + # if name overlap wrap layer + if primed_names is not None and original_names is not None: + return inplace_class( + sublayers=[ + core_layer, + RenameAxes(old=tuple(primed_names), new=tuple(original_names)), + ], + ) + + return core_layer + + @abc.abstractmethod + def _is_transposed(self) -> bool: + ... + + def _input_structure(self): + known_in_axes = { + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.out_axis_names + and name not in self.spatial_axis_names + } + return shapecheck.ArraySpec( + named_shape={**shapecheck.var("In"), **known_in_axes}, + dtype=jnp.floating, + ) + + def _output_structure(self): + known_out_axes = { + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.in_axis_names + and name not in self.spatial_axis_names + } + print(f"known_out_axes: {known_out_axes}") + return shapecheck.ArraySpec( + named_shape={**shapecheck.var("Out"), **known_out_axes}, + dtype=jnp.floating, + ) + + @property + def input_axes(self) -> dict[str, int]: + """The axis names and sizes that should appear in the input only.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.in_axis_names + } + + @property + def output_axes(self) -> dict[str, int]: + """The axis names and sizes that will appear in the output only.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.out_axis_names + } + + @property + def parallel_axes(self) -> dict[str, int]: + """The axis names and sizes that should appear in both input and output.""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name not in self.spatial_axis_names + and name not in self.in_axis_names + and name not in self.out_axis_names + } + + @property + def convolution_spatial_axes(self) -> dict[str, int]: + """The spatial axis names and sizes of the convolution kernel that should + appear in both input and output. Note that that the sizes are only related + to the kernel shape""" + return { # pytype: disable=bad-return-type + name: size + for name, size in self.kernel.value.named_shape.items() + if name in self.spatial_axis_names + } + + +@struct.pytree_dataclass +class Conv(AbstractGeneralConv): + """A general convolution operator, for named arrays. + + Applies an arbitrary contraction to the input `NamedArray` and a weight + parameter. This can be used to express an arbitrary linear convolution + operator. + + Attributes: + kernel: The named array holding the kernel for the convlution operator. + strides: The stride of the convolution operator + padding: The padding to apply to the input before the convolution + inputs_dilation: The input dilation of the convolution + kernel_dilation: The kernel dilation of the convolution + convolution_spatial_axis_names: The names of the spatial axes over wich to + apply the convolution operator + in_axis_names: The names of the axes to contract with the input, removing + them. + out_axis_names: The names of the axes that should not appear in the input + and will be inserted into the output. + """ + + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + + @classmethod + def from_config( + cls, + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + inputs_dilation: int | Sequence[int] = 1, + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> Conv | ConvInPlace: + """Constructs a ``Conv`` layer from a configuration. + + This can be used when building a new convolution operator at the start of + training. + + Note: For the purposes of the initializer, the ``parallel_axes`` and + ``parallel_broadcast_axes`` are treated in the same way, without + participating in output-dimension variance scaling. However, after + initialization, the ``parallel_broadcast_axes`` will be treated like extra + output axes (and assumed not to be present in the input). + + Args: + name: The name of the layer. + init_base_rng: The base RNG to use for initializing model parameters. + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator + should produce. If any axis names overlap with ``input_axes``, the + argument ``rename_outputs_if_necessary`` must be True. + convolution_spatial_axes: Names and lengths for the spatial axes of the + convolution kernel. + strides: strides of the convolution, if strides is an integer, it is + broadcasted to every spatial dimensions + padding: The padding to apply to the input before the convolution. Can be + either the strings ‘SAME’, ‘SAME_LOWER’, or ‘VALID’, or a sequence + of n (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. ‘SAME’ and ‘SAME_LOWER’ add padding to + produce same output size as the input when the stride is one. The + padding is split between the two sides equally or almost equally. In + case the padding is an odd number, the extra padding is added at the end + for ‘SAME’ and at the + beginning for ‘SAME_LOWER’. + inputs_dilation: inputs dilation of the convolution, if inputs_dilation is + an integer, it is broadcasted to every spatial dimensions + kernel_dilation: kernel dilation of the convolution, if kernel_dilation is + an integer, it is broadcasted to every spatial dimensions + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting convolution operator will apply a different operator to + each slice. (This is similar to a grouped convolution) Must not overlap + with any axes named in ``input_axes`` or ``output_axes``. + parallel_broadcast_axes: Names and lengths for axes that should be treated + like ``parallel_axes`` but will only appear in the output. The input + will be implicitly broadcast over these axes. Must not overlap with any + axes named in ``input_axes``, ``output_axes`` or ``parallel_axes``. + initializer: Function to use to initialize the kernel. + dtype: Dtype for the kernel. + rename_outputs_if_necessary: If True, and if ``output_axes`` and + ``input_axes`` have overlapping names, avoids name conflicts by adding + "primed" versions of the overlapping names, and returns an instance of + `ConvInPlace` instead of a ``Conv`` layer directly. + + Returns: + A ``Conv`` layer with uninitialized kernel, or possibly a + `ConvInPlace` layer if ``rename_outputs_if_necessary`` is True and + ``input_axes`` overlaps with ``output_axes``. + """ + + layer = super()._from_config( + inplace_class=ConvInPlace, + name=name, + init_base_rng=init_base_rng, + input_axes=input_axes, + output_axes=output_axes, + convolution_spatial_axes=convolution_spatial_axes, + strides=strides, + padding=padding, + inputs_dilation=inputs_dilation, + kernel_dilation=kernel_dilation, + parallel_axes=parallel_axes, + parallel_broadcast_axes=parallel_broadcast_axes, + initializer=initializer, + dtype=dtype, + rename_outputs_if_necessary=rename_outputs_if_necessary, + ) + if isinstance(layer, AbstractGeneralConv): + return cast(Conv, layer) + assert isinstance(layer, ConvInPlace) + return layer + + def _is_transposed(self): + return False + + def treescope_color(self) -> str: + return "#79eb75" + + +@struct.pytree_dataclass +class ConvTranspose(AbstractGeneralConv): + """A general transposed convolution operator, for named arrays. + + Applies an arbitrary contraction to the input `NamedArray` and a kernel + parameter. This can be used to express an arbitrary linear transposed + convolution operator. + + Attributes: + kernel: The named array holding the kernel for the convlution operator. + strides: The stride of the convolution operator + padding: The padding to apply to the input before the convolution + kernel_dilation: The kernel dilation of the convolution + convolution_spatial_axis_names: The names of the spatial axes over wich to + apply the convolution operator + in_axis_names: The names of the axes to contract with the input, removing + them. + out_axis_names: The names of the axes that should not appear in the input + and will be inserted into the output. + """ + + kernel: parameters.ParameterLike[NamedArray] + strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False}) + padding: str | Sequence[tuple[int, int]] = dataclasses.field( + metadata={"pytree_node": False} + ) + + spatial_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + in_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + out_axis_names: tuple[str, ...] = dataclasses.field( + metadata={"pytree_node": False} + ) + + kernel_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + inputs_dilation: Sequence[int] = dataclasses.field( + metadata={"pytree_node": False} + ) + + @classmethod + def from_config( + cls, + name: str, + init_base_rng: jax.Array | None, + input_axes: dict[str, int], + output_axes: dict[str, int], + convolution_spatial_axes: dict[str, int], + strides: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", + kernel_dilation: int | Sequence[int] = 1, + parallel_axes: dict[str, int] | None = None, + parallel_broadcast_axes: dict[str, int] | None = None, + initializer: LinearOperatorWeightInitializer = xavier_uniform_initializer, + dtype: jax.typing.DTypeLike = jnp.float32, + rename_outputs_if_necessary: bool = True, + ) -> ConvTranspose | ConvTransposeInPlace: + """Constructs a ``Conv`` layer from a configuration. + + This can be used when building a new convolution operator at the start of + training. + + Note: For the purposes of the initializer, the ``parallel_axes`` and + ``parallel_broadcast_axes`` are treated in the same way, without + participating in output-dimension variance scaling. However, after + initialization, the ``parallel_broadcast_axes`` will be treated like extra + output axes (and assumed not to be present in the input). + + Args: + name: The name of the layer. + init_base_rng: The base RNG to use for initializing model parameters. + input_axes: Names and lengths for axes that the linear operator should + contract over. + output_axes: Names and lengths for new axes that the linear operator + should produce. If any axis names overlap with ``input_axes``, the + argument ``rename_outputs_if_necessary`` must be True. + convolution_spatial_axes: Names and lengths for the spatial axes of the + convolution kernel. + strides: strides of the convolution, if strides is an integer, it is + broadcasted to every spatial dimensions + padding: The padding to apply to the input before the convolution. Can be + either the strings ‘SAME’, ‘SAME_LOWER’, or ‘VALID’, or a sequence + of n (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. ‘SAME’ and ‘SAME_LOWER’ add padding to + produce same output size as the input when the stride is one. The + padding is split between the two sides equally or almost equally. In + case the padding is an odd number, the extra padding is added at the end + for ‘SAME’ and at the + beginning for ‘SAME_LOWER’. + kernel_dilation: kernel dilation of the convolution, if kernel_dilation is + an integer, it is broadcasted to every spatial dimensions + parallel_axes: Names and lengths for axes that should be processed in + parallel. These axes should appear in both the input and the output, and + the resulting convolution operator will apply a different operator to + each slice. (This is similar to a grouped convolution) Must not overlap + with any axes named in ``input_axes`` or ``output_axes``. + parallel_broadcast_axes: Names and lengths for axes that should be treated + like ``parallel_axes`` but will only appear in the output. The input + will be implicitly broadcast over these axes. Must not overlap with any + axes named in ``input_axes``, ``output_axes`` or ``parallel_axes``. + initializer: Function to use to initialize the kernel. + dtype: Dtype for the kernel. + rename_outputs_if_necessary: If True, and if ``output_axes`` and + ``input_axes`` have overlapping names, avoids name conflicts by adding + "primed" versions of the overlapping names, and returns an instance of + `ConvInPlace` instead of a ``Conv`` layer directly. + + Returns: + A ``ConvTranspose`` layer with uninitialized kernel, or possibly a + `ConvTransposeInPlace` layer if ``rename_outputs_if_necessary`` is True + and ``input_axes`` overlaps with ``output_axes``. + """ + layer = super()._from_config( + inplace_class=ConvTransposeInPlace, + name=name, + init_base_rng=init_base_rng, + input_axes=input_axes, + output_axes=output_axes, + convolution_spatial_axes=convolution_spatial_axes, + strides=strides, + padding=padding, + kernel_dilation=kernel_dilation, + inputs_dilation=1, # not used for transposed convolutions + parallel_axes=parallel_axes, + parallel_broadcast_axes=parallel_broadcast_axes, + initializer=initializer, + dtype=dtype, + rename_outputs_if_necessary=rename_outputs_if_necessary, + ) + if isinstance(layer, AbstractGeneralConv): + return cast(ConvTranspose, layer) + + assert isinstance(layer, ConvTransposeInPlace) + return layer + + def _is_transposed(self): + return True + + def treescope_color(self) -> str: + return "#c7eb75" diff --git a/penzai/pz/nn.py b/penzai/pz/nn.py index fcd79fc..6069c86 100644 --- a/penzai/pz/nn.py +++ b/penzai/pz/nn.py @@ -74,6 +74,10 @@ Linear, LinearOperatorWeightInitializer, LinearInPlace, + Conv, + ConvInPlace, + ConvTranspose, + ConvTransposeInPlace, RenameAxes, contract, variance_scaling_initializer, diff --git a/tests/nn/linear_and_affine_test.py b/tests/nn/linear_and_affine_test.py index 3c62467..cc1487e 100644 --- a/tests/nn/linear_and_affine_test.py +++ b/tests/nn/linear_and_affine_test.py @@ -18,6 +18,7 @@ import chex import jax from penzai import pz +from penzai.toolshed import jit_wrapper class LinearAndAffineTest(absltest.TestCase): @@ -163,6 +164,254 @@ def test_affine(self): ), ) + def test_conv_shape(self): + layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 10, + "width": 15, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_strided_conv_shape(self): + layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + strides=(2, 2), + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 16, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 5, + "width": 8, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_conv_jit_wrapper(self): + layer = jit_wrapper.Jitted( + pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + ) + layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + + def test_conv_value(self): + inputs = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) + ) + + pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3, 7)).tag( + "batch", "height", "width", "foo", "baz" + ) + + simple_layer = pz.nn.Conv.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3, "baz": 7}, + output_axes={"foo_out": 5, "baz_out": 11}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes=None, + parallel_broadcast_axes=None, + rename_outputs_if_necessary=True, + ) + + pz_outputs = ( + simple_layer(pz_inputs) + .untag("batch", "height", "width", "foo_out", "baz_out") + .reshape((1, 10, 15, 5 * 11)) + .unwrap() + ) + + # build equivalent jax conv + kernel = ( + simple_layer.kernel.value.untag( + "height", + "width", + "foo", + "baz", + "foo_out", + "baz_out", + ) + .reshape(3, 3, 3 * 7, 5 * 11) + .unwrap() + ) + outputs = jax.lax.conv_general_dilated( + inputs, + kernel, + window_strides=(1, 1), + padding="SAME", + dimension_numbers=("NHWC", "HWIO", "NHWC"), + ) + + chex.assert_trees_all_equal(pz_outputs, outputs) + + def test_conv_transpose_shape(self): + layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 10, + "width": 15, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_strided_conv_transpose_shape(self): + layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + strides=(2, 2), + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + result = layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 16, "foo": 3, "baz": 7}), + ) + pz.chk.check_structure( + result, + pz.chk.ArraySpec( + named_shape={ + "batch": 1, + "height": 20, + "width": 32, + "foo": 5, + "baz": 7, + "qux": 11, + } + ), + ) + + def test_conv_transpose_value(self): + inputs = jax.random.normal( + key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3 * 7) + ) + + pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3, 7)).tag( + "batch", "height", "width", "foo", "baz" + ) + + simple_layer = pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3, "baz": 7}, + output_axes={"foo_out": 5, "baz_out": 11}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes=None, + parallel_broadcast_axes=None, + rename_outputs_if_necessary=True, + ) + + pz_outputs = ( + simple_layer(pz_inputs) + .untag("batch", "height", "width", "foo_out", "baz_out") + .reshape((1, 10, 15, 5 * 11)) + .unwrap() + ) + + # build equivalent jax conv transpose + kernel = ( + simple_layer.kernel.value.untag( + "height", + "width", + "foo", + "baz", + "foo_out", + "baz_out", + ) + .reshape(3, 3, 3 * 7, 5 * 11) + .unwrap() + ) + outputs = jax.lax.conv_transpose( + inputs, + kernel, + padding="SAME", + strides=(1, 1), + dimension_numbers=("NHWC", "HWIO", "NHWC"), + ) + + chex.assert_trees_all_equal(pz_outputs, outputs) + + def test_conv_transposed_jit_wrapper(self): + layer = jit_wrapper.Jitted( + pz.nn.ConvTranspose.from_config( + name="test", + init_base_rng=jax.random.key(1), + input_axes={"foo": 3}, + output_axes={"foo": 5}, + convolution_spatial_axes={"height": 3, "width": 3}, + parallel_axes={"baz": 7}, + parallel_broadcast_axes={"qux": 11}, + rename_outputs_if_necessary=True, + ) + ) + layer( + pz.nx.ones({"batch": 1, "height": 10, "width": 15, "foo": 3, "baz": 7}), + ) + def test_constant_rescale(self): layer = pz.nn.ConstantRescale(3.0) result = layer(pz.nx.ones({"foo": 3})) diff --git a/uv.lock b/uv.lock index 7e1588b..575abf5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -359,7 +360,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -862,7 +863,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1980,7 +1981,6 @@ name = "nvidia-nccl-cu12" version = "2.20.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, ] @@ -1989,7 +1989,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.6.68" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, ] @@ -2211,6 +2210,7 @@ requires-dist = [ { name = "treescope", specifier = ">=0.1.9" }, { name = "typing-extensions", specifier = ">=4.2" }, ] +provides-extras = ["dev", "docs", "extras", "notebook"] [[package]] name = "pexpect" @@ -3438,19 +3438,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -3491,7 +3491,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [