Skip to content

Commit

Permalink
Merge pull request #3769 from chiamp:note
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617213238
  • Loading branch information
Flax Authors committed Mar 19, 2024
2 parents b626099 + f14d8cc commit 8220154
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 75 deletions.
56 changes: 27 additions & 29 deletions flax/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ def build(self) -> A:
"""Create and return a copy of the original object with accumulated changes.
This method is to be called after making changes to the Cursor object.
NOTE: The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
.. note::
The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
Example::
Expand Down Expand Up @@ -300,19 +301,18 @@ def apply_update(
- The output is the new value (either modified by the ``update_fn`` or same as the
input value if the condition wasn't fulfilled)
NOTES:
- If the ``update_fn`` returns a modified value, this method will not recurse any further
down that branch to record changes. For example, if we intend to replace an attribute that points
to a dictionary with an int, we don't need to look for further changes inside the dictionary,
since the dictionary will be replaced anyways.
- The ``is`` operator is used to determine whether the return value is modified (by comparing it
to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists,
dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as
unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value.
- ``.apply_update`` WILL NOT call the ``update_fn`` to the value at the top-most level of
the pytree (i.e. the root node). The ``update_fn`` will first be called on the root node's
children, and then the pytree traversal will continue recursively from there.
.. note::
- If the ``update_fn`` returns a modified value, this method will not recurse any further
down that branch to record changes. For example, if we intend to replace an attribute that points
to a dictionary with an int, we don't need to look for further changes inside the dictionary,
since the dictionary will be replaced anyways.
- The ``is`` operator is used to determine whether the return value is modified (by comparing it
to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists,
dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as
unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value.
- ``.apply_update`` WILL NOT call the ``update_fn`` to the value at the top-most level of
the pytree (i.e. the root node). The ``update_fn`` will first be called on the root node's
children, and then the pytree traversal will continue recursively from there.
Example::
Expand Down Expand Up @@ -396,13 +396,12 @@ def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]':
error because the user should always expect this method to return the only object whose
corresponding key path and value fulfill the condition of the ``cond_fn``.
NOTES:
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child node
that fulfills the condition in ``cond_fn`` in a particular key path
- ``.find`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
.. note::
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child node
that fulfills the condition in ``cond_fn`` in a particular key path
- ``.find`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
Example::
Expand Down Expand Up @@ -485,13 +484,12 @@ def find_all(
by ``'/'``) and value at that current key path
- The output is a boolean, denoting whether to return the child Cursor object at this path
NOTES:
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child nodes
that fulfill the condition in ``cond_fn`` in a particular key path
- ``.find_all`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
.. note::
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child nodes
that fulfill the condition in ``cond_fn`` in a particular key path
- ``.find_all`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
Example::
Expand Down
15 changes: 8 additions & 7 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,16 @@ def dot_product_attention(
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
.. note::
``query``, ``key``, ``value`` needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
num_heads, v_depth_per_head]`.
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
value: values to be used in attention with shape of ``[batch..., kv_length,
num_heads, v_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
Expand Down
3 changes: 2 additions & 1 deletion flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def dot_product_attention(
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
.. note::
``query``, ``key``, ``value`` needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
Expand Down
9 changes: 5 additions & 4 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,10 +2653,11 @@ def perturb(
intermediate gradients of ``value`` by running ``jax.grad`` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
.. note::
This is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
Example::
Expand Down
53 changes: 28 additions & 25 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ def __call__(
):
"""Normalizes the input using batch statistics.
NOTE:
During initialization (when ``self.is_initializing()`` is ``True``) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with ``axis_name``) does not have
to exist.
.. note::
During initialization (when ``self.is_initializing()`` is ``True``) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with ``axis_name``) does not have
to exist.
Args:
x: the input to be normalized.
Expand Down Expand Up @@ -389,9 +389,10 @@ class LayerNorm(Module):
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
NOTE: This normalization operation is identical to InstanceNorm and GroupNorm;
the difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
.. note::
This normalization operation is identical to InstanceNorm and GroupNorm;
the difference is simply which axes are reduced and the shape of the feature
axes (i.e. the shape of the learnable scale and bias parameters).
Example usage::
Expand Down Expand Up @@ -602,8 +603,9 @@ class GroupNorm(Module):
The user should either specify the total number of channel groups or the
number of channels per group.
NOTE: LayerNorm is a special case of GroupNorm where ``num_groups=1``, and
InstanceNorm is a special case of GroupNorm where ``group_size=1``.
.. note::
LayerNorm is a special case of GroupNorm where ``num_groups=1``, and
InstanceNorm is a special case of GroupNorm where ``group_size=1``.
Example usage::
Expand Down Expand Up @@ -778,9 +780,10 @@ class InstanceNorm(Module):
within each channel within each example close to 0 and the activation standard
deviation close to 1.
NOTE: This normalization operation is identical to LayerNorm and GroupNorm; the
difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
.. note::
This normalization operation is identical to LayerNorm and GroupNorm; the
difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
Example usage::
Expand Down Expand Up @@ -903,17 +906,17 @@ class SpectralNorm(Module):
where each wrapped layer will have its params spectral normalized before
computing its ``__call__`` output.
Usage Note:
The initialized variables dict will contain, in addition to a 'params'
collection, a separate 'batch_stats' collection that will contain a
``u`` vector and ``sigma`` value, which are intermediate values used
when performing spectral normalization. During training, we pass in
``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u``
and ``sigma`` are updated with the most recently computed values using
power iteration. This will help the power iteration method approximate
the true singular value more accurately over time. During eval, we pass
in ``update_stats=False`` to ensure we get deterministic behavior from
the model. For example::
.. note::
The initialized variables dict will contain, in addition to a 'params'
collection, a separate 'batch_stats' collection that will contain a
``u`` vector and ``sigma`` value, which are intermediate values used
when performing spectral normalization. During training, we pass in
``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u``
and ``sigma`` are updated with the most recently computed values using
power iteration. This will help the power iteration method approximate
the true singular value more accurately over time. During eval, we pass
in ``update_stats=False`` to ensure we get deterministic behavior from
the model.
Example usage::
Expand Down
8 changes: 5 additions & 3 deletions flax/linen/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
NOTE: Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
.. note::
Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
Args:
inputs: input data with dimensions (batch, window dims..., features).
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ class ConvLSTMCell(RNNCellBase):
i_t, f_t, o_t are input, forget and output gate activations,
and g_t is a vector of cell updates.
Notes:
.. note::
Forget gate initialization:
Following jozefowicz2015empirical we add 1.0 to b_f
after initialization in order to reduce the scale of forgetting in
Expand Down
9 changes: 6 additions & 3 deletions flax/linen/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
class Dropout(Module):
"""Create a dropout layer.
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization. Example usage::
.. note::
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
Expand Down
5 changes: 3 additions & 2 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def field(pytree_node=True, **kwargs):
def dataclass(clz: _T, **kwargs) -> _T:
"""Create a class which can be passed to functional transformations.
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
.. note::
Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
immutable and can be mapped over using the ``jax.tree_util`` methods.
Expand Down

0 comments on commit 8220154

Please sign in to comment.