diff --git a/docs_nnx/flip/5310-tree-mode-nnx.md b/docs_nnx/flip/5310-tree-mode-nnx.md index f06e611de..8925a7b23 100644 --- a/docs_nnx/flip/5310-tree-mode-nnx.md +++ b/docs_nnx/flip/5310-tree-mode-nnx.md @@ -67,25 +67,25 @@ These new transforms are highly simplified compared to current transforms, they ```py def transform_wrapper(*args): if graph: args = to_tree(args) - check_no_aliases(args=args) + variables = check_no_aliases(args=args) @jax_transform def transformed_f(*args): - updates, snapshot = updates_and_snapshot(args) + current, prev = snapshot(labeled(args=args)) if graph: args = from_tree(args) out = f(*args) if graph: out = to_tree(out) - check_no_aliases(args=updates, out=out) - updates = mask_variable_updates(updates, snapshot) + check_no_aliases(**current, out=out) + updates = get_updates(current, prev) return out, updates out, updates = transformed_f(*args) - apply_variable_updates(args, updates) + apply_updates(variables, updates) if graph: out = from_tree(out) return out ``` -The transformed function tracks input Variable `updates`, applies f, and masks Variable updates (no updates for Variables that didn’t change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs, we simply convert objects to a tree representation before passing them to jax, and back to graphs before passing them to the user code. +The transformed function tracks input Variable, applies `f`, and creates the Variable updates (no updates for Variables that didn't change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs, we simply convert objects to a tree representation before passing them to jax, and back to graphs before passing them to the user code. ## Backward Compatibility diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index c464fcf30..cc816075b 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -16,23 +16,60 @@ from collections import namedtuple import dataclasses import functools +import inspect import typing as tp from flax import struct from flax import typing -from flax.nnx import filterlib, graphlib, variablelib +from flax.nnx import filterlib, graphlib, reprlib, variablelib +from flax.nnx.statelib import State from flax.nnx.pytreelib import Pytree from flax.typing import Missing, PathParts import jax A = tp.TypeVar('A') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable) Index = int KeyPath = tuple[tp.Hashable, ...] Prefix = tp.Any Leaf = tp.Any +class OrderedDict(reprlib.MappingReprMixin, dict[A, B]): + pass + + +def _ordered_dict_flatten_with_keys(d: OrderedDict): + children = [(jax.tree_util.DictKey(k), v) for k, v in d.items()] + return children, tuple(d.keys()) + +jax.tree_util.register_pytree_with_keys( + OrderedDict, + _ordered_dict_flatten_with_keys, + lambda keys, values: OrderedDict(zip(keys, values)), +) + +_labeled_tuples_cache: dict[tuple[str, ...], type[tp.Any]] = {} + + +def labeled(**kwargs): + keys = tuple(kwargs.keys()) + if keys not in _labeled_tuples_cache: + class LabeledTuple(namedtuple('LabeledTuple', keys)): + def keys(self): + return self._fields + + def __getitem__(self, key): + if isinstance(key, str): + return getattr(self, key) + return super().__getitem__(key) + + _labeled_tuples_cache[keys] = LabeledTuple + return _labeled_tuples_cache[keys](**kwargs) + + class PrefixMapping(abc.ABC): @abc.abstractmethod def map_prefix( @@ -237,10 +274,10 @@ def broadcast_prefix_map( class GraphDefState(struct.PyTreeNode): graphdef: graphlib.GraphDef[tp.Any] = struct.field(pytree_node=False) - state: graphlib.State = struct.field(pytree_node=True) + state: State = struct.field(pytree_node=True) S = tp.TypeVar( - 'S', bound=graphlib.State | graphlib.GraphFlatState | list[tp.Any] + 'S', bound=State | graphlib.GraphFlatState | list[tp.Any] ) class NodeStates(struct.PyTreeNode): @@ -457,7 +494,7 @@ def _to_node_states(leaf): return jax.tree.unflatten(treedef, leaves_out) -def from_tree2(tree: tp.Any, /) -> tp.Any: +def from_tree2(tree: tp.Any, /, recreate_variables: bool = True) -> tp.Any: index_ref = graphlib.IndexMap() def _from_node_states(x): @@ -466,6 +503,7 @@ def _from_node_states(x): state = graphlib._merge_to_flat_state((x.state,)) return graphlib.unflatten( x.graphdef, state, index_ref=index_ref, + recreate_variables=recreate_variables, ) return jax.tree.map( @@ -557,19 +595,15 @@ def clear_non_graph_nodes(tree): class Mask(tp.NamedTuple): pass -def mask_at(t: tuple, index: int | None) -> tuple: - if index is None: - return t - return tuple( +def mask_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]: + if index is None or not isinstance(t, tuple): + return None, t + x = t[index] + new_t = tuple( Mask() if i == index else x for i, x in enumerate(t) ) - -def replace_at(t: tuple, index: int, value: tp.Any) -> tuple: - return tuple( - value if i == index else x - for i, x in enumerate(t) - ) + return x, new_t def slice_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]: @@ -578,15 +612,17 @@ def slice_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]: return t[index], t[:index] + t[index + 1 :] -def insert_at(t: tuple, index: int | None, value: tp.Any) -> tuple: +def replace_at(t: tuple, index: int | None, value: tp.Any) -> tuple: if index is None: return t xs = list(t) - xs.insert(index, value) + xs[index] = value return tuple(xs) def find(t: tuple, value: tp.Any) -> int | None: + if not isinstance(t, tuple): + return None return next((i for i, x in enumerate(t) if x == value), None) @@ -602,69 +638,50 @@ def extract( tree: tp.Any, *, is_leaf: tp.Callable[[tp.Any], bool] | None = None, + prefix_leaf: tp.Callable[[tp.Any], bool] | None = None, ) -> tuple[tp.Any, list[tp.Any]]: extracted: list[tp.Any] = [] - def _leaf_fn(path: jax.tree_util.KeyPath, prefix_leaf: tp.Any, leaf: tp.Any): + def _leaf_fn(path, prefix_leaf, leaf): if f(path, prefix_leaf, leaf): idx = len(extracted) extracted.append(leaf) return ExtractIndex(idx) return leaf - full_prefix = jax.tree.broadcast(prefix, tree, is_leaf=is_leaf) - new_tree = jax.tree.map_with_path(_leaf_fn, full_prefix, tree, is_leaf=is_leaf) + new_tree = broadcast_prefix_map( + _leaf_fn, prefix, tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf + ) return new_tree, extracted -def insert( - tree: tp.Any, - extracted: list[tp.Any], - is_leaf: tp.Callable[[tp.Any], bool] | None = None, -) -> tp.Any: - if is_leaf is None: - _is_leaf = lambda x: isinstance(x, ExtractIndex) - else: - _is_leaf = lambda x: isinstance(x, ExtractIndex) or is_leaf(x) - - def _leaf_fn(leaf: tp.Any): +def insert(tree: tp.Any, extracted: list[tp.Any]) -> tp.Any: + def _leaf_fn(leaf): if isinstance(leaf, ExtractIndex): return extracted[leaf.index] return leaf - return jax.tree.map(_leaf_fn, tree, is_leaf=_is_leaf) - + return jax.tree.map( + _leaf_fn, tree, is_leaf=lambda x: isinstance(x, ExtractIndex) + ) -def updates_and_snapshot(args: A) -> tuple[A, A]: +def snapshot(args: A) -> tuple[A, A]: is_leaf = lambda x: isinstance(x, variablelib.Variable) - leaves, treedef = jax.tree.flatten(args, is_leaf=is_leaf) - updates_leaves: list[variablelib.Variable | Mask] = [] - snapshot_leaves: list[variablelib.Variable | Mask] = [] - for leaf in leaves: - if isinstance(leaf, variablelib.Variable): - updates_leaves.append(leaf) - # don't snapshot hijax or ref Variables as their updates are automatically - # masked out in mask_variable_updates. However, the leaf is kept in the - # updates to check for aliasing. This avoids a copy operation which has - # significance for ref Variables. - if leaf.hijax or leaf.ref: - snapshot_leaves.append(Mask()) - else: - snapshot_leaves.append(leaf.copy()) - else: - updates_leaves.append(Mask()) - snapshot_leaves.append(Mask()) - updates = jax.tree.unflatten(treedef, updates_leaves) - snapshot = jax.tree.unflatten(treedef, snapshot_leaves) - return updates, snapshot + current = jax.tree.map(lambda x: x, args, is_leaf=is_leaf) + snapshot = jax.tree.map(lambda x: x, args) + return current, snapshot +def copy_var_structure(tree: A) -> A: + return jax.tree.map( + lambda x: x, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) + ) def check_no_aliases( - fn_name: str, /, *, check_can_update: tp.Iterable[str] = (), **kwargs -): - Attrs = namedtuple('Attrs', kwargs.keys()) # type: ignore[misc] - container = Attrs(**kwargs) + fn_name: str, /, *, check: tp.Iterable[str] = (), **kwargs +) -> dict[jax.tree_util.KeyPath, variablelib.Variable]: + container = labeled(**kwargs) is_leaf = lambda x: isinstance(x, variablelib.Variable) seen: dict[int, jax.tree_util.KeyPath] = {} + all_variables: dict[jax.tree_util.KeyPath, variablelib.Variable] = {} for path, leaf in jax.tree.leaves_with_path(container, is_leaf=is_leaf): if not isinstance(leaf, variablelib.Variable): continue @@ -672,7 +689,7 @@ def check_no_aliases( assert isinstance(path[0], jax.tree_util.GetAttrKey) kwarg_name = path[0].name - if kwarg_name in check_can_update: + if kwarg_name in check: if not leaf._can_update: path_str = jax.tree_util.keystr(path) raise ValueError( @@ -703,6 +720,8 @@ def check_no_aliases( f' nnx.compat.{fn_name}(...)' ) seen[var_id] = path + all_variables[path] = leaf + return all_variables def check_prefix( @@ -711,8 +730,11 @@ def check_prefix( fn_name: str, graph: bool, graph_updates: bool, + none_leaf: bool = True, ): - def _check(path, leaf): + unique_prefixes: OrderedDict[tp.Any, tp.Any] = OrderedDict() + + def _check_prefix(path, leaf): if isinstance(leaf, variablelib.Variable): raise ValueError( f'Found Variable of type {type(leaf).__name__} ' @@ -766,14 +788,23 @@ def _check(path, leaf): raise ValueError(msg) jax.tree.map_with_path( - _check, + _check_prefix, prefix, - is_leaf=lambda x: isinstance(x, variablelib.Variable) + is_leaf=lambda x: x is None + or isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x) or isinstance(x, PrefixMapping) or isinstance(x, TreeState), ) + def _collect_prefix(_, leaf): + unique_prefixes[leaf] = leaf + + jax.tree.map_with_path( + _collect_prefix, prefix, is_leaf=lambda x: x is None and none_leaf + ) + return unique_prefixes + def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool: post_leaves, post_td = jax.tree.flatten(post) @@ -786,53 +817,160 @@ def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> b [PathParts, tp.Any, variablelib.Variable, variablelib.Variable], bool ] -def mask_variable_updates( + +class Updates( + tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]], + reprlib.Representable, +): + __slots__ = ('_keys', '_values') + + _keys: list[jax.tree_util.KeyPath] + _values: list[variablelib.Variable] + + def __init__( + self, + items: tp.Iterable[ + tuple[jax.tree_util.KeyPath, variablelib.Variable] + ] = (), + ): + self._keys, self._values = [], [] + for key, value in items: + self._keys.append(key) + self._values.append(value) + + def append(self, key: jax.tree_util.KeyPath, value: variablelib.Variable): + self._keys.append(key) + self._values.append(value) + + @property + def paths(self) -> list[jax.tree_util.KeyPath]: + return self._keys + + @property + def leaves(self) -> list[variablelib.Variable]: + return self._values + + @tp.overload + def __getitem__( + self, key: int + ) -> tuple[jax.tree_util.KeyPath, variablelib.Variable]: + ... + + @tp.overload + def __getitem__( + self, key: slice + ) -> tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]]: + ... + + @tp.overload # type: ignore[override] + def __getitem__(self, key: tuple[tp.Hashable, ...]) -> variablelib.Variable: + ... + + def __getitem__( + self, key: int | slice | jax.tree_util.KeyPath + ): + if isinstance(key, int): + return self._keys[key], self._values[key] + elif isinstance(key, slice): + raise NotImplementedError('Slicing is not supported for Updates.') + idx = self._keys.index(key) + return self._values[idx] + + def __len__(self): + return len(self._keys) + + def __iter__(self): + return iter(zip(self._keys, self._values)) + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self), kv_sep=': ', start='({', end='})') + for path, value in self: + yield reprlib.Attr( + jax.tree_util.keystr(path), + value, + use_raw_key=True, + ) + + +def _updates_flatten_with_keys(x: Updates): + key_children = [ + (jax.tree_util.FlattenedIndexKey(i), v) + for i, v in enumerate(x._values) + ] + return key_children, x._keys + + +def _updates_flatten(x: Updates): + return x._values, x._keys + + +def _updates_unflatten(keys, values) -> Updates: + updates = object.__new__(Updates) + updates._keys = keys + updates._values = list(values) + return updates + + +jax.tree_util.register_pytree_with_keys( + Updates, + _updates_flatten_with_keys, + _updates_unflatten, + flatten_func=_updates_flatten, +) + +def get_updates( current_tree: A, snapshot_tree: A, *, - prefix: tp.Any = Missing, + prefix: tp.Any = None, keep_fn: KeepFn | None = None, -) -> A: + known_prefixes: tp.Iterable[tp.Any] = (None,), +): if keep_fn is None: keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap) + updates = OrderedDict((pfx, Updates()) for pfx in known_prefixes) + def _mask_updates(path, prefix_leaf, current, snapshot): if isinstance(current, variablelib.Variable): if current.hijax or current.ref: - return Mask() + return if keep_fn(path, prefix_leaf, current, snapshot): - return current - return Mask() + updates[prefix_leaf].append(path, current) + prefix_leaf = lambda x: x is None is_leaf = lambda x: isinstance(x, variablelib.Variable) - if prefix is Missing: - return jax.tree.map_with_path( - lambda path, cur, snap: _mask_updates(path, None, cur, snap), - current_tree, snapshot_tree, is_leaf=is_leaf - ) - return broadcast_prefix_map( + broadcast_prefix_map( _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf, ) + return updates -def apply_variable_updates(args_tree: A, updates_tree: A): - is_leaf = lambda x: isinstance(x, variablelib.Variable) or isinstance(x, Mask) - args_leaves = jax.tree.leaves(args_tree, is_leaf=is_leaf) - _, treedef = jax.tree.flatten(args_tree, is_leaf=is_leaf) - updates_leaves = treedef.flatten_up_to(updates_tree) - for variable, update in zip(args_leaves, updates_leaves, strict=True): - if isinstance(update, variablelib.Variable): - assert isinstance(variable, variablelib.Variable) - variable.update_from_state(update) +def apply_updates( + variables: dict[jax.tree_util.KeyPath, variablelib.Variable], + updates: OrderedDict[tp.Any, Updates], +): + for _, flat_state in updates.items(): + for path, update in flat_state: + if path in variables: + variable = variables[path] + assert isinstance(variable, variablelib.Variable) + variable.update_from_state(update) + else: + path_str = jax.tree_util.keystr(path) + raise RuntimeError( + f'Variable not found at path {path_str}. This is a bug in NNX, ' + f'please report it. Variable: {update}' + ) -def treemap_copy_args(f): +def treemap_copy_args(f: F) -> F: @functools.wraps(f) def wrapper(*args, **kwargs): args, kwargs = jax.tree.map(lambda x: x, (args, kwargs)) return f(*args, **kwargs) - return wrapper + return wrapper # type: ignore[return-value] def check_same_variables(inputs, outputs, transform_name: str = ''): @@ -971,3 +1109,31 @@ def _apply_prefix(jax_path, leaf): return prefix_fn(path, leaf) return jax.tree.map_with_path(_apply_prefix, node, is_leaf=is_leaf) + +def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]): + combined: OrderedDict[tp.Any, tp.Any] = OrderedDict() + for updates in all_updates.values(): + combined.update(updates) + return jax.tree.map_with_path( + lambda path, _: combined.get(path, None), tree, + is_leaf=lambda x: x is None + ) + +def filter_kwargs(f, **kwargs): + sig = inspect.signature(f) + has_var_keyword = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if has_var_keyword: + return kwargs + named_params = { + name + for name, p in sig.parameters.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + filtered_kwargs = {k: v for k, v in kwargs.items() if k in named_params} + return filtered_kwargs diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 25c0e0ada..201b8026c 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -307,6 +307,7 @@ def dataclass( kw_only=kw_only, slots=slots, weakref_slot=weakref_slot, + repr=False, ) def _collect_stats( diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 8eac69412..222f9e439 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -25,6 +25,7 @@ graphlib, variablelib, ) +from flax.nnx.extract import labeled from flax.nnx.statelib import State import jax @@ -77,20 +78,14 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args, **kwargs): - updates, snapshot = extract.updates_and_snapshot((args, kwargs)) + current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases( - 'grad', - args=updates[0], - kwargs=updates[1], - out=out, - check_can_update=['out'], - ) - updates = extract.mask_variable_updates(updates, snapshot) + extract.check_no_aliases('grad', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) if self.has_aux: loss, aux = out @@ -174,7 +169,7 @@ def tree_grad_wrapper(*args, **kwargs): (args, kwargs), prefix=(args_prefix, False), ) - extract.check_no_aliases('grad', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('grad', args=args, kwargs=kwargs) fn_out = gradded_fn(*args, **kwargs) @@ -197,7 +192,7 @@ def tree_grad_wrapper(*args, **kwargs): if graph: grads = extract.from_tree2(grads) result = grads - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) return result return tree_grad_wrapper @@ -569,17 +564,15 @@ def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args - def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + def __call__(self, *primals): + current, snapshot = extract.snapshot(labeled(primals=primals)) if self.graph: - args = extract.from_tree2(args) - out = self.f(*args) + primals = extract.from_tree2(primals) + out = self.f(*primals) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases( - 'vjp', args=updates, out=out, check_can_update=['out'] - ) - updates = extract.mask_variable_updates(updates, snapshot) + extract.check_no_aliases('vjp', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) if self.has_aux: primals_out, aux = out return primals_out, (updates, aux) @@ -699,7 +692,7 @@ def vjp( if graph: primals = extract.to_tree2(primals) - extract.check_no_aliases('vjp', primals=primals) + variables = extract.check_no_aliases('vjp', primals=primals) primals_out, vjp_fn, aux = jax.vjp( SimpleVjpFn(f_unbound, has_aux=has_aux, graph=graph), *primals, @@ -715,7 +708,7 @@ def vjp( raw_vjp_fn = vjp_fn def vjp_fn(g): return extract.from_tree2(raw_vjp_fn(g)) - extract.apply_variable_updates(primals, updates) + extract.apply_updates(variables, updates) if has_aux: return primals_out, vjp_fn, user_aux else: @@ -737,17 +730,15 @@ def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args - def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + def __call__(self, *primals): + current, snapshot = extract.snapshot(labeled(primals=primals)) if self.graph: - args = extract.from_tree2(args) - out = self.f(*args) + primals = extract.from_tree2(primals) + out = self.f(*primals) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases( - 'jvp', args=updates, out=out, check_can_update=['out'] - ) - updates = extract.mask_variable_updates(updates, snapshot) + extract.check_no_aliases('jvp', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) if self.has_aux: primals_out, aux = out return (primals_out, updates), aux @@ -876,7 +867,7 @@ def jvp( if graph: primals = extract.to_tree2(primals) tangents = extract.to_tree2(tangents) - extract.check_no_aliases('jvp', primals=primals) + variables = extract.check_no_aliases('jvp', primals=primals) extract.check_no_aliases('jvp', tangents=tangents) if has_aux: (primals_out, updates), (tangent_out, _updates_tangent), aux = jax.jvp( @@ -894,7 +885,7 @@ def jvp( if graph: primals_out = extract.from_tree2(primals_out) tangent_out = extract.from_tree2(tangent_out) - extract.apply_variable_updates(primals, updates) + extract.apply_updates(variables, updates) if has_aux: return primals_out, tangent_out, aux else: @@ -910,41 +901,20 @@ def jvp( class SimpleCustomVjpFn: f: tp.Callable[..., tp.Any] graph: bool - nondiff_argnums: tuple[int, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + current, snapshot = extract.snapshot(labeled(args=args)) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases( - 'custom_vjp', args=updates, out=out, check_can_update=['out'] - ) - diff_prefix = tuple( - i not in self.nondiff_argnums for i in range(len(args)) - ) - def keep_fn(path, diff_arg, cur, snap): - assert isinstance(diff_arg, bool) - changed = extract.variable_changed(cur, snap) - if diff_arg and changed: - raise ValueError( - f'Variables in differentiable argument were mutated inside ' - f'custom_vjp at {jax.tree_util.keystr(path)}.\n' - f'This is not supported when ' - f'graph_updates=False because the gradient for the Variable ' - f'updates would be silently dropped. Move the Variable mutation ' - f'to a non-differentiable argument, or use graph_updates=True.' - ) - return changed - updates = extract.mask_variable_updates( - updates, snapshot, prefix=diff_prefix, keep_fn=keep_fn, - ) + extract.check_no_aliases('custom_vjp', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) return out, updates @@ -958,18 +928,20 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + current, snapshot = extract.snapshot(labeled(args=args)) if self.graph: args = extract.from_tree2(args) out, residual = self.fwd(*args) if self.graph: out = extract.to_tree2(out) residual = extract.to_tree2(residual) - extract.check_no_aliases( - 'custom_vjp', args=updates, out=out, check_can_update=['out'] + extract.check_no_aliases('custom_vjp', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) + masked_args = jax.tree.map( + lambda _: None, current, + is_leaf=lambda x: isinstance(x, variablelib.Variable), ) - updates = extract.mask_variable_updates(updates, snapshot) - return (out, updates), residual + return (out, updates), (residual, masked_args) @dataclasses.dataclass(eq=False) @@ -982,12 +954,15 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args): - *nondiff, residual, (out_g, _updates_g) = args + *nondiff, (residual, masked_args), (out_g, updates_g) = args + updates_g = extract.to_masked(masked_args, updates_g).args if self.graph: nondiff = extract.from_tree2(nondiff) residual = extract.from_tree2(residual) out_g = extract.from_tree2(out_g) - result = self.bwd(*nondiff, residual, out_g) + updates_g = extract.from_tree2(updates_g, recreate_variables=False) + kwargs = extract.filter_kwargs(self.bwd, updates_g=updates_g) + result = self.bwd(*nondiff, residual, out_g, **kwargs) if self.graph: result = extract.to_tree2(result) return result @@ -1005,7 +980,7 @@ def __init__( self.nondiff_argnums = nondiff_argnums self.graph = graph self.custom_vjp_fn = jax.custom_vjp( - fun=SimpleCustomVjpFn(fun, graph=graph, nondiff_argnums=nondiff_argnums), + fun=SimpleCustomVjpFn(fun, graph=graph), nondiff_argnums=nondiff_argnums, ) @@ -1015,15 +990,13 @@ def __call__( args = resolve_kwargs(self.fun, args, kwargs) del kwargs if self.graph: - prefix = tuple( - i not in self.nondiff_argnums for i in range(len(args)) - ) - args = extract.to_tree2(args, prefix=prefix) - extract.check_no_aliases('custom_vjp', args=args) + extract.check_no_aliases('custom_vjp', args=args) + args = extract.to_tree2(args) + variables = extract.check_no_aliases('custom_vjp', args=args) (out, updates) = self.custom_vjp_fn(*args) if self.graph: out = extract.from_tree2(out) - extract.apply_variable_updates(args, updates) + extract.apply_updates(variables, updates) return out def defvjp( @@ -1493,10 +1466,7 @@ def custom_vjp( ``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and tangent types are the same as the input types, this means the tangent for a Module is a Module instance with gradient values set on its attributes. - This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally, - Variables in differentiable arguments cannot be mutated inside ``f``. If - mutations are needed, pass the relevant Variables through a non-differentiable - argument instead. + This mode does not support ``DiffState`` in ``nondiff_argnums``. Example:: @@ -1516,6 +1486,45 @@ def custom_vjp( ... >>> f.defvjp(f_fwd, f_bwd) + **updates_g** + + When Variables are mutated inside ``f`` or ``f_fwd``, NNX tracks these + updates and propagates their gradients. The ``bwd`` function can optionally + receive these gradients by declaring an ``updates_g`` keyword argument. + ``updates_g`` is a pytree with the same structure as the input ``args``, + where mutated Variables have their gradient values and all other leaves are + ``None``. + + Example:: + + >>> class Bar(nnx.Module): + ... def __init__(self, x, y, z): + ... self.x = nnx.Param(x) + ... self.y = nnx.Param(y) + ... self.z = nnx.BatchStat(z) + ... + >>> @nnx.custom_vjp(graph_updates=False) + ... def f(m: Bar): + ... m.z[...] *= 2.0 # mutation tracked as an update + ... return jnp.sin(m.x) * m.y + ... + >>> def f_fwd(m: Bar): + ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) + ... + >>> def f_bwd(res, g, *, updates_g): + ... cos_x, sin_x, m = res + ... # updates_g is a tuple matching args: (m_updates_g,) + ... # m_updates_g.z contains the gradient for the z mutation + ... m_g = nnx.clone(m) + ... m_g.x[...] = cos_x * g * m.y + ... m_g.y[...] = sin_x * g + ... return (m_g,) + ... + >>> f.defvjp(f_fwd, f_bwd) + + If ``bwd`` does not declare ``updates_g``, the update gradients are + silently discarded. + Args: fun: Callable base function. nondiff_argnums: Tuple of integers or DiffState objects specifying the @@ -1575,14 +1584,14 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args, **kwargs): - updates, snapshot = extract.updates_and_snapshot((args, kwargs)) + current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('remat', args=updates[0], kwargs=updates[1], out=out) - updates = extract.mask_variable_updates(updates, snapshot) + extract.check_no_aliases('remat', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) return out, updates @tp.overload @@ -1675,11 +1684,11 @@ def remat( def simple_remat_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2((args, kwargs)) - extract.check_no_aliases('remat', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('remat', args=args, kwargs=kwargs) out, updates = checkpointed_fn(*args, **kwargs) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) return out return simple_remat_wrapper # type: ignore[return-value] diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 0f0e0787f..52ec0205a 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -30,6 +30,7 @@ statelib, variablelib, ) +from flax.nnx.extract import labeled from flax.nnx.transforms.transforms import ( _resolve_bound_callable, _raise_bound_method_error, @@ -148,6 +149,12 @@ def __call__(self, *pure_args, **pure_kwargs): @tp.overload + + + + + + def jit( *, in_shardings: tp.Any = None, @@ -166,6 +173,12 @@ def jit( @tp.overload + + + + + + def jit( fun: tp.Callable[P, R], *, @@ -367,20 +380,24 @@ def jit( if was_bound: _raise_bound_method_error('jit') - if in_shardings is not None: - extract.check_prefix( + update_shardings = extract.check_prefix( in_shardings, 'in_shardings', 'jit', graph, graph_updates - ) - if out_shardings is not None: - extract.check_prefix( + ) + update_shardings[None] = None # kwargs sharding + extract.check_prefix( out_shardings, 'out_shardings', 'jit', graph, graph_updates - ) + ) wrapped_cls: tp.Any if graph and graph_updates: wrapped_cls = JitWrapped else: - wrapped_cls = functools.partial(SimpleJitWrapped, graph=graph) + wrapped_cls = functools.partial( + SimpleJitWrapped, + partial_args=(), + graph=graph, + update_shardings=update_shardings, + ) return wrapped_cls( fun_unbound, in_shardings=in_shardings, @@ -430,42 +447,40 @@ def _flatten_to_partial_state( @dataclasses.dataclass(eq=False) class SimpleJitFn: f: tp.Callable[..., tp.Any] + in_shardings: tp.Any out_shardings: tp.Any donate_argnums: frozenset[int] donate_argnames: frozenset[str] graph: bool + update_shardings: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): - updates, snapshot = extract.updates_and_snapshot((args, kwargs)) - args_updates, kwargs_updates = updates - args_snapshot, kwargs_snapshot = snapshot + current, snapshot = extract.snapshot( + labeled(args=args, kwargs=kwargs) + ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) - extract.check_no_aliases( - 'jit', - args=args_updates, - kwargs=kwargs_updates, - out=out, - check_can_update=['out'], + extract.check_no_aliases('jit', **current, out=out, check=['out']) + def keep_fn(jax_path, prefix, c, s): + if extract.variable_changed(c, s): + return True + arg_type, arg_key, *_ = graphlib.jax_to_nnx_path(jax_path) + if arg_type == 'args': + return arg_key in self.donate_argnums + else: # arg_type == 'kwargs': + return arg_key in self.donate_argnames + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_shardings, kwargs=None), + known_prefixes=self.update_shardings, keep_fn=keep_fn ) - def donated_arg(jax_path, prefix, c, s): - path = graphlib.jax_to_nnx_path(jax_path) - return path[0] in self.donate_argnums or extract.variable_changed(c, s) - args_updates = extract.mask_variable_updates( - args_updates, args_snapshot, keep_fn=donated_arg) - def donated_kwarg(jax_path, prefix, c, s): - path = graphlib.jax_to_nnx_path(jax_path) - return path[0] in self.donate_argnames or extract.variable_changed(c, s) - kwargs_updates = extract.mask_variable_updates( - kwargs_updates, kwargs_snapshot, keep_fn=donated_kwarg) - return out, (args_updates, kwargs_updates) + return out, updates class SimpleJitWrapped(tp.Generic[P, R]): @@ -475,16 +490,17 @@ def __init__( fun: tp.Callable[P, R], in_shardings: tp.Any, out_shardings: tp.Any, - static_argnums: int | tp.Sequence[int] | None = None, - static_argnames: str | tp.Iterable[str] | None = None, - donate_argnums: int | tp.Sequence[int] | None = None, - donate_argnames: str | tp.Iterable[str] | None = None, - keep_unused: bool = False, - device: tp.Optional[jax.Device] = None, - backend: tp.Optional[str] = None, - inline: bool = False, - partial_args: tuple[PartialState, ...] = (), - graph: bool = True, + static_argnums: int | tp.Sequence[int] | None, + static_argnames: str | tp.Iterable[str] | None, + donate_argnums: int | tp.Sequence[int] | None, + donate_argnames: str | tp.Iterable[str] | None, + keep_unused: bool, + device: tp.Optional[jax.Device], + backend: tp.Optional[str], + inline: bool, + partial_args: tuple[PartialState, ...], + graph: bool, + update_shardings: extract.OrderedDict, ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun @@ -492,10 +508,8 @@ def __init__( self.partial_args = partial_args self.graph = graph - if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and ( - static_argnums or static_argnames - ): - resolved = _resolve_argnums(fun, static_argnums, static_argnames) + resolved = _resolve_argnums(fun, static_argnums, static_argnames) + if isinstance(in_shardings, (tuple, list)) and resolved: expanded = list(in_shardings) for i in sorted(resolved): expanded.insert(i, None) @@ -503,22 +517,6 @@ def __init__( else: self.in_shardings = in_shardings - jit_out_shardings: tp.Any - if in_shardings is not None or out_shardings is not None: - if isinstance(in_shardings, (tuple, list)) and ( - static_argnums or static_argnames - ): - resolved = _resolve_argnums(fun, static_argnums, static_argnames) - expanded = list(in_shardings) - for i in sorted(resolved): - expanded.insert(i, None) - out_in_shardings = tuple(expanded) - else: - out_in_shardings = in_shardings - jit_out_shardings = (out_shardings, (out_in_shardings, None)) - else: - jit_out_shardings = None - donate_argnums_set = frozenset( (donate_argnums,) if isinstance(donate_argnums, int) else donate_argnums or () @@ -528,9 +526,17 @@ def __init__( else donate_argnames or () ) self.jitted_fn = jax.jit( - SimpleJitFn(fun, out_shardings, donate_argnums_set, donate_argnames_set, graph), + SimpleJitFn( + fun, + self.in_shardings, + out_shardings, + donate_argnums_set, + donate_argnames_set, + graph, + tuple(update_shardings), + ), in_shardings=in_shardings, - out_shardings=jit_out_shardings, + out_shardings=(out_shardings, update_shardings), static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, @@ -560,10 +566,9 @@ def _maybe_from_tree(self, out): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args = (*self.partial_args, *args) # type: ignore[assignment] args, kwargs = self._maybe_to_tree(args, kwargs) - if not self.graph: # skip check for graph mode - extract.check_no_aliases('jit', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.jitted_fn(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) return self._maybe_from_tree(out) def __get__(self, obj, objtype=None): @@ -599,6 +604,8 @@ def jit_partial( *partial_args: tp.Any, in_shardings: tp.Any = None, out_shardings: tp.Any = None, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, @@ -674,6 +681,10 @@ def jit_partial( raise ValueError( '`graph_updates` not supported by `jit_partial`' ) + update_shardings = extract.check_prefix( + in_shardings, 'in_shardings', 'jit_partial', graph, graph_updates + ) + update_shardings[None] = None # kwargs sharding if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' @@ -734,6 +745,8 @@ def _unflatten(arg): wrapped_fun, in_shardings=jit_in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, @@ -742,6 +755,7 @@ def _unflatten(arg): inline=inline, partial_args=flat_partial_args, graph=graph, + update_shardings=update_shardings, ) @@ -1164,10 +1178,9 @@ def call(*args, **kwargs): def __call__(self, *args, **kwargs): args = (*self.jit_wrapped.partial_args, *args) args, kwargs = self.jit_wrapped._maybe_to_tree(args, kwargs) - if not self.jit_wrapped.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.compiled(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) return self.jit_wrapped._maybe_from_tree(out) @property @@ -1265,23 +1278,26 @@ def lower( class SimpleShardMapFn: f: tp.Callable[..., tp.Any] graph: bool + in_specs: tp.Any out_specs: tp.Any + update_specs: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + current, snapshot = extract.snapshot(labeled(args=args)) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_specs) - extract.check_no_aliases( - 'shard_map', args=updates, out=out, check_can_update=['out'] + extract.check_no_aliases('shard_map', **current, out=out, check=['out']) + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_specs), + known_prefixes=self.update_specs ) - updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -1535,9 +1551,10 @@ def f(m, x): if was_bound: _raise_bound_method_error('shard_map') - extract.check_prefix( - in_specs, 'in_specs', 'shard_map', graph, graph_updates + update_specs = extract.check_prefix( + in_specs, 'in_specs', 'shard_map', graph, graph_updates, none_leaf=False ) + assert None not in update_specs extract.check_prefix( out_specs, 'out_specs', 'shard_map', graph, graph_updates ) @@ -1545,10 +1562,16 @@ def f(m, x): if not (graph and graph_updates): shard_map_fn = jax.shard_map( - SimpleShardMapFn(f_unbound, graph=graph, out_specs=out_specs), + SimpleShardMapFn( + f_unbound, + graph=graph, + in_specs=in_specs, + out_specs=out_specs, + update_specs=tuple(update_specs), + ), mesh=mesh, in_specs=in_specs, - out_specs=(out_specs, in_specs), + out_specs=(out_specs, update_specs), axis_names=axis_names, check_vma=check_vma, ) @@ -1561,9 +1584,9 @@ def shard_map_wrapper(*args, **kwargs): prefix=in_specs, check_aliasing=in_specs is not None, ) - extract.check_no_aliases('shard_map', args=args) + variables = extract.check_no_aliases('shard_map', args=args) out, updates = shard_map_fn(*args, **kwargs) - extract.apply_variable_updates(args, updates) + extract.apply_updates(variables, updates) if graph: out = extract.from_tree2(out) return out diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index e39e8a5a2..a13516600 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -23,6 +23,7 @@ from flax import typing from flax.core.frozen_dict import FrozenDict from flax.nnx import extract, filterlib, graphlib, spmd, variablelib +from flax.nnx.extract import labeled from flax.nnx import statelib from flax.nnx.module import Module from flax.nnx.statelib import State @@ -127,26 +128,26 @@ def transform_metadata( extract.check_prefix(out_axes, 'out_axes', 'transform_metadata', graph, True) @functools.wraps(f) - def wrapper(*in_args, **in_kwargs): - in_args = resolve_kwargs(f, in_args, in_kwargs) + def wrapper(*args, **kwargs): + args = resolve_kwargs(f, args, kwargs) if graph: - in_args = extract.to_tree2(in_args, prefix=in_axes) - extract.check_no_aliases('transform_metadata', args=in_args) - args = graphlib.clone(in_args, graph=graph) + args = extract.to_tree2(args, prefix=in_axes) + variables = extract.check_no_aliases('transform_metadata', args=args) + args = graphlib.clone(args, graph=graph) _apply_axis_fn(args, in_axes, metadata, spmd.remove_axis) - updates, snapshot = extract.updates_and_snapshot(args) + current, snapshot = extract.snapshot(labeled(args=args)) if graph: args = extract.from_tree2(args) out = f(*args) if graph: out = extract.to_tree2(out, prefix=out_axes) extract.check_no_aliases( - 'transform_metadata', args=updates, out=out, check_can_update=['out'] + 'transform_metadata', **current, out=out, check=['out'] ) _apply_axis_fn(args, in_axes, metadata, spmd.add_axis) _apply_axis_fn(out, out_axes, metadata, spmd.add_axis) - updates = extract.mask_variable_updates(updates, snapshot) - extract.apply_variable_updates(in_args, updates) + updates = extract.get_updates(current, snapshot) + extract.apply_updates(variables, updates) if graph: out = extract.from_tree2(out) return out @@ -269,27 +270,29 @@ def _vmap_split_fn(ctx: graphlib.SplitContext, path, prefix, x): class SimpleVmapFn: f: tp.Callable[..., tp.Any] graph: bool + in_axes: tp.Any out_axes: tp.Any + update_axes: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) - @extract.treemap_copy_args def __call__(self, *args, **kwargs): - updates, snapshot = extract.updates_and_snapshot((args, kwargs)) + current, snapshot = extract.snapshot( + labeled(args=args, kwargs=kwargs) + ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases( - 'vmap', - args=updates[0], - kwargs=updates[1], - out=out, - check_can_update=['out'], + 'vmap', **current, out=out, check=['out'], + ) + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_axes, kwargs=0), + known_prefixes=self.update_axes ) - updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -297,27 +300,30 @@ def __call__(self, *args, **kwargs): class SimplePmapFn: f: tp.Callable[..., tp.Any] graph: bool + in_axes: tp.Any out_axes: tp.Any + update_axes: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): - updates, snapshot = extract.updates_and_snapshot((args, kwargs)) + current, snapshot = extract.snapshot( + labeled(args=args, kwargs=kwargs) + ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases( - 'pmap', - args=updates[0], - kwargs=updates[1], - out=out, - check_can_update=['out'], + 'pmap', **current, out=out, check=['out'], + ) + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_axes, kwargs=0), + known_prefixes=self.update_axes ) - updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -510,15 +516,22 @@ def vmap( if was_bound: _raise_bound_method_error('vmap') - extract.check_prefix(in_axes, 'in_axes', 'vmap', graph, graph_updates) + update_axes = extract.check_prefix(in_axes, 'in_axes', 'vmap', graph, graph_updates) + update_axes[0] = 0 # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'vmap', graph, graph_updates) if not (graph and graph_updates): vmapped_fn = jax.vmap( - SimpleVmapFn(f_unbound, graph=graph, out_axes=out_axes), + SimpleVmapFn( + f_unbound, + graph=graph, + in_axes=in_axes, + out_axes=out_axes, + update_axes=tuple(update_axes), + ), in_axes=in_axes, - out_axes=(out_axes, (in_axes, 0)), + out_axes=(out_axes, update_axes), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, @@ -529,21 +542,17 @@ def simple_vmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), - prefix=(in_axes, None) - if in_axes is not None - else None, - check_aliasing=in_axes is not None, + prefix=(in_axes, 0), ) - extract.check_no_aliases('vmap', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('vmap', args=args, kwargs=kwargs) out, updates = vmapped_fn(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) if graph: out = extract.from_tree2(out) return out return simple_vmap_wrapper # type: ignore[return-value] - jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) @@ -761,16 +770,23 @@ def pmap( if was_bound: _raise_bound_method_error('pmap') - extract.check_prefix(in_axes, 'in_axes', 'pmap', graph, graph_updates) + update_axes = extract.check_prefix(in_axes, 'in_axes', 'pmap', graph, graph_updates) + update_axes[0] = 0 # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'pmap', graph, graph_updates) if not (graph and graph_updates): pmapped_fn = jax.pmap( - SimplePmapFn(f_unbound, graph=graph, out_axes=out_axes), + SimplePmapFn( + f_unbound, + graph=graph, + in_axes=in_axes, + out_axes=out_axes, + update_axes=tuple(update_axes), + ), axis_name=axis_name, in_axes=in_axes, - out_axes=(out_axes, (in_axes, 0)), + out_axes=(out_axes, update_axes), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, @@ -783,14 +799,11 @@ def simple_pmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), - prefix=(in_axes, None) - if in_axes is not None - else None, - check_aliasing=in_axes is not None, + prefix=(in_axes, 0), ) - extract.check_no_aliases('pmap', args=args, kwargs=kwargs) + variables = extract.check_no_aliases('pmap', args=args, kwargs=kwargs) out, updates = pmapped_fn(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_updates(variables, updates) if graph: out = extract.from_tree2(out) return out @@ -1394,61 +1407,50 @@ class SimpleScanFn: in_axes: tp.Any out_axes: tp.Any out_is_tuple: bool - carry_arg_index: int | None - carry_out_index: int | None + carry_idx: int | None + carry_out_idx: int | None + update_axes: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args - def __call__(self, full_carry: tp.Any, x_args: tp.Any): + def __call__(self, full_carry: tp.Any, args: tp.Any): carry, broadcasts = full_carry - updates, snapshot = extract.updates_and_snapshot(x_args) - x_args = extract.insert( - x_args, - broadcasts, - is_leaf=lambda x: isinstance(x, variablelib.Variable), - ) - + carry_in = extract.copy_var_structure(carry) + args = extract.insert(args, broadcasts) + current, snapshot = extract.snapshot(labeled(args=args)) if self.graph: - x_args = extract.from_tree2(x_args) + args = extract.from_tree2(args) carry = extract.from_tree2(carry) - - # Reconstruct full args - if self.carry_arg_index is not None: - args = extract.insert_at(x_args, self.carry_arg_index, carry) - else: - args = x_args + args = extract.replace_at(args, self.carry_idx, carry) out = self.f(*args) - if self.graph: - # check consistent aliasing, temporarily convert `out` to tree - # to check aliasing, but the real tree convertion is done later - check_out = extract.to_tree2(out, prefix=self.out_axes) - else: - check_out = out - - extract.check_no_aliases( - 'scan', args=updates, out=check_out, check_can_update=['out'] - ) - updates = extract.mask_variable_updates(updates, snapshot) - if self.carry_arg_index is not None: # has carry - if self.out_is_tuple: - carry_out, ys = extract.slice_at(out, self.carry_out_index) - else: - carry_out = out - ys = None - extract.check_same_variables(carry, carry_out, 'scan') - else: - ys = out + if self.carry_idx is None: # has carry carry_out = None + ys = out + elif self.out_is_tuple: + carry_out, ys = extract.mask_at(out, self.carry_out_idx) + else: + carry_out = out + ys = None if self.graph: - # convert the carry to tree separately to ensure a consistent - # graph structure for the carry in and carry out + ys = extract.to_tree2(ys, prefix=self.out_axes) carry_out = extract.to_tree2(carry_out) - ys = extract.to_tree2(ys) + + extract.check_same_variables(carry_in, carry_out, 'scan') + extract.check_no_aliases( + 'scan', **current, carry=carry_out, out=ys, check=['out'] + ) + updates = extract.get_updates( + current, snapshot, + prefix=labeled(args=self.in_axes), + known_prefixes=self.update_axes, + keep_fn=lambda _, prefix, cur, snap: isinstance(prefix, int) + and extract.variable_changed(cur, snap), + ) return (carry_out, broadcasts), (ys, updates) @@ -1624,8 +1626,8 @@ def forward(x, model): if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() - extract.check_prefix(in_axes, 'in_axes', 'scan', graph, graph_updates) extract.check_prefix(out_axes, 'out_axes', 'scan', graph, graph_updates) + updates_axes = extract.check_prefix(in_axes, 'in_axes', 'scan', graph, graph_updates) _check_out_axes(out_axes) if not graph or not graph_updates: @@ -1634,6 +1636,7 @@ def forward(x, model): in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, + updates_axes=updates_axes, ) return _graph_updates_scan( @@ -1659,108 +1662,84 @@ def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, + updates_axes: extract.OrderedDict, ): _validate_scan_axes(in_axes, out_axes) + # None and Carry aren't valid update axes + updates_axes.pop(None, None) + updates_axes.pop(Carry, None) out_is_tuple = isinstance(out_axes, tuple) was_carry = in_axes is Carry if in_axes is Carry: in_axes = (Carry,) - if isinstance(in_axes, tuple): - carry_arg_index = extract.find(in_axes, Carry) - _, sliced_in_axes = extract.slice_at(in_axes, carry_arg_index) - else: - carry_arg_index = None - sliced_in_axes = in_axes - - if isinstance(out_axes, tuple): - carry_out_index = extract.find(out_axes, Carry) - _, sliced_out_axes = extract.slice_at(out_axes, carry_out_index) - else: - carry_out_index = None - sliced_out_axes = out_axes + carry_idx = extract.find(in_axes, Carry) + carry_out_idx = extract.find(out_axes, Carry) simple_scan_fn = SimpleScanFn( f_unbound, graph=graph, in_axes=in_axes, out_axes=out_axes, out_is_tuple=out_is_tuple, - carry_arg_index=carry_arg_index, - carry_out_index=carry_out_index, + carry_idx=carry_idx, + carry_out_idx=carry_out_idx, + update_axes=tuple(updates_axes), ) @functools.wraps(f) def simple_scan_wrapper(*args): - args = resolve_kwargs(f, args, {}) if was_carry and len(args) != 1: raise ValueError( 'When in_axes=Carry, the function must take exactly one argument, ' f'got {len(args)} arguments.' ) - if graph: - # check consistent aliasing, temporarily convert args to tree - # to check aliasing, but the real tree convertion is done later - check_args = extract.to_tree2(args, prefix=in_axes) - else: - check_args = args - - extract.check_no_aliases('scan', args=check_args) - carry, x_args = extract.slice_at(args, carry_arg_index) + if graph: # check consistent aliasing + extract.to_tree2(args, prefix=in_axes) + carry, args = extract.mask_at(args, carry_idx) if graph: - # convert the carry to tree separately to ensure a consistent - # graph structure for the carry in and carry out + args = extract.to_tree2(args, prefix=in_axes) carry = extract.to_tree2(carry) - x_args = extract.to_tree2(x_args) - - def extract_broadcasts(path, prefix_leaf, leaf): - return leaf is not None and ( - prefix_leaf is None - or ( - isinstance(prefix_leaf, variablelib.Variable) - and prefix_leaf.get_value() is None - ) - ) - x_args, broadcasts = extract.extract( - extract_broadcasts, sliced_in_axes, x_args, - is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), + variables = extract.check_no_aliases('scan', args=args, carry=carry) + args, broadcasts = extract.extract( + lambda _, axes, x: axes is None, + in_axes, args, + is_leaf=lambda x: isinstance(x, variablelib.Variable), + prefix_leaf=lambda x: x is None, ) - - x_args_transposed = _move_axis( + args_t = _move_axis( lambda ax, leaf: jnp.moveaxis(leaf, ax, 0), - sliced_in_axes, x_args, + in_axes, args, ) - (carry_out, final_broadcasts), (ys, updates) = jax.lax.scan( simple_scan_fn, (carry, broadcasts), - x_args_transposed, + args_t, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) - ys, updates = _move_axis( lambda ax, leaf: jnp.moveaxis(leaf, 0, ax), - (sliced_out_axes, sliced_in_axes), + (out_axes, updates_axes), (ys, updates), ) - - extract.apply_variable_updates(x_args, updates) - extract.apply_variable_updates(broadcasts, final_broadcasts) carry = extract.update_carry_variables(carry, carry_out) + extract.apply_updates(variables, updates) + for broadcast, update in zip(broadcasts, final_broadcasts, strict=True): + if isinstance(broadcast, variablelib.Variable): + broadcast.update_from_state(update) if graph: ys = extract.from_tree2(ys) carry = extract.from_tree2(carry) - if carry_arg_index is not None: - if out_is_tuple: - out = extract.insert_at(ys, carry_out_index, carry) - else: - out = carry - else: + if carry_idx is None: out = ys + elif out_is_tuple: + out = extract.replace_at(ys, carry_out_idx, carry) + else: + out = carry return out @@ -1907,13 +1886,13 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, val): - val_variables, _ = extract.updates_and_snapshot(val) + val_in = extract.copy_var_structure(val) if self.graph: val = extract.from_tree2(val) out = self.f(val) if self.graph: out = extract.to_tree2(out) - extract.check_same_variables(val_variables, out, 'while_loop') + extract.check_same_variables(val_in, out, 'while_loop') return out @@ -2104,13 +2083,13 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, i, val): - val_variables, _ = extract.updates_and_snapshot(val) + val_in = extract.copy_var_structure(val) if self.graph: val = extract.from_tree2(val) out = self.f(i, val) if self.graph: out = extract.to_tree2(out) - extract.check_same_variables(val_variables, out, 'fori_loop') + extract.check_same_variables(val_in, out, 'fori_loop') return out diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index c298e812c..e919b6b66 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -27,6 +27,7 @@ graphlib, variablelib, ) +from flax.nnx.extract import labeled from flax.nnx.module import Module from flax.nnx.proxy_caller import ( CallableProxy, @@ -363,14 +364,14 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) + current, snapshot = extract.snapshot(labeled(args=args)) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('checkify', args=updates, out=out) - updates = extract.mask_variable_updates(updates, snapshot) + extract.check_no_aliases('checkify', **current, out=out, check=['out']) + updates = extract.get_updates(current, snapshot) return out, updates def checkify( @@ -434,11 +435,11 @@ def checkify( def simple_checkify_wrapper(*args): if graph: args = extract.to_tree2(args) - extract.check_no_aliases('checkify', args=args) + variables = extract.check_no_aliases('checkify', args=args) error, (out, updates) = checkify_fn(*args) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(args, updates) + extract.apply_updates(variables, updates) return error, out return simple_checkify_wrapper # type: ignore @@ -570,14 +571,17 @@ def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args - def __call__(self, *args): - updates, _snapshot = extract.updates_and_snapshot(args) + def __call__(self, *operands): + current, snapshot = extract.snapshot(labeled(operands=operands)) if self.graph: - args = extract.from_tree2(args) - out = self.f(*args) + operands = extract.from_tree2(operands) + out = self.f(*operands) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('switch', args=updates, out=out) + extract.check_no_aliases('switch', **current, out=out, check=['out']) + updates = extract.get_updates( + current, snapshot, keep_fn=lambda *_: True + ) return out, updates @@ -615,7 +619,7 @@ def cond( if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) - extract.check_no_aliases('cond', operands=operands) + variables = extract.check_no_aliases('cond', operands=operands) out, updates = jax.lax.cond( pred, SimpleCondFn(true_fun, graph=graph), @@ -624,7 +628,7 @@ def cond( ) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(operands, updates) + extract.apply_updates(variables, updates) return out @general.split_inputs(ctxtag='cond') @@ -670,7 +674,7 @@ def switch( if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) - extract.check_no_aliases('switch', operands=operands) + variables = extract.check_no_aliases('switch', operands=operands) out, updates = jax.lax.switch( index, [SimpleCondFn(f, graph=graph) for f in branches], @@ -678,7 +682,7 @@ def switch( ) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(operands, updates) + extract.apply_updates(variables, updates) return out @general.split_inputs(ctxtag='switch') diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index f1f5867e8..a1795caf0 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1220,7 +1220,7 @@ def f(a, b): return a[...] + b[...] mesh = jax.sharding.Mesh(jax.devices(), ('x',)) - with mesh: + with jax.set_mesh(mesh): with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) @@ -2309,16 +2309,23 @@ def f_bwd(res, g): self.assertEqual(m.z, 1) @parameterized.parameters(True, False) + @nnx.set_graph_updates(False) def test_jax_example_functional(self, graph): - @dataclasses.dataclass + @nnx.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] - z: int + z: nnx.Variable[jax.Array] - @nnx.custom_vjp(graph=graph, graph_updates=False) + m = Foo( + nnx.Param(jnp.array(1.0)), + nnx.Param(jnp.array(2.0)), + nnx.Variable(jnp.array(0)), + ) + + @nnx.custom_vjp(graph=graph) def f(m: Foo): - m.z += 1 + m.z[...] += 1 return jnp.sin(m.x) * m.y # type: ignore def f_fwd(m: Foo): @@ -2326,23 +2333,35 @@ def f_fwd(m: Foo): res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore return y, res - def f_bwd(res, g): + def f_bwd(res, g, updates_g): + (m_up_g,) = updates_g + self.assertIsInstance(updates_g, tuple) + self.assertLen(jax.tree.leaves(m_up_g), 1) + self.assertIsNone(m_up_g.x) + self.assertIsNone(m_up_g.y) + self.assertIsInstance(m_up_g.z, nnx.Variable) + cos_x, sin_x, m = res - out_g = g m_g = nnx.clone(m) - m_g.x[...] = cos_x * out_g * m.y - m_g.y[...] = sin_x * out_g + m_g.x[...] = cos_x * g * m.y + m_g.y[...] = sin_x * g return (m_g,) f.defvjp(f_fwd, f_bwd) - m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + params, nondiff = nnx.unpack(m, nnx.Param, ...) - grads = nnx.grad(f, graph=graph, graph_updates=False)(m) - self.assertIsInstance(grads, Foo) + @nnx.jit(graph=graph) + @nnx.grad(graph=graph) + def grad_fn(params, nondiff): + m = nnx.merge(params, nondiff) + return f(m) + + grads = grad_fn(params, nondiff) + self.assertIsInstance(grads, nnx.GraphState) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore - self.assertEqual(m.z, 0) + self.assertEqual(m.z, 1) def test_diff_state(self): @dataclasses.dataclass @@ -2998,10 +3017,11 @@ def f_bwd(v_nondiff, res, g): f.defvjp(f_fwd, f_bwd) - with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): + with self.assertRaisesRegex(ValueError, 'Duplicate'): f(v, v) - def test_custom_vjp_diff_arg_mutation_error(self): + def test_custom_vjp_diff_arg_mutation(self): + n = 0 @nnx.custom_vjp(graph=True, graph_updates=False) def f(m): m.x[...] += 1 @@ -3010,7 +3030,12 @@ def f(m): def f_fwd(m): return f(m), (m,) - def f_bwd(res, g): + def f_bwd(res, g, updates_g): + nonlocal n + n += 1 + (m_up_g,) = updates_g + self.assertIsInstance(m_up_g.x, nnx.Param) + self.assertIsNone(m_up_g.y) (m,) = res m_g = nnx.clone(m) m_g.x[...] = g * m.y[...] @@ -3025,10 +3050,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0))) - with self.assertRaisesRegex( - ValueError, 'Variables in differentiable argument' - ): - f(m) + + g = nnx.grad(f)(m) + self.assertEqual(n, 1) class TestVjpJvp(parameterized.TestCase): @@ -3361,8 +3385,8 @@ def test_nested_carry_rejected(self): )({'a': jnp.array(1.0)}) @parameterized.parameters(True, False) - def test_broadcast_out_axes_rejected(self, graph): - with self.assertRaises(ValueError): + def test_broadcast_out_axes_rejected1(self, graph): + with self.assertRaisesRegex(ValueError, 'Cannot broadcast output state'): nnx.scan( lambda c, x: (c, x), in_axes=(nnx.Carry, 0), @@ -3408,7 +3432,7 @@ def test_nested_carry_in_out_axes_rejected(self): )(jnp.array(0.0), jnp.arange(3.0)) def test_carry_in_in_axes_only_rejected(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, 'If one of in_axes or out_axes has Carry'): nnx.scan( lambda c, x: (c + x,), in_axes=(nnx.Carry, 0), @@ -3417,7 +3441,7 @@ def test_carry_in_in_axes_only_rejected(self): )(jnp.array(0.0), jnp.arange(3.0)) def test_carry_in_out_axes_only_rejected(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, 'If one of in_axes or out_axes has Carry'): nnx.scan( lambda x: x, in_axes=(0,), @@ -3427,7 +3451,7 @@ def test_carry_in_out_axes_only_rejected(self): def test_non_tuple_carry_only(self): def f(carry): - return carry + 1.0 + return carry + 1 result = nnx.scan( f, @@ -3435,8 +3459,8 @@ def f(carry): out_axes=nnx.Carry, length=5, graph=False, - )(jnp.array(0.0)) - np.testing.assert_allclose(result, 5.0) + )(jnp.array(0)) + self.assertEqual(result, 5) def test_non_tuple_scan_only(self): def f(x): @@ -3508,6 +3532,30 @@ def stack_forward(params, x): assert y.shape == (5, 1, 3) assert count[...] == 5 + @parameterized.parameters(True, False) + def test_variables_broadcast_in_scan(self, graph): + w = nnx.Param(jax.random.normal(jax.random.key(0), (3, 3))) + b = nnx.Param(jnp.zeros((3,))) + count = nnx.BatchStat(0) + + def block_forward(w, b, x): + return nnx.gelu(x @ w + b[None]) + + @nnx.scan( + in_axes=(None, None, None, 0), out_axes=0, + graph=graph, graph_updates=False + ) + def stack_forward(w, b, count, x): + y = block_forward(w, b, x) + count[...] += 1 + return y + + x = jnp.ones((5, 1, 3)) + y = stack_forward(w, b, count, x) + + assert y.shape == (5, 1, 3) + assert count[...] == 5 + def test_basic_no_carry(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): @@ -3763,6 +3811,25 @@ def loop(foo: Foo) -> tuple[Foo, jax.Array]: foo = Foo() foo2, cs = loop(foo) self.assertIs(foo2.c, foo.c) + self.assertEqual(foo.c[...], 5) + np.testing.assert_allclose(cs, jnp.arange(1, 6)) + + @parameterized.parameters(True, False) + def test_only_carry_functional(self, graph): + class Foo(nnx.Module): + def __init__(self): + self.c = nnx.BatchStat(jnp.array(0)) + + @nnx.scan( + in_axes=None, out_axes=0, length=5, graph=graph, graph_updates=False + ) + def loop(foo: Foo) -> tuple[Foo, jax.Array]: + foo.c[...] += 1 + return foo.c[...] + + foo = Foo() + cs = loop(foo) + self.assertEqual(foo.c[...], 5) np.testing.assert_allclose(cs, jnp.arange(1, 6)) def test_out_axes(self): @@ -4856,7 +4923,7 @@ def _step2(self, state: tuple[CarryAsPytree, jax.Array, CarryAsPytree]): state[2].data = new_data2 return (state[0], out, state[2]) - @nnx.jit(static_argnames=("method"), graph=graph, graph_updates=False) + @nnx.jit(static_argnames=("method",), graph=graph, graph_updates=False) def __call__(self, state, method): state_axes = nnx.prefix( self, {nnx.Intermediate: 0, ...: None}, graph=graph