Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove markdown from section titles #4322

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs_nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Flax `nnx.Module` system\n",
"## The Flax NNX Module system\n",
"\n",
"The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that:\n",
"\n",
Expand Down Expand Up @@ -190,7 +190,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Nested `nnx.Module`s\n",
"### Nested Modules\n",
"\n",
"Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.\n",
"\n",
Expand Down Expand Up @@ -355,7 +355,7 @@
"1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n",
"2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone.\n",
"\n",
"### `nnx.scan` over layers\n",
"### Scan over layers\n",
"\n",
"The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input.\n",
"\n",
Expand Down Expand Up @@ -474,7 +474,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### `State` and `GraphDef`\n",
"### State and GraphDef\n",
"\n",
"A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function:\n",
"\n",
Expand Down Expand Up @@ -522,7 +522,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### `split`, `merge`, and `update`\n",
"### Split, merge, and update\n",
"\n",
"Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows:\n",
"\n",
Expand Down Expand Up @@ -574,14 +574,14 @@
"source": [
"The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.\n",
"\n",
"**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about."
"**Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fine-grained `State` control\n",
"### Fine-grained State control\n",
"\n",
"Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n",
"\n",
Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import jax
import jax.numpy as jnp
```

## The Flax `nnx.Module` system
## The Flax NNX Module system

The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that:

Expand Down Expand Up @@ -112,7 +112,7 @@ to handle them, as demonstrated in later sections of this guide.

+++

### Nested `nnx.Module`s
### Nested Modules

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

Expand Down Expand Up @@ -209,7 +209,7 @@ There are two things happening in this example that are worth mentioning:
1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside.
2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone.

### `nnx.scan` over layers
### Scan over layers

The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input.

Expand Down Expand Up @@ -272,7 +272,7 @@ y = model(jnp.ones((1, 3)))
nnx.display(model)
```

### `State` and `GraphDef`
### State and GraphDef

A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function:

Expand All @@ -285,7 +285,7 @@ graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
```

### `split`, `merge`, and `update`
### Split, merge, and update

Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows:

Expand Down Expand Up @@ -318,11 +318,11 @@ print(f'{model.count.value = }')

The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.

**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about.
**Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about.

+++

### Fine-grained `State` control
### Fine-grained State control

Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).

Expand Down
Loading