Skip to content

Commit 5d31452

Browse files
author
Flax Authors
committed
Merge pull request #4250 from google:enable-notebook-doctest
PiperOrigin-RevId: 682363934
2 parents 2d64500 + 146d2ff commit 5d31452

File tree

6 files changed

+20
-12
lines changed

6 files changed

+20
-12
lines changed

docs/conf.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137

138138
# -- Options for myst ----------------------------------------------
139139
# uncomment line below to avoid running notebooks during development
140-
nb_execution_mode = 'off'
140+
# nb_execution_mode = 'off'
141141
# Notebook cell execution timeout; defaults to 30.
142142
nb_execution_timeout = 100
143143
# List of patterns, relative to source directory, that match notebook
@@ -147,6 +147,8 @@
147147
'quick_start.ipynb', # <-- times out
148148
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
149149
'flax/nnx', # exclude nnx
150+
'guides/quantization/fp8_basics.ipynb',
151+
'guides/training_techniques/use_checkpointing.ipynb', # TODO(IvyZX): needs to be updated
150152
]
151153
# raise exceptions on execution so CI can catch errors
152154
nb_execution_allow_errors = False

docs_nnx/conf.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137

138138
# -- Options for myst ----------------------------------------------
139139
# uncomment line below to avoid running notebooks during development
140-
nb_execution_mode = 'off'
140+
# nb_execution_mode = 'off'
141141
# Notebook cell execution timeout; defaults to 30.
142142
nb_execution_timeout = 100
143143
# List of patterns, relative to source directory, that match notebook
@@ -147,6 +147,10 @@
147147
'quick_start.ipynb', # <-- times out
148148
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
149149
'flax/nnx', # exclude nnx
150+
'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update
151+
'guides/why.ipynb', # TODO(cgarciae): broken, remove in favor on the new guide
152+
'guides/flax_gspmd.ipynb', # TODO(IvyZX): broken, needs to be updated
153+
'guides/surgery.ipynb', # TODO(IvyZX): broken, needs to be updated
150154
]
151155
# raise exceptions on execution so CI can catch errors
152156
nb_execution_allow_errors = False

docs_nnx/guides/demo.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
" pass\n",
8989
"\n",
9090
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
91-
"model.set_attributes(deterministic=False, use_running_average=False) # set flags\n",
91+
"model.set_attributes(use_running_average=False) # set flags\n",
9292
"y = model(jnp.ones((2, 4))) # call methods directly\n",
9393
"\n",
9494
"print(f'{model = }'[:500] + '\\n...')"

docs_nnx/guides/demo.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Count(nnx.Variable): # custom Variable types define the "collections"
4848
pass
4949
5050
model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method
51-
model.set_attributes(deterministic=False, use_running_average=False) # set flags
51+
model.set_attributes(use_running_average=False) # set flags
5252
y = model(jnp.ones((2, 4))) # call methods directly
5353
5454
print(f'{model = }'[:500] + '\n...')

docs_nnx/guides/filters_guide.ipynb

+5-4
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
"\n",
173173
"Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
174174
"and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
175-
"use the following filters:"
175+
"use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n",
176+
"to specify how `model`'s various substates should be vectorized:"
176177
]
177178
},
178179
{
@@ -182,9 +183,9 @@
182183
"metadata": {},
183184
"outputs": [],
184185
"source": [
185-
"from functools import partial\n",
186+
"state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n",
186187
"\n",
187-
"@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n",
188+
"@nnx.vmap(in_axes=(state_axes, 0))\n",
188189
"def forward(model, x):\n",
189190
" ..."
190191
]
@@ -275,7 +276,7 @@
275276
"KeyPath = tuple[nnx.graph.Key, ...]\n",
276277
"\n",
277278
"def split(node, *filters):\n",
278-
" graphdef, state, _ = nnx.graph.flatten(node)\n",
279+
" graphdef, state = nnx.graph.flatten(node)\n",
279280
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
280281
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
281282
"\n",

docs_nnx/guides/filters_guide.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li
9898

9999
Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters
100100
and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can
101-
use the following filters:
101+
use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`
102+
to specify how `model`'s various substates should be vectorized:
102103

103104
```{code-cell} ipython3
104-
from functools import partial
105+
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
105106
106-
@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})
107+
@nnx.vmap(in_axes=(state_axes, 0))
107108
def forward(model, x):
108109
...
109110
```
@@ -140,7 +141,7 @@ from typing import Any
140141
KeyPath = tuple[nnx.graph.Key, ...]
141142
142143
def split(node, *filters):
143-
graphdef, state, _ = nnx.graph.flatten(node)
144+
graphdef, state = nnx.graph.flatten(node)
144145
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
145146
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
146147

0 commit comments

Comments
 (0)