-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NNXWrapper #4088
Comments
Hey @PhilipVinc, as you point out #4081 is the solution we are working on to use Linen Modules in NNX and vice versa. Should be done soon-ish. In the meantime maybe you can use something simple like: class LinenToNNX(nnx.Module):
def __init__(
self,
module: linen.Module,
rngs: nnx.Rngs,
):
self.module = module
self.rngs = rngs
self.initialized = False
def __call__(
self, *args: Any, **kwargs: Any
) -> Any:
_rngs = {name: stream() for name, stream in rngs.items()}
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
if not self.initialized:
self.initialized = True
out, variables = self.module.init_with_output(_rngs, *args, **kwargs)
self.params = nnx.Param(variables['params'])
else:
variables = {'params': self.params.value}
out, variables = self.module.apply(variables, *args, rngs=_rngs, **kwargs)
self.params.value = variables['params']
return out |
Hi! I am working on the
|
Hi,
I have a large library that we decided to build on top of
flax.linen
several years ago.I'd like now to begin testing
nnx
. However, given the size of the repo and people using it, I cannot change everything at once over to nnx, instead I would like to keep usinglinen-style
code for a while, and allowing users to use models defined with nnx inside of our library.In brief, the way we use modules right now is
I tried to use
nnx.split
to this end, but the way it works, returning a special object and not a simple dictionary, makes it impossible to have this approach work fine.By inspecting
nnx.compat/bridge
I see that you have several utilities to use linen layers within nnx, but it is unclear to me how to do the opposite.It seems that
nnx.bridge.NNXWrapper
should do that, but it is unfinished, while it is not clear to me how to usennx.Module
..Is there anything I can use?
The text was updated successfully, but these errors were encountered: