Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot assign arrays to dataclass fields in nnx #3627

Open
frazane opened this issue Jan 17, 2024 · 3 comments
Open

Cannot assign arrays to dataclass fields in nnx #3627

frazane opened this issue Jan 17, 2024 · 3 comments

Comments

@frazane
Copy link

frazane commented Jan 17, 2024

When instantiating a nnx.dataclass module, if the input to a param_field (or any variable_field actually) is a jax Array, a ValueError is raised because the value is assigned to the module without being wrapped into the nnx.Param class.

import jax
import jax.numpy as jnp
from flax.experimental import nnx

@nnx.dataclass
class Foo(nnx.Module):
    x: jax.Array = nnx.param_field()

foo = Foo(jnp.array(0.2))

I would expect that the input is wrapped into nnx.Param before being assigned to the module. Same as is works for e.g. integers or floats.

Logs, error messages, etc:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/users/fzanetta/pyprojects/GPJax/_debug/variables.ipynb Cell 2 line 5
      [1] @nnx.dataclass
      [2] class Foo(nnx.Module):
      [3]        x: jax.Array = nnx.param_field()
Foo(jnp.array(0.2))

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:150, in ModuleMeta.__call__(self, *args, **kwargs)
    149 def __call__(self, *args: Any, **kwargs: Any) -> Any:
--> 150   return self._meta_call(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:155, in ModuleMeta._meta_call(cls, *args, **kwargs)
    153 module = cls.__new__(cls, *args, **kwargs)
    154 vars(module)['_module__state'] = ModuleState()
--> 155 module.__init__(*args, **kwargs)
    157 if dataclasses.is_dataclass(module):
    158   if isinstance(module, _HasSetup):

File <string>:3, in __init__(self, x)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:207, in Module.__setattr__(self, name, value)
    206 def __setattr__(self, name: str, value: Any) -> None:
--> 207   self._setattr(name, value)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:232, in Module._setattr(self, name, value)
    230 else:
    231   if isinstance(value, (jax.Array, np.ndarray, State)):
--> 232     raise ValueError(
    233       f\"Trying to assign a '{type(value).__name__}' to the Module\"
    234       f\" attribute '{name}'. This is not supported. Non-hashable \"
    235       'objects are not valid static state in JAX. Please wrap '
    236       'the value in a Variable type instead.'
    237     )
    238   vars_dict[name] = value

ValueError: Trying to assign a 'ArrayImpl' to the Module attribute 'x'. This is not supported. Non-hashable objects are not valid static state in JAX. Please wrap the value in a Variable type instead."
}
@chiamp
Copy link
Collaborator

chiamp commented Jan 25, 2024

Seems like this is intentional behavior as there's a line of code that catches whether the input is a jax.Array or not. Any thoughts @cgarciae?

@frazane
Copy link
Author

frazane commented Jan 25, 2024

Seems like this is intentional behavior as there's a line of code that catches whether the input is a jax.Array or not. Any thoughts @cgarciae?

It's intentional that jax.Array cannot be assigned directly, but I thought the point of using nnx.param_field is that the jax.Array is first wrapped into nnx.Param before being assigned to the module. Same for nnx.variable_field where the array would be wrapped into the specified variable.

If I understand correctly, first the arguments are assigned directly to the module

module.__init__(*args, **kwargs)

and only in a second step (when using dataclasses)

if dataclasses.is_dataclass(module):
if isinstance(module, _HasSetup):
module.setup()
assert isinstance(module, Module)
for field in dataclasses.fields(module):
if not field.init:
continue
value = vars(module)[field.name]
# set Rngs instances to None
if isinstance(value, Rngs):
vars(module)[field.name] = None
continue
if 'nnx_variable_constructor' not in field.metadata:
continue
variable_constructor = field.metadata['nnx_variable_constructor']
vars(module)[field.name] = variable_constructor(value)

are they wrapped in the given variable container. So for integers, floats, etc. there are no problems during the first step, but if the argument is an array we have an error. If this is intentional, I wonder why?

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 7, 2024

Since #3720 you should pass the Param directly. nnx.dataclasses will be removed soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants