Skip to content

Commit

Permalink
Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad.
Browse files Browse the repository at this point in the history
Passing a truthy value for this argument will cause JAX to raise an error.  It looks like this has been the case for a little more than a year -- see jax-ml/jax#19970 .

PiperOrigin-RevId: 735883853
  • Loading branch information
jburnim authored and Flax Authors committed Mar 11, 2025
1 parent 7c36812 commit cc8bb66
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 68 deletions.
24 changes: 8 additions & 16 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,6 @@ def f(scope, x, y):
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
vjp_variables: The vjpfun will return a cotangent vector for all
variable collections specified by this filter.
variables: other variables collections that are available inside `fn` but
Expand All @@ -496,6 +489,9 @@ def f(scope, x, y):
``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""
if reduce_axes:
raise NotImplementedError('reduce_axes argument to vjp is deprecated')
del reduce_axes

def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
vjp_vars, other_vars = variable_groups
Expand All @@ -512,7 +508,7 @@ def wrapper(vjp_vars, *args):
return y, (aux, repack_fn(scope))

y, bwd, (aux, out_vars) = jax.vjp(
wrapper, vjp_vars, *args, reduce_axes=reduce_axes, has_aux=True
wrapper, vjp_vars, *args, has_aux=True
)
treedef = jax.tree_util.tree_structure(scope)
bwd = jax.tree_util.Partial(functools.partial(_bwd_wrapper, treedef), bwd)
Expand Down Expand Up @@ -570,13 +566,6 @@ def f(scope, x, y):
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
variables: other variables collections that are available inside `fn` but
do not receive a cotangent.
rngs: the prngs that are available inside `fn`.
Expand All @@ -588,6 +577,10 @@ def f(scope, x, y):
``(primals_out, aux, grads)`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""
if reduce_axes:
raise NotImplementedError(
'reduce_axes argument to value_and_grad is deprecated')
del reduce_axes

def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
@functools.wraps(fn)
Expand All @@ -604,7 +597,6 @@ def wrapper(*args):
wrapper,
*args,
has_aux=True,
reduce_axes=reduce_axes,
)

inputs_grad = bwd(jax.numpy.ones_like(y))
Expand Down
35 changes: 11 additions & 24 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,13 +1465,6 @@ def vjp(
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
vjp_variables: The vjpfun will return a cotangent vector for all
variable collections specified by this filter.
variables: other variables collections that are available inside ``fn`` but
Expand All @@ -1489,14 +1482,17 @@ def vjp(
``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""
if reduce_axes:
raise NotImplementedError('reduce_axes argument to vjp is deprecated')
del reduce_axes

return lift_direct_transform(
lift.vjp,
(fn,),
mdl,
*primals,
multi_scope=multi_scope,
has_aux=has_aux,
reduce_axes=reduce_axes,
vjp_variables=vjp_variables,
variables=variables,
rngs=rngs,
Expand Down Expand Up @@ -1547,13 +1543,6 @@ def __call__(self, x, y):
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
grad will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a grad function that sums over the batch while ``grad(f, *args)``
will create a per-example grad.
variables: variables collections that are available inside ``fn`` but
do not receive a cotangent.
rngs: the prngs that are available inside ``fn``.
Expand All @@ -1565,6 +1554,10 @@ def __call__(self, x, y):
``(primals_out, aux), grads`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""
if reduce_axes:
raise NotImplementedError(
'reduce_axes argument to value_and_grad is deprecated')
del reduce_axes

grad_partial = functools.partial(
lift_direct_transform,
Expand All @@ -1573,7 +1566,6 @@ def __call__(self, x, y):
mdl,
*primals,
has_aux=has_aux,
reduce_axes=reduce_axes,
variables=variables,
rngs=rngs,
)
Expand Down Expand Up @@ -1640,13 +1632,6 @@ def __call__(self, x, y):
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
grad will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a grad function that sums over the batch while ``grad(f, *args)``
will create a per-example grad.
variables: variables collections that are available inside ``fn`` but
do not receive a cotangent.
rngs: the prngs that are available inside ``fn``.
Expand All @@ -1658,14 +1643,16 @@ def __call__(self, x, y):
``(grads, aux)`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""
if reduce_axes:
raise NotImplementedError('reduce_axes argument to grad is deprecated')
del reduce_axes

value_and_grad_partial = functools.partial(
value_and_grad,
fn,
mdl,
*primals,
has_aux=has_aux,
reduce_axes=reduce_axes,
variables=variables,
rngs=rngs,
)
Expand Down
21 changes: 8 additions & 13 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def _grad_general(
has_aux: bool,
holomorphic: bool,
allow_int: bool,
reduce_axes: tp.Sequence[AxisName],
return_value: bool,
) -> tp.Callable[..., tp.Any]:
transform = jax.value_and_grad if return_value else jax.grad
Expand Down Expand Up @@ -159,7 +158,6 @@ def _grad_split_fn(
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)

fn_out = gradded_fn(*pure_args)
Expand Down Expand Up @@ -304,14 +302,10 @@ def grad(
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
"""
if reduce_axes:
raise NotImplementedError('reduce_axes argument to grad is deprecated')
del reduce_axes

if isinstance(f, Missing):
return functools.partial(
Expand All @@ -320,15 +314,13 @@ def grad(
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
return_value=False,
)

Expand Down Expand Up @@ -364,22 +356,25 @@ def value_and_grad(
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
if reduce_axes:
raise NotImplementedError(
'reduce_axes argument to value_and_grad is deprecated')
del reduce_axes

if f is Missing:
return functools.partial(
value_and_grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
return_value=True,
)

Expand Down
23 changes: 8 additions & 15 deletions flax/nnx/transforms/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,6 @@ def _grad_general(
has_aux: bool,
holomorphic: bool,
allow_int: bool,
reduce_axes: tp.Sequence[AxisName],
wrt: filterlib.Filter,
return_value: bool,
) -> tp.Callable[..., tp.Any]:
Expand Down Expand Up @@ -1662,7 +1661,6 @@ def only_diff(path: tuple, value: tp.Any) -> bool:
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*args, f, graphdef, non_diff_state, has_aux, diff_args)

if return_value:
Expand Down Expand Up @@ -1748,25 +1746,20 @@ def grad(
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
wrt: Optional, filterlib.Filter. Filter to extract the differentiable state
of each graph node. Default is `nnx.Param`.
"""
if reduce_axes:
raise NotImplementedError('reduce_axes argument to grad is deprecated')
del reduce_axes

return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
wrt,
return_value=False,
)
Expand All @@ -1782,13 +1775,17 @@ def value_and_grad(
*,
wrt: filterlib.Filter = variablelib.Param,
) -> tp.Callable[..., tp.Any]:
if reduce_axes:
raise NotImplementedError(
'reduce_axes argument to value_and_grad is deprecated')
del reduce_axes

return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
wrt,
return_value=True,
)
Expand All @@ -1801,7 +1798,6 @@ def constructor(
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
return_value: bool = False,
*,
wrt: filterlib.Filter = variablelib.Param,
Expand All @@ -1813,7 +1809,6 @@ def _create_grad(*args, **kwargs):
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
return_value=return_value,
# submodule args
module_init_args=args,
Expand All @@ -1829,7 +1824,6 @@ def __init__(
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
*,
wrt: filterlib.Filter = variablelib.Param,
# submodule args
Expand All @@ -1847,7 +1841,6 @@ def __init__(
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
wrt=wrt,
)
def grad_call_apply(module, *args):
Expand Down

0 comments on commit cc8bb66

Please sign in to comment.