Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated NNX rnglib docstring #3980

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api_reference/flax.nnx/rnglib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ rnglib
.. currentmodule:: flax.nnx

.. autoclass:: Rngs
:members:
:members: __init__
.. autoclass:: RngStream
:members:
87 changes: 87 additions & 0 deletions flax/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,99 @@ def __call__(self) -> jax.Array:


class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]):
"""NNX rng container class. To instantiate the ``Rngs``, pass
in an integer, specifying the starting seed. ``Rngs`` can have
different "streams", allowing the user to generate different
rng keys. For example, to generate a key for the ``params``
and ``dropout`` stream::

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rng1 = nnx.Rngs(0, params=1)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() != rng2.dropout()

Because we passed in ``params=1``, the starting seed for
``params`` is ``1``, whereas the starting seed for ``dropout``
defaults to the ``0`` we passed in, since we didn't specify
a seed for ``dropout``. If we didn't specify a seed for ``params``,
then both streams will default to using the ``0`` we passed in::

>>> rng1 = nnx.Rngs(0)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() == rng2.dropout()

The ``Rngs`` container class contains a separate counter for
each stream. Every time the stream is called to generate a new rng
key, the counter increments by ``1``. To generate a new rng key,
we fold in the counter value for the current rng stream into its
corresponding starting seed. If we try to generate an rng key for
a stream we did not specify on instantiation, then the ``default``
stream is used (i.e. the first positional argument passed to ``Rngs``
during instantiation is the ``default`` starting seed)::

>>> rng1 = nnx.Rngs(100, params=42)
>>> # `params` stream starting seed is 42, counter is 0
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0)
>>> # `dropout` stream starting seed is defaulted to 100, counter is 0
>>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0)
>>> # empty stream starting seed is defaulted to 100, counter is 1
>>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1)
>>> # `params` stream starting seed is 42, counter is 1
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)

Let's see an example of using ``Rngs`` in a :class:`Module` and
verifying the output by manually threading the ``Rngs``::

>>> class Model(nnx.Module):
... def __init__(self, rngs):
... # Linear uses the `params` stream twice for kernel and bias
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... # Dropout uses the `dropout` stream once
... self.dropout = nnx.Dropout(0.5, rngs=rngs)
... def __call__(self, x):
... return self.dropout(self.linear(x))

>>> def assert_same(x, rng_seed, **rng_kwargs):
... model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs))
... out = model(x)
...
... # manual forward propagation
... rngs = nnx.Rngs(rng_seed, **rng_kwargs)
... kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
... assert (model.linear.kernel.value==kernel).all()
... bias = nnx.initializers.zeros_init()(rngs.params(), (3,))
... assert (model.linear.bias.value==bias).all()
... mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3))
... # dropout scales the output proportional to the dropout rate
... manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5
... assert (out == manual_out).all()

>>> x = jnp.ones((1, 2))
>>> assert_same(x, 0)
>>> assert_same(x, 0, params=1)
>>> assert_same(x, 0, params=1, dropout=2)
"""
def __init__(
self,
default: RngValue | RngDict | None = None,
/,
**rngs: RngValue,
):
"""
Args:
default: the starting seed for the ``default`` stream. Any
key generated from a stream that isn't specified in the
``**rngs`` key-word arguments will default to using this
starting seed.
**rngs: optional key-word arguments to specify starting
seeds for different rng streams. The key-word is the
stream name and its value is the corresponding starting
seed for that stream.
"""
if default is not None:
if isinstance(default, tp.Mapping):
rngs = {**default, **rngs}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ testing = [
"sentencepiece", # WMT/LM1B examples
"tensorflow_text>=2.11.0", # WMT/LM1B examples
"tensorflow_datasets",
"tensorflow",
"tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error
"torch",
"nbstripout",
"black[jupyter]==23.7.0",
Expand Down
Loading