-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" #20114
Comments
also the similar but slightly different
|
Same with jax-metal 0.0.6 |
I just ran into probably the same bug myself on jax-metal 0.1.1, doing
which gives me
And uninstalling jax-metal (so it's running on the CPU) fixes the immediate problem:
|
Similar case when I tried this Stein DMM exercise using numpyro and Mine works when |
This issue also affects the gradient computation for certain implementations of multi-query attention: import jax
import jax.numpy as jnp
from einops import einsum, rearrange
from flax import nnx
from flax.nnx.nn.linear import default_kernel_init
default_dtype = (
# jnp.bfloat16
jnp.float32
)
class MultiQueryAttention(nnx.Module):
def __init__(
self,
in_dim: int,
head_dim: int,
num_heads: int,
*,
num_kv_heads: int | None = None,
weight_init = default_kernel_init,
dtype = default_dtype,
rngs: nnx.Rngs,
):
self.in_dim = in_dim
self.head_dim = head_dim
self.num_heads = num_heads
num_kv_heads = num_kv_heads or num_heads
self.num_kv_heads = num_kv_heads
self.group_size = num_heads // num_kv_heads
self.scale = self.head_dim ** -0.5
self.W_q = nnx.Param(weight_init(rngs(), (in_dim, num_heads * head_dim), dtype))
self.W_k = nnx.Param(weight_init(rngs(), (in_dim, num_kv_heads * head_dim), dtype))
self.W_v = nnx.Param(weight_init(rngs(), (in_dim, num_kv_heads * head_dim), dtype))
self.W_o = nnx.Param(weight_init(rngs(), (num_heads * head_dim, in_dim), dtype))
def __call__(self, x: jax.Array) -> jax.Array:
q = rearrange(x @ self.W_q, "B T (K G H) -> B T K G H", K=self.num_kv_heads, G=self.group_size)
k = rearrange(x @ self.W_k, "B S (K H) -> B S K H", K=self.num_kv_heads)
v = rearrange(x @ self.W_v, "B S (K H) -> B S K H", K=self.num_kv_heads)
scores = einsum(q, k, "B T K G H, B S K H -> B T S K G")
weights = nnx.softmax(scores * self.scale, axis=2)
attn = einsum(weights, v, "B T S K G, B S K H -> B T K G H")
o = rearrange(attn, "B T K G H -> B T (K G H)") @ self.W_o
return o
def reproduce_mhlo_dot_general_error_on_metal():
# Following https://jax-ml.github.io/scaling-book/transformers/#transformer-accounting
B, T, D, K, G, H = (3, 512, 1024, 8, 2, 64)
for device in [
jax.devices("cpu")[0],
jax.devices("METAL")[0],
]:
with jax.default_device(device):
attn = MultiQueryAttention(
in_dim=D,
head_dim=H,
num_heads=K*G,
num_kv_heads=K,
rngs=nnx.Rngs(0),
)
x = jnp.ones((B, T, D), default_dtype)
@nnx.value_and_grad
def loss_fn(model):
return model(x).mean()
loss, grads = loss_fn(attn)
print(loss.item())
if __name__ == "__main__":
reproduce_mhlo_dot_general_error_on_metal() This succeeds for the CPU device but fails on metal with:
A workaround is to give k = rearrange(x @ self.W_k, "B S (K G H) -> B S K G H", K=self.num_kv_heads, G=1)
k_broadcast = jnp.broadcast_to(k, q.shape)
...
scores = einsum(q, k_broadcast, "B T K G H, B S K G H -> B T S K G") |
Description
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: