Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deprecation warning to all nnx.State methods #4561

Merged
merged 2 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs_nnx/guides/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ When interacting with checkpoint libraries (like Orbax), you may prefer to work

```{code-cell} ipython3
# Save as pure dict
pure_dict_state = state.to_pure_dict()
pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)

# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
```
Expand Down Expand Up @@ -181,7 +181,7 @@ restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!

Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -501,8 +501,8 @@
")\n",
"loaded_sharded = checkpointer.restore(path / 'checkpoint_name',\n",
" target=abs_state)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.w2.value)"
"jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ abs_state = jax.tree.map(
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)
jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)
```

## Compile the training loop
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ The dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))

.. testcode:: Haiku
:hide:
Expand Down Expand Up @@ -378,7 +378,7 @@ The parameter structure is as follows:
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -387,7 +387,7 @@ The parameter structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}


To call those custom methods:
Expand Down Expand Up @@ -634,14 +634,14 @@ Now inspect the variable pytree on both sides:
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}


Top-level Haiku functions vs top-level Flax modules
Expand Down
12 changes: 6 additions & 6 deletions docs_nnx/guides/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))

.. testcode:: Linen
:hide:
Expand Down Expand Up @@ -389,7 +389,7 @@ The variable structure is as follows:

# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -398,7 +398,7 @@ The variable structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}

To call methods other than ``__call__``:

Expand Down Expand Up @@ -531,7 +531,7 @@ Scan-over-layers is a technique where you run an input through a sequence of N r
* Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
* In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap<flax.nnx.vmap>` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan<flax.nnx.scan>` transform to run the model input through them.

For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__.
For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.build/en/guides/transforms.html>`__.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -644,14 +644,14 @@ Now inspect the variable pytree on both sides:

# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}


Using ``TrainState`` in Flax NNX
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,8 +95,8 @@
"# Variable sharing (weight-tying).\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n",
"assert hasattr(nnx.state(model), 'linear2')\n",
"assert hasattr(nnx.state(model)['linear2'], 'bias')\n",
"assert 'linear2' in nnx.state(model)\n",
"assert 'bias' in nnx.state(model)['linear2']\n",
"assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n",
"\n",
"# Monkey-patching.\n",
Expand Down Expand Up @@ -256,7 +256,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -301,7 +301,7 @@
"# Fit it into the model state.\n",
"abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graph_def, state = nnx.split(abs_model)\n",
"state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
"nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))\n",
"restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/surgery.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))
# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert 'linear2' in nnx.state(model)
assert 'bias' in nnx.state(model)['linear2']
assert not hasattr(nnx.state(model)['linear2'], 'kernel')

# Monkey-patching.
Expand Down Expand Up @@ -172,7 +172,7 @@ raw_dict['layer2'] = raw_dict.pop('linear2')
# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
Expand Down
68 changes: 32 additions & 36 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ Flax NNX is a new simplified API that is designed to make it easier to create, i

To begin, install Flax with `pip` and import necessary dependencies:

## Setup

Install Flax with `pip` and impost necessary dependencies:

```{code-cell} ipython3
:tags: [skip-execution]

Expand Down Expand Up @@ -95,7 +91,7 @@ to handle them, as demonstrated in later sections of this guide.

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:

```{code-cell} ipython3
class MLP(nnx.Module):
Expand Down
4 changes: 2 additions & 2 deletions examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def assign_val_fn(

mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = dict(state.flat_state())
state = dict(nnx.to_flat_state(state))
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
Expand All @@ -88,6 +88,6 @@ def assign_val_fn(
f' exist (original path={path}).'
)
state = assign_val_fn(state, mapped_path, val)
state = nnx.State.from_flat_path(state)
state = nnx.from_flat_state(state)

return nnx.merge(graph_def, state)
2 changes: 1 addition & 1 deletion examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
@property
def dtype(self) -> jnp.dtype:
params_state = nnx.state(self.transformer, nnx.Param)
return jax.tree_util.tree_leaves(params_state.flat_state())[0].dtype
return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype

def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
"""Performs a single sampling step."""
Expand Down
4 changes: 2 additions & 2 deletions examples/lm1b_nnx/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transfer_params(
params_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.axis_rules)
flat_params_nnx = dict(params_nnx.flat_state())
flat_params_nnx = dict(nnx.to_flat_state(params_nnx))
flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/')

def apply_rules(names: tuple[str, ...]):
Expand Down Expand Up @@ -163,7 +163,7 @@ def transfer_cache(
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
flat_cache_nnx = dict(cache_nnx.flat_state())
flat_cache_nnx = dict(nnx.to_flat_state(cache_nnx))
flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/')

def copy_var(nnx_name: str, linen_name: str):
Expand Down
6 changes: 3 additions & 3 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def init_optimizer_state(variable: nnx.Variable):

self.lr = lr
self.params = params
self.momentum = jax.tree.map(init_optimizer_state, self.params)
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay

def update(self, grads: nnx.State):
Expand Down Expand Up @@ -117,7 +117,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
else:
raise ValueError(f'Unknown path: {path}')

named_shardings = state.map(get_named_shardings)
named_shardings = nnx.map_state(get_named_shardings, state)
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
nnx.update(optimizer, sharded_state)
return model, optimizer
Expand All @@ -126,7 +126,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
model, optimizer = create_model()

jax.debug.visualize_array_sharding(model.w1.value)
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
jax.debug.visualize_array_sharding(optimizer.momentum['w1'].value)


@nnx.jit
Expand Down
8 changes: 8 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@
from .spmd import with_partitioning as with_partitioning
from .spmd import with_sharding_constraint as with_sharding_constraint
from .statelib import State as State
from .statelib import to_flat_state as to_flat_state
from .statelib import from_flat_state as from_flat_state
from .statelib import to_pure_dict as to_pure_dict
from .statelib import replace_by_pure_dict as replace_by_pure_dict
from .statelib import filter_state as filter_state
from .statelib import merge_state as merge_state
from .statelib import split_state as split_state
from .statelib import map_state as map_state
from .training import metrics as metrics
from .variablelib import Param as Param
# this needs to be imported before optimizer to prevent circular import
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _get_variables(self) -> tp.Mapping:
_variables: dict = {}

variable_state: variablelib.VariableState
for path, variable_state in state.flat_state():
for path, variable_state in statelib.to_flat_state(state):
try:
collection = variablelib.variable_name_from_type(variable_state.type)
except ValueError:
Expand Down Expand Up @@ -365,7 +365,7 @@ def to_variable(value):
real_variables[col_name] = linen_collection

states = ({},) if not real_variables else real_variables.values()
state = ModuleState.merge(*states)
state = statelib.merge_state(*states, cls=ModuleState)
graph.update(module, state)

if rngs is None:
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __call__(self, *args, **kwargs):
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
states = [State(v) for v in states.values()]
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
nnx_state = nnx.merge_state(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
Expand Down
Loading
Loading