Skip to content

Commit

Permalink
Add NNXToLinen
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Aug 15, 2024
1 parent e637bb3 commit fe1422e
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 19 deletions.
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
from .nnx.training import metrics as metrics
from .nnx.variables import (
Param as Param,
register_variable_name_type_pair as register_variable_name_type_pair,
) # this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
Expand Down
7 changes: 4 additions & 3 deletions flax/nnx/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .module import Scope as Scope
from .module import compact as compact
from .wrappers import functional as functional
from .wrappers import LinenToNNX as LinenToNNX
from .wrappers import Functional as Functional
from .wrappers import NNXToLinen as NNXToLinen
from .wrappers import lazy_init as lazy_init
from .wrappers import ToNNX as ToNNX
from .wrappers import lazy_init as lazy_init
from .wrappers import ToLinen as ToLinen
from .wrappers import to_linen as to_linen
124 changes: 118 additions & 6 deletions flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,35 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
return fn


class LinenToNNX(Module):
class ToNNX(Module):
"""A wrapper to turn any Linen module into an NNX module.
The result NNX module can be used standalone with all NNX APIs, or as a submodule of
another NNX module.
Since Linen module initialization requires a sample input, you need to call `lazy_init`
with an argument to initialize the variables.
Example::
>>> from flax import linen as nn, nnx
>>> import jax
>>> linen_module = nn.Dense(features=64)
>>> x = jax.numpy.ones((1, 32))
>>> # Like Linen, initialize with a sample input
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply, but using NNX's direct call method
>>> y = model(x)
>>> nnx.state(model).params.kernel.value.shape
(32, 64)
Args:
module: The Linen Module instance.
rngs: The `nnx.Rngs` instance being passed to any NNX module.
Returns:
A stateful NNX module that behaves the same as the wrapped Linen module.
"""
def __init__(
self,
module: linen.Module,
Expand Down Expand Up @@ -139,11 +167,95 @@ def nn_var_to_nnx_state(kp, v):
return out


class NNXToLinen(linen.Module):
module: Module
def linen_rngs_dict(linen_module: linen.Module) -> tp.Mapping[str, jax.Array]:
"""Given a module, split out one of its every active RNG key collections."""
assert linen_module.scope is not None, 'linen_rngs_dict() must be called inside a Linen module.'
return {name: linen_module.make_rng(name)
for name in linen_module.scope.rngs.keys()}


class ToLinen(linen.Module):
"""A wrapper to turn any NNX module into a Linen module.
The result Linen module can be used standalone with all Linen APIs, or as a submodule of
another Linen module.
def setup(self):
...
Since NNX modules are stateful and owns the state, we only create it once during init
time, and will track its state and static data as separate variables.
Example::
>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> x = jax.numpy.ones((1, 32))
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].value.shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
dict_keys(['nnx', 'params'])
>>> type(variables['nnx']['graphdef'])
<class 'flax.nnx.nnx.graph.GraphDef'>
Args:
nnx_class: The NNX Module class (not instance!).
args: The arguments that normally would be passed in to create the NNX module.
kwargs: The keyword arguments that normally would be passed in to create the NNX module.
skip_rng: True if this NNX module doesn't need `rngs` arg during initialization (not common).
Returns:
A stateful NNX module that behaves the same as the wrapped Linen module.
"""
nnx_class: tp.Callable[..., Module]
args: tp.Sequence = ()
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
skip_rng: bool = False

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = variableslib.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = variableslib.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
self.put_variable(collection, k, v)

@linen.compact
def __call__(self, *args, **kwargs):
...
# init codepath
if self.is_initializing():
module_kwargs = dict(self.kwargs)
if not self.skip_rng:
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
self.update_variables(module)
return module(*args, **kwargs)

# apply codepath
gdef = self.get_variable('nnx', 'graphdef')
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
return out


def to_linen(nnx_class: tp.Callable[..., Module], *args, **kwargs):
"""Shortcut of `ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs)
41 changes: 37 additions & 4 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,14 +999,47 @@ def wrapper(*args):
return wrapper # type: ignore


### Variable type <-> name mapping ###
# Assumption: the mapping is 1-1 and unique.

def variable_type(name: str) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
if name not in VariableTypeCache:
VariableTypeCache[name] = type(name, (Variable,), {})
return VariableTypeCache[name]


def variable_type_name(typ: tp.Type[Variable[tp.Any]]) -> str:
"""Given an NNX Variable type, get or create its Linen-style collection name.
Should output the exact inversed result of `variable_type()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name_type_pair(name, typ)
return name


def register_variable_name_type_pair(name, typ):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
VariableTypeCache[name] = typ


# add known variable type names
VariableTypeCache['params'] = Param
VariableTypeCache['batch_stats'] = BatchStat
VariableTypeCache['cache'] = Cache
VariableTypeCache['intermediates'] = Intermediate
register_variable_name_type_pair('params', Param)
register_variable_name_type_pair('batch_stats', BatchStat)
register_variable_name_type_pair('cache', Cache)
register_variable_name_type_pair('intermediates', Intermediate)


def sort_variable_types(types: list[type]):
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, nnx.Variable))
parent_count = {t: _variable_parents_count(t) for t in types}
return sorted(types, key=lambda t: -parent_count[t])
111 changes: 105 additions & 6 deletions flax/nnx/tests/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from absl.testing import absltest

import flax
from flax import linen as nn
from flax import nnx
Expand All @@ -31,20 +30,24 @@ def test_functional(self):
x = jax.numpy.ones((1, 32))
y, updates = functional.apply(state)(x)

##################
### LinenToNNX ###
##################

def test_linen_to_nnx(self):
## Wrapper API for Linen Modules
linen_module = nn.Dense(features=64)
x = jax.numpy.ones((1, 32))
model = bridge.LinenToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init
model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init
y = model(x) # like linen apply
assert y.shape == (1, 64)

def test_linen_to_nnx_submodule(self):
class NNXOuter(nnx.Module):
def __init__(self, dout: int, *, rngs: nnx.Rngs):
self.nn_dense1 = bridge.LinenToNNX(nn.Dense(dout, use_bias=False), rngs=rngs)
self.nn_dense1 = bridge.ToNNX(nn.Dense(dout, use_bias=False), rngs=rngs)
self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, dout,)))
self.batchnorm = bridge.LinenToNNX(nn.BatchNorm(use_running_average=True), rngs=rngs)
self.batchnorm = bridge.ToNNX(nn.BatchNorm(use_running_average=True), rngs=rngs)
self.rngs = rngs

def __call__(self, x):
Expand Down Expand Up @@ -77,7 +80,7 @@ def dot(self, x):
return x @ w

x = jax.random.normal(jax.random.key(0), (2, 4))
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0))
model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0))
bridge.lazy_init(model, x, method=model.module.dot)
y = model(x, method=model.module.dot)
np.testing.assert_allclose(y, x @ nnx.state(model).params.w.value)
Expand All @@ -96,11 +99,107 @@ def __call__(self, x):
return x

x = lambda: jnp.zeros((), jnp.int32)
model = bridge.LinenToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x)
model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x)
assert nnx.state(model).counter.count.value == 0
y = model(x, mutable=True)
assert nnx.state(model).counter.count.value == 1

##################
### NNXToLinen ###
##################

def test_nnx_to_linen(self):
model = bridge.to_linen(nnx.Linear, 32, out_features=64)
x = jax.numpy.ones((1, 32))
y, variables = model.init_with_output(jax.random.key(0), x)
assert y.shape == (1, 64)
np.testing.assert_allclose(y, x @ variables['params']['kernel'].value)

def test_nnx_to_linen_multiple_rngs(self):
class NNXInner(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
def __call__(self, x):
return self.dropout(x @ self.w.value)

class LinenOuter(nn.Module):
@nn.compact
def __call__(self, x):
inner = bridge.to_linen(NNXInner, 4, 3)
return inner(x)

xkey, pkey, dkey1, dkey2 = jax.random.split(jax.random.key(0), 4)
x = jax.random.normal(xkey, (2, 4))
model = LinenOuter()
y1, var = model.init_with_output({'params': pkey, 'dropout': dkey1}, x)
y2 = model.apply(var, x, rngs={'dropout': dkey2})
assert not jnp.allclose(y1, y2) # dropout keys are different

def test_nnx_to_linen_multiple_collections(self):
class NNXInner(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)

def __call__(self, x):
return self.bn(x @ self.w.value) + self.lora(x)

xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
model = bridge.to_linen(NNXInner, 4, 3)
var = model.init({'params': pkey, 'dropout': dkey}, x)
self.assertSameElements(var.keys(), ['nnx', 'LoRAParam', 'params', 'batch_stats'])
y = model.apply(var, x)
assert y.shape == (2, 3)

def test_nnx_to_linen_mutable(self):
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('Count', Count)

class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1

model = bridge.ToLinen(Counter, skip_rng=True)
variables = model.init(jax.random.key(0))
assert variables['Count']['count'].value == 0

_, updates = model.apply(variables, mutable='Count')
assert updates['Count']['count'].value == 1
_ = model.apply(variables | updates)

def test_nnx_to_linen_mutated_static_data(self):
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('Count', Count)

class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
self.count_nonzero = Count(jnp.array(1))

model = bridge.ToLinen(Counter, skip_rng=True)
variables = model.init(jax.random.key(0))
assert variables['Count']['count'].value == 0

# This does not work, because the __call__ also changes the static data of the model.
_, updates = model.apply(variables, mutable='Count')
assert updates['Count']['count'].value == 1
assert updates['Count']['count_nonzero'].value == 1
with self.assertRaises(ValueError):
_ = model.apply(variables | updates)

# This makes sure the static data is updated too. Using mutable=True also works.
_, updates = model.apply(variables, mutable=['Count', 'nnx'])
assert updates['Count']['count'].value == 1
assert updates['Count']['count_nonzero'].value == 1
_ = model.apply(variables | updates)


if __name__ == '__main__':
absltest.main()

0 comments on commit fe1422e

Please sign in to comment.