Skip to content

Commit

Permalink
added Einsum layer
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Mar 4, 2024
1 parent c83ee86 commit 8e56442
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/api_reference/flax.linen/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Linear Modules
:module: flax.linen
:class: ConvLocal

.. flax_module::
:module: flax.linen
:class: Einsum

.. flax_module::
:module: flax.linen
:class: Embed
Expand Down Expand Up @@ -153,6 +157,7 @@ BatchApply
Conv
ConvTranspose
ConvLocal
Einsum
Embed
BatchNorm
LayerNorm
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
Conv as Conv,
DenseGeneral as DenseGeneral,
Dense as Dense,
Einsum as Einsum,
Embed as Embed,
)
from .module import (
Expand Down
126 changes: 125 additions & 1 deletion flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
from jax import eval_shape, lax
from jax.core import ShapedArray

import opt_einsum

from flax.core import meta
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen import module
from flax.linen.module import Module, compact
from flax.typing import (
Array,
Expand Down Expand Up @@ -282,6 +285,128 @@ def __call__(self, inputs: Array) -> Array:
return y


class Einsum(Module):
"""An einsum transformation with learnable kernel and bias.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde')
>>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}
Attributes:
shape: the shape of the kernel.
einsum_str: a string to denote the einsum equation. The equation must
have exactly two operands, the lhs being the input passed in, and
the rhs being the learnable kernel. Exactly one of ``einsum_str``
in the constructor argument and call argument must be not None,
while the other must be None.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""

shape: Shape
einsum_str: Optional[str] = None
use_bias: bool = True
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = initializers.zeros_init()

@compact
def __call__(self, inputs: Array, einsum_str: Optional[str] = None) -> Array:
"""Applies a linear transformation to the inputs along the last dimension.
Args:
inputs: The nd-array to be transformed.
einsum_str: a string to denote the einsum equation. The equation must
have exactly two operands, the lhs being the input passed in, and
the rhs being the learnable kernel. Exactly one of ``einsum_str``
in the constructor argument and call argument must be not None,
while the other must be None.
Returns:
The transformed input.
"""
einsum_str = module.merge_param('einsum_str', self.einsum_str, einsum_str)

einsum_str = einsum_str.replace(' ', '')
if '->' not in einsum_str:
raise ValueError(
'`einsum_str` equation must be explicit and include "->".'
)
if einsum_str.count(',') != 1:
raise ValueError(
'`einsum_str` equation must have exactly two operands and '
'therefore, exactly one comma character, instead of '
f'{einsum_str.count(",")}'
)

kernel = self.param(
'kernel',
self.kernel_init,
self.shape,
self.param_dtype,
)

if self.use_bias:
bias_shape, broadcasted_bias_shape = self._get_bias_shape(
einsum_str, inputs, kernel
)
bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype)
else:
bias = None

inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)

y = jnp.einsum(einsum_str, inputs, kernel, precision=self.precision)

if bias is not None:
y += jnp.reshape(bias, broadcasted_bias_shape)
return y

def _get_bias_shape(self, einsum_str: str, lhs: Array, rhs: Array):
"""Infer the bias shape and broadcasted bias shape given the ``einsum_str``,
``lhs`` and ``rhs`` arrays. This is needed for instantiating the bias
parameter and adding the bias to the output during forward inference.
This function first replaces all ellipses with actual letter characters,
then computes the bias shape by checking to see which axes in the rhs
array remain in the resulting array after einsumming. These axes are the
embedding/feature dimensions, and all other axes in rhs are reduction axes.
"""
# More details on the parsing function: https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/parser.py#L246
# returns the einsum string representation of the operands and result, with
# ellipsis replaced by actual letter characters
operands_str, result_str, _ = opt_einsum.parser.parse_einsum_input(
(einsum_str, lhs, rhs)
)

# rhs_dict is a dict{character:index} mapping that maps every character in
# the rhs einsum string representation to its corresponding index position in the string
rhs_dict = {c: i for i, c in enumerate(operands_str.split(',')[1])}
assert len(rhs_dict) == len(self.shape)

broadcasted_bias_shape = [1] * len(result_str)
bias_shape = []
for i, c in enumerate(result_str):
if c in rhs_dict:
broadcasted_bias_shape[i] = self.shape[rhs_dict[c]]
bias_shape.append(self.shape[rhs_dict[c]])

return bias_shape, broadcasted_bias_shape


def _conv_dimension_numbers(input_shape):
"""Computes the dimension numbers based on the input shape."""
ndim = len(input_shape)
Expand All @@ -291,7 +416,6 @@ def _conv_dimension_numbers(input_shape):
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)



def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
""" "Canonicalizes conv padding to a jax.lax supported format."""
if isinstance(padding, str):
Expand Down
Loading

0 comments on commit 8e56442

Please sign in to comment.