Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM ERROR: Failed to infer result type(s). on Jax Metal 0.1.1 #25302

Open
dlwh opened this issue Dec 6, 2024 · 0 comments
Open

LLVM ERROR: Failed to infer result type(s). on Jax Metal 0.1.1 #25302

dlwh opened this issue Dec 6, 2024 · 0 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@dlwh
Copy link
Contributor

dlwh commented Dec 6, 2024

Description

LLVM ERROR: Failed to infer result type(s).

repro:

import jax
from jax import numpy as jnp
jax.print_environment_info()
def fn(s):
    ig = jax.lax.broadcast_in_dim(s, shape=(1024, 32, 1024), broadcast_dimensions=(1, 2))
    ih = jax.lax.transpose(ig, permutation=(2, 1, 0))
    return ih


s = jnp.zeros((32, 1024))
jax.jit(fn)(s)

removing jit or the broadcast or the transpose fixes (though in my real code it's not so easy to do that...)

System info (python version, jaxlib version, accelerator, etc.)

Apple M1 Pro

jax-metal 0.1.1

jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
device info: Metal-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-57.local', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants