Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add extract APIs #4078

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions flax/nnx/nnx/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import abc
import contextlib
import dataclasses
import threading
import typing as tp

import jax
from jax._src.tree_util import broadcast_prefix

from flax import struct
from flax.nnx.nnx.state import State
from flax.typing import PathParts
from flax.nnx.nnx import graph


class Missing:
pass


MISSING = Missing()
A = tp.TypeVar('A')
E = tp.TypeVar('E', bound='Extractable')
Index = int
KeyEntry = tp.TypeVar('KeyEntry', bound=tp.Hashable)
KeyPath = tuple[KeyEntry, ...]
Prefix = tp.Any
Leaf = tp.Any


class Extractable(abc.ABC):
@property
@abc.abstractmethod
def index(self) -> Index: ...


class ExtractableStates(Extractable):
@property
@abc.abstractmethod
def states(self) -> tp.Iterable[State]: ...

@property
@abc.abstractmethod
def graphdef(self) -> graph.GraphDef[tp.Any]: ...


class ExtractionIndex(struct.PyTreeNode, Extractable):
"""Index of a graph node in a Pytree structure."""

_index: Index = struct.field(pytree_node=False)

@property
def index(self) -> Index:
return self._index


@tp.overload
def extract_graph_nodes(
pytree: A,
/,
*,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> tuple[A, tuple[tp.Any, ...]]: ...


@tp.overload
def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ...


def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any = MISSING,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> (
tuple[A, tuple[tp.Any, ...]]
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
nodes = graph.RefMap[tp.Any, Index]()
node_prefixes = []
leaves = []

prefix_leaves = broadcast_prefix(
prefix,
pytree,
is_leaf=lambda x: x is None,
)
key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree)

assert len(key_leaves) == len(prefix_leaves)

for (keypath, leaf), prefix_leaf in zip(key_leaves, prefix_leaves):
if validate_fn:
validate_fn(keypath, prefix_leaf, leaf)
if graph.is_graph_node(leaf):
if leaf not in nodes:
index = nodes[leaf] = len(nodes)
node_prefixes.append(prefix_leaf)
else:
index = nodes[leaf]
# check consistent aliasing
if prefix_leaf != node_prefixes[index]:
path_str = jax.tree_util.keystr(keypath)
raise ValueError(
f'Inconsistent aliasing detected. Node {type(leaf)} at path {path_str} '
f'has different prefixes: {prefix_leaf} and {node_prefixes[index]}.'
)
leaves.append(ExtractionIndex(index))
else:
leaves.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
else:
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, Extractable):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def extract_indexes(
pytree,
/,
types: tuple[type[E], ...] | type[E] = Extractable, # type: ignore[assignment]
) -> tuple[E, ...]:
"""Extracts all indexes from a pytree."""
indexes: list[E] = []
for x in jax.tree.leaves(
pytree, is_leaf=lambda x: isinstance(x, Extractable)
):
if isinstance(x, Extractable):
if not isinstance(x, types):
raise ValueError(f'Expected Extractable of type {types}, got {type(x)}')
indexes.append(x) # type: ignore[arg-type]
return tuple(indexes)


def replace_indexes(
pytree: A,
replace_fn: tp.Callable[[Extractable], tp.Any],
/,
clear: bool = False,
) -> A:
def _replace_map_fn(x):
if isinstance(x, Extractable):
return replace_fn(x)
elif clear:
return None
return x

return jax.tree_util.tree_map(
_replace_map_fn, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def merge_extractable_states(
extractable_states: tp.Sequence[ExtractableStates], /
):
if len(extractable_states) == 0:
raise ValueError('Expected at least one ExtractableStates object')

graphdef = extractable_states[0].graphdef
flat_state: list[tuple[PathParts, tp.Any]] = []

for extractable_state in extractable_states:
flat_state.extend(
((extractable_state.index, *path), value)
for state in extractable_state.states
for path, value in state.flat_state().items()
)

state = State.from_flat_path(flat_state)
return graphdef, state


def check_consistent_aliasing(
nodes: tuple[tp.Any, ...], prefixes: tuple[tp.Any, ...]
):
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()

# collect all paths and prefixes for each node
for node, prefix in zip(nodes, prefixes):
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value):
if value in node_prefixes:
paths_prefixes = node_prefixes[value]
paths_prefixes.append((path, prefix))
else:
node_prefixes[value] = [(path, prefix)]

# check for inconsistent aliasing
node_msgs = []
for node, paths_prefixes in node_prefixes.items():
unique_prefixes = {prefix for _, prefix in paths_prefixes}
if len(unique_prefixes) > 1:
path_prefix_repr = '\n'.join(
f' {"/".join(map(str,path)) if path else "<root>"}: {prefix}'
for path, prefix in paths_prefixes
)
nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}'
node_msgs.append(nodes_msg)

if node_msgs:
raise ValueError(
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
+ '\n'.join(node_msgs)
)

# -----------------------------
# broadcast
# -----------------------------


@dataclasses.dataclass
class BroadcastContext(threading.local):
broadcast_state_stacks: dict[str, list[tp.Any]] = dataclasses.field(
default_factory=dict
)


BROADCAST_CONTEXT = BroadcastContext()


@contextlib.contextmanager
def broadcast_state(tag: str, state: tp.Any):
if tag in BROADCAST_CONTEXT.broadcast_state_stacks:
stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]
else:
stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] = []
stack.append(state)
try:
yield
finally:
stack.pop()
if not stack:
del BROADCAST_CONTEXT.broadcast_state_stacks[tag]


def get_broadcast_state(tag: str) -> tp.Any:
if tag not in BROADCAST_CONTEXT.broadcast_state_stacks:
raise ValueError(f'No broadcast state found for {tag!r}')

stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]

if not stack:
raise RuntimeError(
f'Empty broadcast state stack for {tag!r}, this is a bug'
)

return stack[-1]
44 changes: 0 additions & 44 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,50 +1595,6 @@ class Static(tp.Generic[A]):

jax.tree_util.register_static(Static)

# ---------------------------------------------------------
# insert/extract_graph_nodes API
# ---------------------------------------------------------


@dataclasses.dataclass(frozen=True)
class GraphNodeIndex:
"""Index of a graph node in a Pytree structure."""

index: Index


jax.tree_util.register_static(GraphNodeIndex)


def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]:
"""Extracts all graph nodes from a pytree."""
nodes = RefMap[tp.Any, Index]()

def _maybe_extract(x):
if is_graph_node(x):
if x not in nodes:
index = nodes[x] = len(nodes)
else:
index = nodes[x]
return GraphNodeIndex(index)
return x

return jax.tree_util.tree_map(_maybe_extract, pytree), tuple(nodes)


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, GraphNodeIndex):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, GraphNodeIndex)
)


# ---------------------------------------------------------
# Pytree
# ---------------------------------------------------------
Expand Down
13 changes: 12 additions & 1 deletion flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ def flat_state(self) -> FlatState[V]:
return traversals.flatten_mapping(self._mapping)

@classmethod
def from_flat_path(cls, flat_state: tp.Mapping[PathParts, V], /) -> State:
def from_flat_path(
cls,
flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]],
/,
) -> State:
if not isinstance(flat_state, tp.Mapping):
flat_state = dict(flat_state)
nested_state = traversals.unflatten_mapping(flat_state)
return cls(nested_state)

Expand All @@ -176,7 +182,12 @@ def split(
*filters: filterlib.Filter,
) -> tuple[State[K, V], ...]: ...

@tp.overload
def split(
self, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: ...

def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Split a ``State`` into one or more ``State``'s. The
Expand Down
15 changes: 7 additions & 8 deletions flax/nnx/nnx/transforms/looping.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx.nnx import filterlib, graph, rnglib, spmd
from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.proxy_caller import DelayedAccessor
from flax.nnx.nnx.state import State
Expand Down Expand Up @@ -254,7 +254,7 @@ def scan_fn(
input_graph_nodes = ctx.merge(
graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state
)
(args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes)
(args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes)

out = f(*args, **kwargs)

Expand All @@ -271,10 +271,9 @@ def scan_fn(
carry_arg_out = out
scan_args_out = None

(
(carry_arg_out, scan_args_out),
output_graph_nodes,
) = graph.extract_graph_nodes((carry_arg_out, scan_args_out))
((carry_arg_out, scan_args_out), output_graph_nodes) = (
extract.extract_graph_nodes((carry_arg_out, scan_args_out))
)

# split module state
(
Expand Down Expand Up @@ -353,7 +352,7 @@ def scan(
@graph.update_context('scan')
def scan_apply_wrapper(*args, **kwargs):
# extract nodes
(args, kwargs), input_graph_nodes = graph.extract_graph_nodes(
(args, kwargs), input_graph_nodes = extract.extract_graph_nodes(
(args, kwargs)
)
input_rng_streams = rnglib.backup_keys(input_graph_nodes)
Expand Down Expand Up @@ -465,7 +464,7 @@ def scan_apply_wrapper(*args, **kwargs):
broadcast_rng_state_out,
)

carry_arg_out, scan_args_out = graph.insert_graph_nodes(
carry_arg_out, scan_args_out = extract.insert_graph_nodes(
(carry_arg_out, scan_args_out), output_graph_nodes
)

Expand Down
Loading
Loading