Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs_nnx/flip/5310-tree-mode-nnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,25 @@ These new transforms are highly simplified compared to current transforms, they
```py
def transform_wrapper(*args):
if graph: args = to_tree(args)
check_no_aliases(args=args)
variables = check_no_aliases(args=args)

@jax_transform
def transformed_f(*args):
updates, snapshot = updates_and_snapshot(args)
current, prev = snapshot(labeled(args=args))
if graph: args = from_tree(args)
out = f(*args)
if graph: out = to_tree(out)
check_no_aliases(args=updates, out=out)
updates = mask_variable_updates(updates, snapshot)
check_no_aliases(**current, out=out)
updates = get_updates(current, prev)
return out, updates

out, updates = transformed_f(*args)
apply_variable_updates(args, updates)
apply_updates(variables, updates)
if graph: out = from_tree(out)
return out
```

The transformed function tracks input Variable `updates`, applies f, and masks Variable updates (no updates for Variables that didnt change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs, we simply convert objects to a tree representation before passing them to jax, and back to graphs before passing them to the user code.
The transformed function tracks input Variable, applies `f`, and creates the Variable updates (no updates for Variables that didn't change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs, we simply convert objects to a tree representation before passing them to jax, and back to graphs before passing them to the user code.

## Backward Compatibility

Expand Down
Loading
Loading