-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] jit constrain object state #3817
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,11 @@ class JitStaticOutputs: | |
|
||
jax.tree_util.register_static(JitStaticOutputs) | ||
|
||
def _default_constrain_object_state(state: State) -> State: | ||
state_spec = spmd.get_partition_spec(state) | ||
state = jax.lax.with_sharding_constraint(state, state_spec) | ||
return state | ||
|
||
|
||
@dataclasses.dataclass | ||
class JITOptions: | ||
|
@@ -152,7 +157,9 @@ class JITOptions: | |
backend: tp.Optional[str] | ||
inline: bool | ||
abstracted_axes: tp.Optional[tp.Any] | ||
# nnx specific | ||
donate_object_state: bool | ||
constrain_object_state: tp.Callable[[State], State] | None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document what this field does either here or elsewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
@classmethod | ||
def from_jit_kwargs( | ||
|
@@ -169,6 +176,7 @@ def from_jit_kwargs( | |
inline: bool, | ||
abstracted_axes: tp.Optional[tp.Any], | ||
donate_object_state: bool, | ||
constrain_object_state: bool | tp.Callable[[State], State], | ||
): | ||
_static_argnums = _normalize_sequence(static_argnums) | ||
_static_argnames = _normalize_sequence(static_argnames) | ||
|
@@ -178,6 +186,13 @@ def from_jit_kwargs( | |
if donate_object_state: | ||
_donate_argnames = (*_donate_argnames, '_nnx_jit_state') | ||
|
||
if callable(constrain_object_state): | ||
_constrain_object_state = constrain_object_state | ||
elif constrain_object_state: | ||
_constrain_object_state = _default_constrain_object_state | ||
else: | ||
_constrain_object_state = None | ||
|
||
return cls( | ||
in_shardings=in_shardings, | ||
out_shardings=out_shardings, | ||
|
@@ -191,11 +206,13 @@ def from_jit_kwargs( | |
inline=inline, | ||
abstracted_axes=abstracted_axes, | ||
donate_object_state=donate_object_state, | ||
constrain_object_state=_constrain_object_state, | ||
) | ||
|
||
def get_jit_kwargs(self) -> dict[str, tp.Any]: | ||
kwargs = vars(self).copy() | ||
del kwargs['donate_object_state'] | ||
del kwargs['constrain_object_state'] | ||
if kwargs['in_shardings'] is UNSPECIFIED: | ||
kwargs.pop('in_shardings') | ||
if kwargs['out_shardings'] is UNSPECIFIED: | ||
|
@@ -219,6 +236,9 @@ def __call__( | |
backend: tp.Optional[str] = None, | ||
inline: bool = False, | ||
abstracted_axes: tp.Optional[tp.Any] = None, | ||
# nnx specific | ||
donate_object_state: bool = False, | ||
constrain_object_state: bool | tp.Callable[[State], State] = False, | ||
) -> tp.Callable[..., 'JIT[M]']: | ||
super_call = super().__call__ | ||
|
||
|
@@ -237,6 +257,8 @@ def _create_jit(*args, **kwargs) -> JIT[M]: | |
backend=backend, | ||
inline=inline, | ||
abstracted_axes=abstracted_axes, | ||
# nnx specific | ||
donate_object_state=donate_object_state, | ||
# submodule args | ||
module_init_args=args, | ||
module_init_kwargs=kwargs, | ||
|
@@ -267,7 +289,10 @@ def jitted_fn( | |
**kwargs: tp.Any, | ||
): | ||
graphdef = _nnx_jit_static.graphdef | ||
state = _nnx_jit_state | ||
state: State = _nnx_jit_state | ||
|
||
if options.constrain_object_state is not None: | ||
state = options.constrain_object_state(state) | ||
|
||
input_graph_nodes, outer_idx_inner_ref = graph_utils.graph_unflatten( | ||
graphdef, state | ||
|
@@ -287,6 +312,10 @@ def jitted_fn( | |
outer_idx_inner_idx = graph_utils.compose_mapping( | ||
outer_idx_inner_ref, inner_ref_inner_idx | ||
) | ||
|
||
if options.constrain_object_state is not None: | ||
state = options.constrain_object_state(state) | ||
|
||
output_static = JitStaticOutputs(graphdef, outer_idx_inner_idx) | ||
out = (out, state, output_static) | ||
return out | ||
|
@@ -343,10 +372,12 @@ def __init__( | |
backend: tp.Optional[str] = None, | ||
inline: bool = False, | ||
abstracted_axes: tp.Optional[tp.Any] = None, | ||
# nnx specific | ||
donate_object_state: bool = False, | ||
constrain_object_state: bool | tp.Callable[[State], State] = False, | ||
# submodule args | ||
module_init_args: tuple[tp.Any, ...], | ||
module_init_kwargs: dict[str, tp.Any], | ||
donate_object_state: bool = False, | ||
): | ||
self.options = JITOptions.from_jit_kwargs( | ||
in_shardings=in_shardings, | ||
|
@@ -361,6 +392,7 @@ def __init__( | |
inline=inline, | ||
abstracted_axes=abstracted_axes, | ||
donate_object_state=donate_object_state, | ||
constrain_object_state=constrain_object_state, | ||
) | ||
self.accessor: tp.Optional[DelayedAccessor] = None | ||
|
||
|
@@ -391,7 +423,7 @@ def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: | |
|
||
|
||
def jit( | ||
f: F, | ||
fun: F, | ||
*, | ||
in_shardings: tp.Any = UNSPECIFIED, | ||
out_shardings: tp.Any = UNSPECIFIED, | ||
|
@@ -404,12 +436,137 @@ def jit( | |
backend: tp.Optional[str] = None, | ||
inline: bool = False, | ||
abstracted_axes: tp.Optional[tp.Any] = None, | ||
is_init: tp.Optional[bool] = None, | ||
# nnx specific | ||
donate_object_state: bool = False, | ||
constrain_object_state: bool | tp.Callable[[State], State] = False, | ||
) -> F: | ||
if is_init is None: | ||
is_init = f.__name__ == '__init__' | ||
""" | ||
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as | ||
arguments. | ||
|
||
Args: | ||
fun: Function to be jitted. ``fun`` should be a pure function, as | ||
side-effects may only be executed once. | ||
|
||
The arguments and return value of ``fun`` should be arrays, | ||
scalars, or (nested) standard Python containers (tuple/list/dict) thereof. | ||
Positional arguments indicated by ``static_argnums`` can be anything at | ||
all, provided they are hashable and have an equality operation defined. | ||
Static arguments are included as part of a compilation cache key, which is | ||
why hash and equality operators must be defined. | ||
|
||
JAX keeps a weak reference to ``fun`` for use as a compilation cache key, | ||
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` | ||
objects will already satisfy this requirement. | ||
in_shardings: Pytree of structure matching that of arguments to ``fun``, | ||
with all actual arguments replaced by resource assignment specifications. | ||
It is also valid to specify a pytree prefix (e.g. one value in place of a | ||
whole subtree), in which case the leaves get broadcast to all values in | ||
that subtree. | ||
|
||
The ``in_shardings`` argument is optional. JAX will infer the shardings | ||
from the input :py:class:`jax.Array`'s and defaults to replicating the input | ||
if the sharding cannot be inferred. | ||
|
||
The valid resource assignment specifications are: | ||
- :py:class:`XLACompatibleSharding`, which will decide how the value | ||
will be partitioned. With this, using a mesh context manager is not | ||
required. | ||
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding | ||
it wants. | ||
For in_shardings, JAX will mark is as replicated but this behavior | ||
can change in the future. | ||
For out_shardings, we will rely on the XLA GSPMD partitioner to | ||
determine the output shardings. | ||
|
||
The size of every dimension has to be a multiple of the total number of | ||
resources assigned to it. This is similar to pjit's in_shardings. | ||
out_shardings: Like ``in_shardings``, but specifies resource | ||
assignment for function outputs. This is similar to pjit's | ||
out_shardings. | ||
|
||
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` | ||
will use GSPMD's sharding propagation to figure out what the sharding of the | ||
output(s) should be. | ||
static_argnums: An optional int or collection of ints that specify which | ||
positional arguments to treat as static (compile-time constant). | ||
Operations that only depend on static arguments will be constant-folded in | ||
Python (during tracing), and so the corresponding argument values can be | ||
any Python object. | ||
|
||
Static arguments should be hashable, meaning both ``__hash__`` and | ||
``__eq__`` are implemented, and immutable. Calling the jitted function | ||
with different values for these constants will trigger recompilation. | ||
Arguments that are not arrays or containers thereof must be marked as | ||
static. | ||
|
||
If neither ``static_argnums`` nor ``static_argnames`` is provided, no | ||
arguments are treated as static. If ``static_argnums`` is not provided but | ||
``static_argnames`` is, or vice versa, JAX uses | ||
:code:`inspect.signature(fun)` to find any positional arguments that | ||
correspond to ``static_argnames`` | ||
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are | ||
provided, ``inspect.signature`` is not used, and only actual | ||
parameters listed in either ``static_argnums`` or ``static_argnames`` will | ||
be treated as static. | ||
static_argnames: An optional string or collection of strings specifying | ||
which named arguments to treat as static (compile-time constant). See the | ||
comment on ``static_argnums`` for details. If not | ||
provided but ``static_argnums`` is set, the default is based on calling | ||
``inspect.signature(fun)`` to find corresponding named arguments. | ||
donate_argnums: Specify which positional argument buffers are "donated" to | ||
the computation. It is safe to donate argument buffers if you no longer | ||
need them once the computation has finished. In some cases XLA can make | ||
use of donated buffers to reduce the amount of memory needed to perform a | ||
computation, for example recycling one of your input buffers to store a | ||
result. You should not reuse buffers that you donate to a computation, JAX | ||
will raise an error if you try to. By default, no argument buffers are | ||
donated. | ||
|
||
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no | ||
arguments are donated. If ``donate_argnums`` is not provided but | ||
``donate_argnames`` is, or vice versa, JAX uses | ||
:code:`inspect.signature(fun)` to find any positional arguments that | ||
correspond to ``donate_argnames`` | ||
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are | ||
provided, ``inspect.signature`` is not used, and only actual | ||
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will | ||
be donated. | ||
|
||
For more details on buffer donation see the | ||
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_. | ||
donate_argnames: An optional string or collection of strings specifying | ||
which named arguments are donated to the computation. See the | ||
comment on ``donate_argnums`` for details. If not | ||
provided but ``donate_argnums`` is set, the default is based on calling | ||
``inspect.signature(fun)`` to find corresponding named arguments. | ||
keep_unused: If `False` (the default), arguments that JAX determines to be | ||
unused by `fun` *may* be dropped from resulting compiled XLA executables. | ||
Such arguments will not be transferred to the device nor provided to the | ||
underlying executable. If `True`, unused arguments will not be pruned. | ||
device: This is an experimental feature and the API is likely to change. | ||
Optional, the Device the jitted function will run on. (Available devices | ||
can be retrieved via :py:func:`jax.devices`.) The default is inherited | ||
from XLA's DeviceAssignment logic and is usually to use | ||
``jax.devices()[0]``. | ||
backend: This is an experimental feature and the API is likely to change. | ||
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or | ||
``'tpu'``. | ||
inline: Specify whether this function should be inlined into enclosing | ||
jaxprs (rather than being represented as an application of the xla_call | ||
primitive with its own subjaxpr). Default False. | ||
donate_object_state: Optional, bool. If True, the object state of the | ||
graph node's state will be donated to the computation. Default False. | ||
constrain_object_state: Optional, bool or callable. If True, the object | ||
state of the graph node's state will be constrained to the partition | ||
specified by the graph node's partition spec as computed by | ||
:func:`nnx.spmd.get_partition_spec`. If a callable, the object State will | ||
passed to the callable which must return the constrained object State. If | ||
False, the object state will not be constrained. Default False. | ||
|
||
Returns: | ||
A wrapped version of ``fun``, set up for just-in-time compilation. | ||
""" | ||
options = JITOptions.from_jit_kwargs( | ||
in_shardings=in_shardings, | ||
out_shardings=out_shardings, | ||
|
@@ -423,10 +580,11 @@ def jit( | |
inline=inline, | ||
abstracted_axes=abstracted_axes, | ||
donate_object_state=donate_object_state, | ||
constrain_object_state=constrain_object_state, | ||
) | ||
jitted_fn = get_jitted_fn(f, options) | ||
jitted_fn = get_jitted_fn(fun, options) | ||
|
||
@functools.wraps(f) | ||
@functools.wraps(fun) | ||
def jit_apply_wrapper(*args, **kwargs): | ||
return jit_apply(options, jitted_fn, args, kwargs) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is WSC a noop if there is no mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, it will crash.