Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" #20114

Open
dlwh opened this issue Mar 7, 2024 · 6 comments
Open
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@dlwh
Copy link
Contributor

dlwh commented Mar 7, 2024

Description

>>> import jax.numpy as jnp
>>> a = jnp.ones((2,3,4))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:43:05.128623: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> b = jnp.ones((4,3,2))
>>> jnp.einsum("ijk,kji->k", a, b)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
    return _einsum_computation(operands, contractions, precision,  # type: ignore[operator]
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [2, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<2x3x4xf32>, tensor<4x3x2xf32>) -> tensor<4xf32>

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:45:49.148886: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
@dlwh dlwh added the bug Something isn't working label Mar 7, 2024
@dlwh dlwh changed the title Metal: failed to legalize operation 'mhlo.dot_general' Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" Mar 7, 2024
@dlwh
Copy link
Contributor Author

dlwh commented Mar 7, 2024

also the similar but slightly different

>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
    operands, contractions = contract_path(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
    raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)

@steeve
Copy link

steeve commented Mar 21, 2024

Same with jax-metal 0.0.6

@ramithuh
Copy link

ramithuh commented Mar 26, 2024

Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!

jax :0.4.25
jax-metal  :0.0.6
jaxlib :0.4.23
see current operation: %7513 = "mhlo.dot_general"(%7499, %7395) 
{dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], 
lhs_contracting_dimensions = [2, 3], rhs_contracting_dimensions = [1, 3]>, 
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : 
(tensor<1x31x20x5xf32>, tensor<1x20x31x5xf32>) -> tensor<1x31x31xf32>
Screenshot 2024-03-25 at 22 06 56

@danielpcox
Copy link

danielpcox commented Nov 9, 2024

I just ran into probably the same bug myself on jax-metal 0.1.1, doing

import jax.numpy as jnp
a = jnp.arange(12).reshape((3, 2, 2))
b = jnp.arange(12).reshape((3, 2, 2))
jnp.einsum("b...,b...->b", a, b)

which gives me

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 9343, in einsum
    return einsum(operands, contractions, precision,
  File "python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 2]>} : (tensor<3x2x2xsi32>, tensor<3x2x2xsi32>) -> tensor<3xsi32>

And uninstalling jax-metal (so it's running on the CPU) fixes the immediate problem:

Array([ 14, 126, 366], dtype=int32)

@jtdimasaka
Copy link

Similar case when I tried this Stein DMM exercise using numpyro and device = "METAL".

Mine works when device = "cpu". No need to uninstall jax-metal. I am using jax-metal = 0.1.0, Apple M3 Max.

@jkyl
Copy link

jkyl commented Feb 26, 2025

This issue also affects the gradient computation for certain implementations of multi-query attention:

import jax
import jax.numpy as jnp

from einops import einsum, rearrange
from flax import nnx
from flax.nnx.nn.linear import default_kernel_init

default_dtype = (
    # jnp.bfloat16
    jnp.float32
)

class MultiQueryAttention(nnx.Module):

    def __init__(
        self, 
        in_dim: int,
        head_dim: int, 
        num_heads: int, 
        *,
        num_kv_heads: int | None = None,
        weight_init = default_kernel_init,
        dtype = default_dtype,
        rngs: nnx.Rngs,
    ):
        self.in_dim = in_dim
        self.head_dim = head_dim
        self.num_heads = num_heads
        num_kv_heads = num_kv_heads or num_heads
        self.num_kv_heads = num_kv_heads
        self.group_size = num_heads // num_kv_heads
        self.scale = self.head_dim ** -0.5
        self.W_q = nnx.Param(weight_init(rngs(), (in_dim, num_heads * head_dim), dtype))
        self.W_k = nnx.Param(weight_init(rngs(), (in_dim, num_kv_heads * head_dim), dtype))
        self.W_v = nnx.Param(weight_init(rngs(), (in_dim, num_kv_heads * head_dim), dtype))
        self.W_o = nnx.Param(weight_init(rngs(), (num_heads * head_dim, in_dim), dtype))

    def __call__(self, x: jax.Array) -> jax.Array:
        q = rearrange(x @ self.W_q, "B T (K G H) -> B T K G H", K=self.num_kv_heads, G=self.group_size)    
        k = rearrange(x @ self.W_k, "B S (K H) -> B S K H", K=self.num_kv_heads)
        v = rearrange(x @ self.W_v, "B S (K H) -> B S K H", K=self.num_kv_heads)
        scores = einsum(q, k, "B T K G H, B S K H -> B T S K G")
        weights = nnx.softmax(scores * self.scale, axis=2)
        attn = einsum(weights, v, "B T S K G, B S K H -> B T K G H")
        o = rearrange(attn, "B T K G H -> B T (K G H)") @ self.W_o
        return o

def reproduce_mhlo_dot_general_error_on_metal():

    # Following https://jax-ml.github.io/scaling-book/transformers/#transformer-accounting
    B, T, D, K, G, H = (3, 512, 1024, 8, 2, 64)

    for device in [
        jax.devices("cpu")[0], 
        jax.devices("METAL")[0],
    ]:
        with jax.default_device(device):
            attn = MultiQueryAttention(
                in_dim=D, 
                head_dim=H, 
                num_heads=K*G, 
                num_kv_heads=K, 
                rngs=nnx.Rngs(0),
            )
            x = jnp.ones((B, T, D), default_dtype)

            @nnx.value_and_grad
            def loss_fn(model):
                return model(x).mean()
            
            loss, grads = loss_fn(attn)
            print(loss.item())


if __name__ == "__main__":
    reproduce_mhlo_dot_general_error_on_metal()

This succeeds for the CPU device but fails on metal with:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /opt/homebrew/Cellar/[email protected]/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85:23: error: failed to legalize operation 'mhlo.dot_general'
                return func(*args, **kwds)
                      ^
/Users/<redacted>/venv/lib/python3.13/site-packages/einops/_backends.py:193:15: note: called from
        return self.np.einsum(pattern, *x)
              ^
/Users/<redacted>/venv/lib/python3.13/site-packages/einops/einops.py:916:11: note: called from
    return get_backend(tensors[0]).einsum(pattern, *tensors)
          ^
/Users/<redacted>/repro_jax_metal_einsum_error.py:44:15: note: called from
        attn = einsum(weights, v, "B T S K G, B S K H -> B T K G H")
              ^
/Users/<redacted>/repro_jax_metal_einsum_error.py:69:23: note: called from
                return model(x).mean()
                      ^
/Users/<redacted>/venv/lib/python3.13/site-packages/flax/nnx/transforms/autodiff.py:86:10: note: called from
    out = self.f(*args)
         ^
/Users/<redacted>/venv/lib/python3.13/site-packages/flax/nnx/transforms/autodiff.py:164:15: note: called from
      fn_out = gradded_fn(*pure_args)
              ^
/Users/<redacted>/venv/lib/python3.13/site-packages/flax/nnx/graph.py:1082:15: note: called from
        return f(*args, **kwargs)
              ^
/Users/<redacted>/repro_jax_metal_einsum_error.py:71:26: note: called from
            loss, grads = loss_fn(attn)
                         ^
/Users/<redacted>/repro_jax_metal_einsum_error.py:76:4: note: called from
    reproduce_mhlo_dot_general_error_on_metal()
   ^
/opt/homebrew/Cellar/[email protected]/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85:23: note: see current operation: %16 = "mhlo.dot_general"(%2, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], rhs_batching_dimensions = [0, 3], lhs_contracting_dimensions = [3, 4], rhs_contracting_dimensions = [1, 4]>} : (tensor<3x8x64x512x2xf32>, tensor<3x512x512x8x2xf32>) -> tensor<3x8x64x512xf32>
                return func(*args, **kwds)

A workaround is to give k a dummy G axis and broadcast it explicitly before einsum:

k = rearrange(x @ self.W_k, "B S (K G H) -> B S K G H", K=self.num_kv_heads, G=1)
k_broadcast = jnp.broadcast_to(k, q.shape)
...
scores = einsum(q, k_broadcast, "B T K G H, B S K G H -> B T S K G")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

9 participants