Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
525c3d8 by Cristian Garcia <[email protected]>:

add promote_dtype as a config option for multiple layers

PiperOrigin-RevId: 735924878
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Mar 11, 2025
1 parent 187b910 commit b252bb2
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 39 deletions.
3 changes: 2 additions & 1 deletion flax/linen/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
"""APIs for handling dtypes in Linen Modules."""

from typing import Any
from typing import Any, TypeVar
from flax.typing import Dtype
from jax import numpy as jnp

T = TypeVar('T', bound=tuple)

def canonicalize_dtype(
*args, dtype: Dtype | None = None, inexact: bool = True
Expand Down
119 changes: 90 additions & 29 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@

"""Linear modules."""

from typing import (
Any,
)
from collections.abc import Iterable, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from jax import eval_shape, lax
from jax.core import ShapedArray

import opt_einsum
from typing import Any, Protocol

from flax.core import meta
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen import module
from flax.linen.dtypes import promote_dtype
from flax.linen.module import Module, compact
from flax.typing import (
Array,
PRNGKey as PRNGKey,
Dtype,
Shape as Shape,
Initializer,
PrecisionLike,
DotGeneralT,
ConvGeneralDilatedT,
PaddingLike,
LaxPadding,
Array,
ConvGeneralDilatedT,
DotGeneralT,
Dtype,
Initializer,
LaxPadding,
PRNGKey as PRNGKey,
PaddingLike,
PrecisionLike,
Shape as Shape,
)
import jax
from jax import eval_shape, lax
from jax.core import ShapedArray
import jax.numpy as jnp
import numpy as np
import opt_einsum

class PromoteDtypeFn(Protocol):
def __call__(
self, *args: jax.Array | None, dtype: Any = None, inexact: bool = True
) -> list[jax.Array | None]: ...

default_kernel_init = initializers.lecun_normal()

Expand Down Expand Up @@ -94,6 +94,10 @@ class DenseGeneral(Module):
bias_init: initializer function for the bias.
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

features: int | Sequence[int]
Expand All @@ -105,6 +109,7 @@ class DenseGeneral(Module):
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()
precision: PrecisionLike = None
promote_dtype: PromoteDtypeFn = promote_dtype
# Deprecated. Will be removed.
dot_general: DotGeneralT | None = None
dot_general_cls: Any = None
Expand Down Expand Up @@ -181,7 +186,9 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
else:
bias = None

inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
Expand Down Expand Up @@ -225,6 +232,10 @@ class Dense(Module):
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

features: int
Expand All @@ -234,6 +245,7 @@ class Dense(Module):
precision: PrecisionLike = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()
promote_dtype: PromoteDtypeFn = promote_dtype
# Deprecated. Will be removed.
dot_general: DotGeneralT | None = None
dot_general_cls: Any = None
Expand All @@ -260,7 +272,11 @@ def __call__(self, inputs: Array) -> Array:
)
else:
bias = None
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)
assert inputs is not None
assert kernel is not None

if self.dot_general_cls is not None:
dot_general = self.dot_general_cls()
Expand Down Expand Up @@ -306,6 +322,10 @@ class Einsum(Module):
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

shape: Shape
Expand All @@ -316,6 +336,7 @@ class Einsum(Module):
precision: PrecisionLike = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()
promote_dtype: PromoteDtypeFn = promote_dtype

@compact
def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array:
Expand Down Expand Up @@ -361,7 +382,9 @@ def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array:
else:
bias = None

inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)

y = jnp.einsum(einsum_str, inputs, kernel, precision=self.precision)

Expand Down Expand Up @@ -469,6 +492,10 @@ class _Conv(Module):
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

features: int
Expand All @@ -485,6 +512,7 @@ class _Conv(Module):
precision: PrecisionLike = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()
promote_dtype: PromoteDtypeFn = promote_dtype
# Deprecated. Will be removed.
conv_general_dilated: ConvGeneralDilatedT | None = None
conv_general_dilated_cls: Any = None
Expand Down Expand Up @@ -647,7 +675,12 @@ def maybe_broadcast(
else:
bias = None

inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)
assert inputs is not None
assert kernel is not None

if self.shared_weights:
if self.conv_general_dilated_cls is not None:
conv_general_dilated = self.conv_general_dilated_cls()
Expand Down Expand Up @@ -749,6 +782,10 @@ class Conv(_Conv):
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

@property
Expand Down Expand Up @@ -816,6 +853,10 @@ class ConvLocal(_Conv):
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

@property
Expand Down Expand Up @@ -879,6 +920,10 @@ class ConvTranspose(Module):
bias_init: initializer for the bias.
transpose_kernel: if ``True`` flips spatial axes and swaps the input/output
channel axes of the kernel.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
and a ``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
"""

features: int
Expand All @@ -894,6 +939,7 @@ class ConvTranspose(Module):
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()
transpose_kernel: bool = False
promote_dtype: PromoteDtypeFn = promote_dtype

@compact
def __call__(self, inputs: Array) -> Array:
Expand Down Expand Up @@ -976,7 +1022,11 @@ def maybe_broadcast(
else:
bias = None

inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
inputs, kernel, bias = self.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)
assert inputs is not None
assert kernel is not None

y = lax.conv_transpose(
inputs,
Expand Down Expand Up @@ -1089,13 +1139,18 @@ class Embed(Module):
dtype: the dtype of the embedding vectors (default: same as embedding).
param_dtype: the dtype passed to parameter initializers (default: float32).
embedding_init: embedding initializer.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(embedding,)`` during ``__call__``
or ``(query, embedding)`` during ``attend``, and a ``dtype`` keyword argument,
and return a tuple of arrays with the promoted dtype.
"""

num_embeddings: int
features: int
dtype: Dtype | None = None
param_dtype: Dtype = jnp.float32
embedding_init: Initializer = default_embed_init
promote_dtype: PromoteDtypeFn = promote_dtype

def setup(self):
self.embedding = self.param(
Expand All @@ -1120,9 +1175,10 @@ def __call__(self, inputs: Array) -> Array:
raise ValueError('Input type must be an integer or unsigned integer.')
# Use take because fancy indexing numpy arrays with JAX indices does not
# work correctly.
(embedding,) = promote_dtype(
(embedding,) = self.promote_dtype(
self.embedding, dtype=self.dtype, inexact=False
)
assert embedding is not None
if self.num_embeddings == 1:
return jnp.broadcast_to(embedding, inputs.shape + (self.features,))
return jnp.take(embedding, inputs, axis=0)
Expand All @@ -1140,5 +1196,10 @@ def attend(self, query: Array) -> Array:
Commonly used for weight-sharing between embeddings and logit transform
in NLP models.
"""
query, embedding = promote_dtype(query, self.embedding, dtype=self.dtype)
embedding: Array
query, embedding = self.promote_dtype(
query, self.embedding, dtype=self.dtype
)
assert query is not None
assert embedding is not None
return jnp.dot(query, embedding.T)
12 changes: 11 additions & 1 deletion flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from flax.nnx import rnglib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import initializers
from flax.nnx.nn.dtypes import promote_dtype
from flax.nnx.nn import dtypes
from flax.nnx.nn.linear import (
LinearGeneral,
default_kernel_init,
)
from flax.nnx.nn.normalization import LayerNorm
from flax.typing import (
Dtype,
PromoteDtypeFn,
Shape,
Initializer,
PrecisionLike,
Expand All @@ -57,6 +58,7 @@ def dot_product_attention_weights(
dtype: Dtype | None = None,
precision: PrecisionLike = None,
module: Module | None = None,
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
):
"""Computes dot-product attention weights given query and key.
Expand Down Expand Up @@ -86,6 +88,9 @@ def dot_product_attention_weights(
module: the Module that will sow the attention weights into the
``nnx.Intermediate`` collection. If ``module`` is None, the attention
weights will not be sowed.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype``
keyword argument, and return a tuple of arrays with the promoted dtype.
Returns:
Output of shape `[batch..., num_heads, q_length, kv_length]`.
Expand Down Expand Up @@ -148,6 +153,7 @@ def dot_product_attention(
dtype: Dtype | None = None,
precision: PrecisionLike = None,
module: Module | None = None,
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
):
"""Computes dot-product attention given query, key, and value.
Expand Down Expand Up @@ -182,6 +188,10 @@ def dot_product_attention(
module: the Module that will sow the attention weights into the
``nnx.Intermediate`` collection. If ``module`` is None, the attention
weights will not be sowed.
promote_dtype: function to promote the dtype of the arrays to the desired
dtype. The function should accept a tuple of ``(query, key, value)`` and a
``dtype`` keyword argument, and return a tuple of arrays with the promoted
dtype.
Returns:
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
Expand Down
Loading

0 comments on commit b252bb2

Please sign in to comment.