Skip to content

Commit

Permalink
Update Why Flax NNX doc
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Oct 10, 2024
1 parent 28a423c commit 50df809
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions docs_nnx/why.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ Inspection
The first improvement is that Flax NNX Modules are regular Python objects. This means that you can easily
construct and inspect ``Module`` objects.

Compare this to Flax Linen, where Modules are lazy, which means some attributes are not available upon construction
and are only accessible at runtime, making Linen not easy to inspect and debug.
On the other hand, Flax Linen Modules are not easy to inspect and debug because they are lazy, which means some attributes are not available upon construction and are only accessible at runtime.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -130,7 +129,7 @@ In Flax Linen, all top-level computation must be done through the ``flax.linen.M
parameters or any other type of state are handled as a separate structure. This creates an asymmetry between: 1) code that runs inside
``apply`` that can run methods and other ``Module`` objects directly; and 2) code that runs outside of ``apply`` that must use the ``apply`` method.

In Flax NNX, there's no special context because parameters are held as attributes and methods can be called directly.
In Flax NNX, there's no special context because parameters are held as attributes and methods can be called directly. That means your NNX Module's ``__init__`` and ``__call__`` methods are not treated differently from other class methods, whereas Flax Linen Module's ``setup()`` and ``__call__`` methods are special.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -185,7 +184,7 @@ In Flax NNX, there's no special context because parameters are held as attribute
In Flax Linen, calling sub-Modules directly is not possible because they are not initialized.
Therefore, what you must do is construct a new instance and then provide a proper parameter structure.

But In Flax NNX you can call sub-Modules directly without any issues.
But in Flax NNX you can call sub-Modules directly without any issues.

State handling
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -283,7 +282,7 @@ Model surgery
In Flax Linen, model surgery has historically been challenging because of two reasons:

1. Due to lazy initialization, it is not guaranteed that you can replace a sub-``Module`` with a new one.
2. The parameter structure is separate from the ``flax.linen.Module`` structure, which means you have to manually keep them in sync.
2. The parameter structure is separated from the ``flax.linen.Module`` structure, which means you have to manually keep them in sync.

In Flax NNX, you can replace sub-Modules directly as per the Python semantics. Since parameters are
part of the ``nnx.Module`` structure, they are never out of sync. Below is an example of how you can
Expand Down Expand Up @@ -365,10 +364,10 @@ Transforms
Flax Linen transforms are very powerful in that they enable fine-grained control over the model's state.
However, Flax Linen transforms have drawbacks, such as:

1. They expose additional APIs that are not part of JAX (making it not easy to interact with JAX transforms).
1. They expose additional APIs that are not part of JAX, making their behavior confusing and sometimes divergent from their JAX counterparts. This also constrains your ways to interact with `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`_ and keep up with JAX API changes.
2. They work on functions with very specific signatures, namely:
- A ``flax.linen.Module`` must be the first argument.
- They accept other ``Module`` objects as arguments but not as return values.
- A ``flax.linen.Module`` must be the first argument.
- They accept other ``Module`` objects as arguments but not as return values.
3. They can only be used inside ``flax.linen.Module.apply``.

On the other hand, `Flax NNX transforms <https://flax.readthedocs.io/en/latest/guides/transforms.html>`_
Expand Down

0 comments on commit 50df809

Please sign in to comment.