Skip to content

Commit

Permalink
[nnx] Arrays are state
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 27, 2024
1 parent b57e3aa commit 8718110
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 95 deletions.
8 changes: 5 additions & 3 deletions docs/experimental/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`,\n",
"As shown above dynamic state is usually stored in `nnx.Variable`s such as `nnx.Param`,\n",
"and static state (all types not handled by NNX) such as integers or strings \n",
"are stored directly. RNG keys can be requested from the `nnx.Rngs` object\n",
"by calling `rngs.<stream_name>()` where the stream name show match on of the names provided to the `Rngs` constructor (shown below).\n",
"are stored directly. JAX array and Numpy array attributes are also treated as dynamic state,\n",
"although storing them inside `nnx.Variable`s is preferred. Also, RNG keys can be requested from the \n",
"`nnx.Rngs` object by calling `rngs.<stream_name>()` where the stream name show match on of \n",
"the names provided to the `Rngs` constructor (shown below).\n",
"\n",
"To actually initialize a Module is very easy: simply call the constructor. All of the\n",
"parameters of a Module will be created right then and there, and are immediately available\n",
Expand Down
8 changes: 5 additions & 3 deletions docs/experimental/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ class Linear(nnx.Module):
return x @ self.w.value + self.b.value
```

As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`,
As shown above dynamic state is usually stored in `nnx.Variable`s such as `nnx.Param`,
and static state (all types not handled by NNX) such as integers or strings
are stored directly. RNG keys can be requested from the `nnx.Rngs` object
by calling `rngs.<stream_name>()` where the stream name show match on of the names provided to the `Rngs` constructor (shown below).
are stored directly. JAX array and Numpy array attributes are also treated as dynamic state,
although storing them inside `nnx.Variable`s is preferred. Also, RNG keys can be requested from the
`nnx.Rngs` object by calling `rngs.<stream_name>()` where the stream name show match on of
the names provided to the `Rngs` constructor (shown below).

To actually initialize a Module is very easy: simply call the constructor. All of the
parameters of a Module will be created right then and there, and are immediately available
Expand Down
80 changes: 46 additions & 34 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.experimental.nnx.nnx.state import State
from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf
from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable
from flax.typing import Path, PathParts

Expand All @@ -44,6 +44,7 @@

NODE_TYPES: dict[type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'] = {}


class _HashById(tp.Hashable, tp.Generic[A]):
"""A wrapper around a value that uses its id for hashing and equality.
This is used by RefMap to explicitly use object id as the hash for the keys.
Expand Down Expand Up @@ -122,7 +123,6 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
]



def register_graph_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]],
Expand Down Expand Up @@ -414,7 +414,7 @@ def graph_flatten(
x: Node,
) -> tuple[State, GraphDef[Node], tp.Mapping[tp.Any, Index]]:
ref_to_index = RefMap[tp.Any, Index]()
flat_state: dict[Path, Variable[tp.Any]] = {}
flat_state: dict[Path, StateLeaf] = {}
graphdef = _graph_flatten((), ref_to_index, flat_state, x)
assert not isinstance(graphdef, int)
return State.from_flat_path(flat_state), graphdef, ref_to_index
Expand All @@ -423,7 +423,7 @@ def graph_flatten(
def _graph_flatten(
path: PathParts,
ref_to_index: RefMap[tp.Any, Index],
flat_state: dict[Path, Variable[tp.Any]],
flat_state: dict[Path, StateLeaf],
node: Node,
) -> GraphDef[Node] | int:
if not is_node(node):
Expand Down Expand Up @@ -465,6 +465,9 @@ def _graph_flatten(
variables.append(
(key, VariableDef.from_variable(value, variable_index))
)
elif is_state_leaf(value):
str_path = '/'.join((*path, key))
flat_state[str_path] = value
else:
static_fields.append((key, value))

Expand Down Expand Up @@ -505,7 +508,7 @@ def graph_unflatten(

def _graph_unflatten(
graphdef: tp.Union[GraphDef[Node], int],
state: dict[str, Variable[tp.Any] | dict[str, tp.Any]],
state: dict[str, StateLeaf | dict[str, tp.Any]],
index_to_ref: dict[Index, tp.Any],
ref_cache: dict[Index, tp.Any] | None,
) -> Node:
Expand Down Expand Up @@ -536,30 +539,31 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(graphdef.type)

def _get_children():
new_state: dict[str, tp.Any] = {}
children: dict[str, StateLeaf | Node] = {}

for key in graphdef.attributes:
if key in graphdef.static_fields:
new_state[key] = graphdef.static_fields[key]
children[key] = graphdef.static_fields[key]
elif key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in graphdef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = graphdef.subgraphs[key]
if isinstance(subgraphdef, int):
# subgraph exists, take it from the cache
new_state[key] = index_to_ref[subgraphdef]
children[key] = index_to_ref[subgraphdef]
else:
# create an empty node and add it to the cache
substate = {}
node = new_state[key] = _graph_unflatten(
node = children[key] = _graph_unflatten(
subgraphdef, substate, index_to_ref, ref_cache
)
elif key in graphdef.variables:
variable_def = graphdef.variables[key]
if isinstance(variable_def, int):
# variable exists, take it from the cache
new_state[key] = index_to_ref[variable_def]
children[key] = index_to_ref[variable_def]
else:
# create an empty variable and add it to the cache
if ref_cache is not None and variable_def.index in ref_cache:
Expand All @@ -574,30 +578,31 @@ def _get_children():
node.copy_from_def(variable_def, EMPTY)
else:
node = variable_def.to_variable(EMPTY)
new_state[key] = node
children[key] = node
index_to_ref[variable_def.index] = node
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
value = state[key]
if key in graphdef.subgraphs:
if isinstance(value, Variable):
if is_state_leaf(value):
raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable.'
)
assert isinstance(value, dict)
subgraphdef = graphdef.subgraphs[key]

if isinstance(subgraphdef, int):
node = index_to_ref[subgraphdef]
else:
node = new_state[key] = _graph_unflatten(
node = children[key] = _graph_unflatten(
subgraphdef, value, index_to_ref, ref_cache
)

elif key in graphdef.variables:
variable_def = graphdef.variables[key]
if isinstance(variable_def, int):
new_state[key] = index_to_ref[variable_def]
children[key] = index_to_ref[variable_def]
else:
if type(value) != variable_def.type:
raise ValueError(
Expand All @@ -616,13 +621,14 @@ def _get_children():
else:
assert isinstance(value, Variable)
variable = value.copy()
new_state[key] = variable
children[key] = variable
index_to_ref[variable_def.index] = variable

elif is_state_leaf(value):
children[key] = value
for new_key in set(state) - set(graphdef.attributes):
new_state[new_key] = state[new_key]
raise ValueError(f'Unknown key: {new_key!r}')

return new_state
return children

if isinstance(node_impl, GraphNodeImpl):
# we create an empty node first and add it to the index
Expand Down Expand Up @@ -682,7 +688,7 @@ def _graph_pop(
if is_node(value):
_graph_pop(value, id_to_index, (*path_parts, name), states, predicates)
continue
elif not isinstance(value, Variable):
elif not is_state_leaf(value):
continue
elif id(value) in id_to_index:
continue
Expand All @@ -695,9 +701,11 @@ def _graph_pop(
raise ValueError(
f'Cannot pop key {name!r} from node of type {type(node).__name__}'
)
state[path] = value.copy()
id_to_index[id(value)] = len(id_to_index)
node_impl.pop_key(node, name)
if isinstance(value, Variable):
value = value.copy()
state[path] = value
break
else:
# NOTE: should we raise an error here?
Expand All @@ -720,9 +728,7 @@ def graph_update_dynamic(
_graph_update_dynamic(node, state.raw_mapping)


def _graph_update_dynamic(
node: tp.Any, state: dict[str, Variable[tp.Any] | dict[str, tp.Any]]
):
def _graph_update_dynamic(node: tp.Any, state: dict[str, tp.Any]):
if not is_node(node):
raise RuntimeError(f'Unsupported type: {type(node)}')

Expand All @@ -746,23 +752,29 @@ def _graph_update_dynamic(

# case 2: subgraph is being updated
if is_node(current_value):
if isinstance(value, Variable):
raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}'
)
if is_state_leaf(value):
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
_graph_update_dynamic(current_value, value)
else:
# case 3: Variable is being updated
# assert isinstance(value, Variable)
# assert isinstance(current_value, Variable)
if not isinstance(value, Variable):
raise ValueError(f'Expected a Variable for attribute {key!r}')
elif isinstance(value, Variable):
# case 3: state leaf is being updated
if not isinstance(current_value, Variable):
raise ValueError(
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
f'{value!r}'
)
current_value.copy_from(value)
elif is_state_leaf(value):
# case 4: state field is being updated
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {key!r} on immutable node of '
f'type {type(node).__name__}'
)
node_impl.set_key(node, key, value)
else:
raise ValueError(
f'Unsupported update type: {type(value)} for key {key!r}'
)


class _StaticModuleStatus(enum.Enum):
Expand Down Expand Up @@ -813,7 +825,7 @@ def _graph_update_static(
updates_dict = node_impl.node_dict(updates)
for name, value_updates in updates_dict.items():
# case 1: trying to update a Variable, skip
if isinstance(value_updates, Variable):
if is_state_leaf(value_updates):
continue
elif is_node(value_updates):
# case 2: updating an existing subgraph
Expand Down
9 changes: 0 additions & 9 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import jax
import jax.tree_util as jtu
import numpy as np
import typing_extensions as tpe

from flax.experimental.nnx.nnx import (
Expand Down Expand Up @@ -140,14 +139,6 @@ def _setattr(self, name: str, value: tp.Any) -> None:
'Cannot mutate Module from different trace level'
)

if isinstance(value, (jax.Array, np.ndarray, State)):
raise ValueError(
f"Trying to assign a '{type(value).__name__}' to the Module"
f" attribute '{name}'. This is not supported. Non-hashable "
'objects are not valid static state in JAX. Please wrap '
'the value in a Variable type instead.'
)

object.__setattr__(self, name, value)

def __deepcopy__(self: M, memo=None) -> M:
Expand Down
Loading

0 comments on commit 8718110

Please sign in to comment.