Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Arrays are state #3791

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/experimental/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`,\n",
"As shown above dynamic state is usually stored in `nnx.Variable`s such as `nnx.Param`,\n",
"and static state (all types not handled by NNX) such as integers or strings \n",
"are stored directly. RNG keys can be requested from the `nnx.Rngs` object\n",
"by calling `rngs.<stream_name>()` where the stream name show match on of the names provided to the `Rngs` constructor (shown below).\n",
"are stored directly. JAX array and Numpy array attributes are also treated as dynamic state,\n",
"although storing them inside `nnx.Variable`s is preferred. Also, RNG keys can be requested from the \n",
"`nnx.Rngs` object by calling `rngs.<stream_name>()` where the stream name show match on of \n",
"the names provided to the `Rngs` constructor (shown below).\n",
"\n",
"To actually initialize a Module is very easy: simply call the constructor. All of the\n",
"parameters of a Module will be created right then and there, and are immediately available\n",
Expand Down
8 changes: 5 additions & 3 deletions docs/experimental/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ class Linear(nnx.Module):
return x @ self.w.value + self.b.value
```

As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`,
As shown above dynamic state is usually stored in `nnx.Variable`s such as `nnx.Param`,
and static state (all types not handled by NNX) such as integers or strings
are stored directly. RNG keys can be requested from the `nnx.Rngs` object
by calling `rngs.<stream_name>()` where the stream name show match on of the names provided to the `Rngs` constructor (shown below).
are stored directly. JAX array and Numpy array attributes are also treated as dynamic state,
although storing them inside `nnx.Variable`s is preferred. Also, RNG keys can be requested from the
`nnx.Rngs` object by calling `rngs.<stream_name>()` where the stream name show match on of
the names provided to the `Rngs` constructor (shown below).

To actually initialize a Module is very easy: simply call the constructor. All of the
parameters of a Module will be created right then and there, and are immediately available
Expand Down
80 changes: 46 additions & 34 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.experimental.nnx.nnx.state import State
from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf
from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable
from flax.typing import Path, PathParts

Expand All @@ -44,6 +44,7 @@

NODE_TYPES: dict[type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'] = {}


class _HashById(tp.Hashable, tp.Generic[A]):
"""A wrapper around a value that uses its id for hashing and equality.
This is used by RefMap to explicitly use object id as the hash for the keys.
Expand Down Expand Up @@ -122,7 +123,6 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
]



def register_graph_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]],
Expand Down Expand Up @@ -414,7 +414,7 @@ def graph_flatten(
x: Node,
) -> tuple[State, GraphDef[Node], tp.Mapping[tp.Any, Index]]:
ref_to_index = RefMap[tp.Any, Index]()
flat_state: dict[Path, Variable[tp.Any]] = {}
flat_state: dict[Path, StateLeaf] = {}
graphdef = _graph_flatten((), ref_to_index, flat_state, x)
assert not isinstance(graphdef, int)
return State.from_flat_path(flat_state), graphdef, ref_to_index
Expand All @@ -423,7 +423,7 @@ def graph_flatten(
def _graph_flatten(
path: PathParts,
ref_to_index: RefMap[tp.Any, Index],
flat_state: dict[Path, Variable[tp.Any]],
flat_state: dict[Path, StateLeaf],
node: Node,
) -> GraphDef[Node] | int:
if not is_node(node):
Expand Down Expand Up @@ -465,6 +465,9 @@ def _graph_flatten(
variables.append(
(key, VariableDef.from_variable(value, variable_index))
)
elif is_state_leaf(value):
str_path = '/'.join((*path, key))
flat_state[str_path] = value
else:
static_fields.append((key, value))

Expand Down Expand Up @@ -505,7 +508,7 @@ def graph_unflatten(

def _graph_unflatten(
graphdef: tp.Union[GraphDef[Node], int],
state: dict[str, Variable[tp.Any] | dict[str, tp.Any]],
state: dict[str, StateLeaf | dict[str, tp.Any]],
index_to_ref: dict[Index, tp.Any],
ref_cache: dict[Index, tp.Any] | None,
) -> Node:
Expand Down Expand Up @@ -536,30 +539,31 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(graphdef.type)

def _get_children():
new_state: dict[str, tp.Any] = {}
children: dict[str, StateLeaf | Node] = {}

for key in graphdef.attributes:
if key in graphdef.static_fields:
new_state[key] = graphdef.static_fields[key]
children[key] = graphdef.static_fields[key]
elif key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in graphdef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = graphdef.subgraphs[key]
if isinstance(subgraphdef, int):
# subgraph exists, take it from the cache
new_state[key] = index_to_ref[subgraphdef]
children[key] = index_to_ref[subgraphdef]
else:
# create an empty node and add it to the cache
substate = {}
node = new_state[key] = _graph_unflatten(
node = children[key] = _graph_unflatten(
subgraphdef, substate, index_to_ref, ref_cache
)
elif key in graphdef.variables:
variable_def = graphdef.variables[key]
if isinstance(variable_def, int):
# variable exists, take it from the cache
new_state[key] = index_to_ref[variable_def]
children[key] = index_to_ref[variable_def]
else:
# create an empty variable and add it to the cache
if ref_cache is not None and variable_def.index in ref_cache:
Expand All @@ -574,30 +578,31 @@ def _get_children():
node.copy_from_def(variable_def, EMPTY)
else:
node = variable_def.to_variable(EMPTY)
new_state[key] = node
children[key] = node
index_to_ref[variable_def.index] = node
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
value = state[key]
if key in graphdef.subgraphs:
if isinstance(value, Variable):
if is_state_leaf(value):
raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable.'
)
assert isinstance(value, dict)
subgraphdef = graphdef.subgraphs[key]

if isinstance(subgraphdef, int):
node = index_to_ref[subgraphdef]
else:
node = new_state[key] = _graph_unflatten(
node = children[key] = _graph_unflatten(
subgraphdef, value, index_to_ref, ref_cache
)

elif key in graphdef.variables:
variable_def = graphdef.variables[key]
if isinstance(variable_def, int):
new_state[key] = index_to_ref[variable_def]
children[key] = index_to_ref[variable_def]
else:
if type(value) != variable_def.type:
raise ValueError(
Expand All @@ -616,13 +621,14 @@ def _get_children():
else:
assert isinstance(value, Variable)
variable = value.copy()
new_state[key] = variable
children[key] = variable
index_to_ref[variable_def.index] = variable

elif is_state_leaf(value):
children[key] = value
for new_key in set(state) - set(graphdef.attributes):
new_state[new_key] = state[new_key]
raise ValueError(f'Unknown key: {new_key!r}')

return new_state
return children

if isinstance(node_impl, GraphNodeImpl):
# we create an empty node first and add it to the index
Expand Down Expand Up @@ -682,7 +688,7 @@ def _graph_pop(
if is_node(value):
_graph_pop(value, id_to_index, (*path_parts, name), states, predicates)
continue
elif not isinstance(value, Variable):
elif not is_state_leaf(value):
continue
elif id(value) in id_to_index:
continue
Expand All @@ -695,9 +701,11 @@ def _graph_pop(
raise ValueError(
f'Cannot pop key {name!r} from node of type {type(node).__name__}'
)
state[path] = value.copy()
id_to_index[id(value)] = len(id_to_index)
node_impl.pop_key(node, name)
if isinstance(value, Variable):
value = value.copy()
state[path] = value
break
else:
# NOTE: should we raise an error here?
Expand All @@ -720,9 +728,7 @@ def graph_update_dynamic(
_graph_update_dynamic(node, state.raw_mapping)


def _graph_update_dynamic(
node: tp.Any, state: dict[str, Variable[tp.Any] | dict[str, tp.Any]]
):
def _graph_update_dynamic(node: tp.Any, state: dict[str, tp.Any]):
if not is_node(node):
raise RuntimeError(f'Unsupported type: {type(node)}')

Expand All @@ -746,23 +752,29 @@ def _graph_update_dynamic(

# case 2: subgraph is being updated
if is_node(current_value):
if isinstance(value, Variable):
raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}'
)
if is_state_leaf(value):
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
_graph_update_dynamic(current_value, value)
else:
# case 3: Variable is being updated
# assert isinstance(value, Variable)
# assert isinstance(current_value, Variable)
if not isinstance(value, Variable):
raise ValueError(f'Expected a Variable for attribute {key!r}')
elif isinstance(value, Variable):
# case 3: state leaf is being updated
if not isinstance(current_value, Variable):
raise ValueError(
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
f'{value!r}'
)
current_value.copy_from(value)
elif is_state_leaf(value):
# case 4: state field is being updated
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {key!r} on immutable node of '
f'type {type(node).__name__}'
)
node_impl.set_key(node, key, value)
else:
raise ValueError(
f'Unsupported update type: {type(value)} for key {key!r}'
)


class _StaticModuleStatus(enum.Enum):
Expand Down Expand Up @@ -813,7 +825,7 @@ def _graph_update_static(
updates_dict = node_impl.node_dict(updates)
for name, value_updates in updates_dict.items():
# case 1: trying to update a Variable, skip
if isinstance(value_updates, Variable):
if is_state_leaf(value_updates):
continue
elif is_node(value_updates):
# case 2: updating an existing subgraph
Expand Down
9 changes: 0 additions & 9 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import jax
import jax.tree_util as jtu
import numpy as np
import typing_extensions as tpe

from flax.experimental.nnx.nnx import (
Expand Down Expand Up @@ -140,14 +139,6 @@ def _setattr(self, name: str, value: tp.Any) -> None:
'Cannot mutate Module from different trace level'
)

if isinstance(value, (jax.Array, np.ndarray, State)):
raise ValueError(
f"Trying to assign a '{type(value).__name__}' to the Module"
f" attribute '{name}'. This is not supported. Non-hashable "
'objects are not valid static state in JAX. Please wrap '
'the value in a Variable type instead.'
)

object.__setattr__(self, name, value)

def __deepcopy__(self: M, memo=None) -> M:
Expand Down
Loading
Loading