|
172 | 172 | "\n",
|
173 | 173 | "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
|
174 | 174 | "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
|
175 |
| - "use the following filters:" |
| 175 | + "use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n", |
| 176 | + "to specify how `model`'s various substates should be vectorized:" |
176 | 177 | ]
|
177 | 178 | },
|
178 | 179 | {
|
|
182 | 183 | "metadata": {},
|
183 | 184 | "outputs": [],
|
184 | 185 | "source": [
|
185 |
| - "from functools import partial\n", |
| 186 | + "state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n", |
186 | 187 | "\n",
|
187 |
| - "@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n", |
| 188 | + "@nnx.vmap(in_axes=(state_axes, 0))\n", |
188 | 189 | "def forward(model, x):\n",
|
189 | 190 | " ..."
|
190 | 191 | ]
|
|
275 | 276 | "KeyPath = tuple[nnx.graph.Key, ...]\n",
|
276 | 277 | "\n",
|
277 | 278 | "def split(node, *filters):\n",
|
278 |
| - " graphdef, state, _ = nnx.graph.flatten(node)\n", |
| 279 | + " graphdef, state = nnx.graph.flatten(node)\n", |
279 | 280 | " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
|
280 | 281 | " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
|
281 | 282 | "\n",
|
|
0 commit comments