-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot assign arrays to dataclass fields in nnx
#3627
Comments
Seems like this is intentional behavior as there's a line of code that catches whether the input is a |
It's intentional that If I understand correctly, first the arguments are assigned directly to the module flax/flax/experimental/nnx/nnx/module.py Line 155 in 3cd34b6
and only in a second step (when using dataclasses) flax/flax/experimental/nnx/nnx/module.py Lines 157 to 176 in 3cd34b6
are they wrapped in the given variable container. So for integers, floats, etc. there are no problems during the first step, but if the argument is an array we have an error. If this is intentional, I wonder why? |
Since #3720 you should pass the |
When instantiating a
nnx.dataclass
module, if the input to aparam_field
(or anyvariable_field
actually) is a jax Array, a ValueError is raised because the value is assigned to the module without being wrapped into thennx.Param
class.I would expect that the input is wrapped into
nnx.Param
before being assigned to the module. Same as is works for e.g. integers or floats.Logs, error messages, etc:
The text was updated successfully, but these errors were encountered: