Skip to content

Commit

Permalink
Adds axis_name and axis_index_groups to LayerNorm and GroupNorm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 468499562
  • Loading branch information
Flax Team committed Aug 18, 2022
1 parent c3e5363 commit 2b73efd
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ class LayerNorm(Module):
scale_init: Initializer for scale, by default, one.
reduction_axes: Axes for computing normalization statistics.
feature_axes: Feature axes for learned bias and scaling.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
array being normalized is sharded across devices within a pmap.
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
"""
epsilon: float = 1e-6
dtype: Optional[Dtype] = None
Expand All @@ -306,6 +315,8 @@ class LayerNorm(Module):
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
reduction_axes: Axes = -1
feature_axes: Axes = -1
axis_name: Optional[str] = None
axis_index_groups: Any = None

@compact
def __call__(self, x):
Expand All @@ -317,8 +328,8 @@ def __call__(self, x):
Returns:
Normalized inputs (the same shape as inputs).
"""
# TODO(jheek) suport axis_name for model parallelism?
mean, var = _compute_stats(x, self.reduction_axes, self.dtype, None, None)
mean, var = _compute_stats(x, self.reduction_axes, self.dtype,
self.axis_name, self.axis_index_groups)

return _normalize(
self, x, mean, var, self.reduction_axes, self.feature_axes,
Expand Down Expand Up @@ -350,6 +361,15 @@ class GroupNorm(Module):
be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
array being normalized is sharded across devices within a pmap.
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
"""
num_groups: Optional[int] = 32
group_size: Optional[int] = None
Expand All @@ -360,6 +380,8 @@ class GroupNorm(Module):
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None

@compact
def __call__(self, x):
Expand Down Expand Up @@ -405,9 +427,9 @@ def broadcast_stat(stat):
(x.shape[0], num_groups, group_size))
return stat.reshape((x.shape[0], num_groups * group_size))

# TODO(jheek): suport axis_name for model parallelism?
mean, var = _compute_stats(
x.reshape(group_shape), reduction_axes, self.dtype, None, None)
x.reshape(group_shape), reduction_axes, self.dtype, self.axis_name,
self.axis_index_groups)
mean = broadcast_stat(mean)
var = broadcast_stat(var)

Expand Down

0 comments on commit 2b73efd

Please sign in to comment.