Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Primitive gt requires argument replication types to match" when using checkify with shard_map #26887

Open
Findus23 opened this issue Mar 3, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@Findus23
Copy link

Findus23 commented Mar 3, 2025

Description

I am currently trying to debug a bit more complex multi-gpu code and am wondering if some issues are caused by bound checks.
I wanted to try checkify to help with this, but it seems like there are some bugs when using checkify in multi-GPU setups (especially with shard_map).

Thankfully I was able to reduce all issues to short reproducable snippets:

import os
import jax
from jax.experimental import checkify
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
mesh = Mesh(jax.devices(), ('gpus',))

def test_function_int(p):
    order = jax.numpy.argsort(p)
    # order = jax.numpy.array([1, 2, 3])
    return p[order]

test_function = shard_map(
    test_function_int,
    mesh=mesh,
    in_specs=P('gpus'),
    out_specs=P('gpus'),
    # check_rep=False,
)

test_function = jax.jit(test_function)

X = jax.numpy.zeros(1000)

err, delta = checkify.checkify(test_function, errors=checkify.index_checks)(X)
err.throw()

results in Primitive gt requires argument replication types to match, but got (set(), {'gpus'}).

The full output (with JAX_TRACEBACK_FILTERING=off)

Traceback (most recent call last):
  File "checkify_test.py", line 31, in <module>
    err, delta = checkify.checkify(test_function, errors=checkify.index_checks)(X)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: Exception: Primitive gt requires argument replication types to match, but got (set(), {'gpus'}). Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_rep=False argument to shard_map

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "checkify_test.py", line 31, in <module>
    err, delta = checkify.checkify(test_function, errors=checkify.index_checks)(X)
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 1230, in checked_fun
    error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
                      ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 396, in checkify_jaxpr
    return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
                               enabled_errors, err_tree, *err_vals, *args)
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 427, in checkify_jaxpr_flat
    error, outvals = checkify_rule(error, enabled_errors,
                     ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
                                   *invals, **eqn.params)
                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 910, in pjit_error_check
    checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                 ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
                                                         err_tree, *in_avals)
                                                         ^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 754, in jaxpr_to_checkify_jaxpr
    new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
                               ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 334, in wrapper
    return func(*args, **kwargs)
  File "venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 210, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 337, in _flatten_and_get_error_metadata_thunk
    error, out = f(*invals)
                 ~^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 427, in checkify_jaxpr_flat
    error, outvals = checkify_rule(error, enabled_errors,
                     ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
                                   *invals, **eqn.params)
                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 1011, in shard_map_error_check
    err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params)
  File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 461, in bind
    return self._true_bind(*args, **params)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/_src/core.py", line 520, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 466, in bind_with_trace
    return trace.process_shard_map(shard_map_p, fun, args, **params)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 524, in _shard_map_staging
    out_rep = _check_rep(mesh, jaxpr, in_rep)
  File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 647, in _check_rep
    out_rep = rule(mesh, *map(read, e.invars), **e.params)
  File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 1134, in _standard_check
    raise Exception(f"Primitive {prim} requires argument replication types "
    ...<2 lines>...
                    "workaround pass the check_rep=False argument to shard_map")
Exception: Primitive gt requires argument replication types to match, but got (set(), {'gpus'}). Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_rep=False argument to shard_map

Also unrelated, but is the " in print_environment_info() here intentional?
https://github.com/jax-ml/jax/blob/ed4a7bbab10aae1a38ea8127c1349962d13ddba4/jax/_src/environment_info.py#L49C106-L49C107

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.1
jaxlib: 0.5.1
numpy:  2.2.3
python: 3.13.2 (main, Feb  5 2025, 01:23:35) [GCC 14.2.0]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='lukasnotebook', release='6.12.12-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.12.12-1 (2025-02-02)', machine='x86_64')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant