You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently trying to debug a bit more complex multi-gpu code and am wondering if some issues are caused by bound checks.
I wanted to try checkify to help with this, but it seems like there are some bugs when using checkify in multi-GPU setups (especially with shard_map).
Thankfully I was able to reduce all issues to short reproducable snippets:
results in Primitive gt requires argument replication types to match, but got (set(), {'gpus'}).
The full output (with JAX_TRACEBACK_FILTERING=off)
Traceback (most recent call last):
File "checkify_test.py", line 31, in <module>
err, delta = checkify.checkify(test_function, errors=checkify.index_checks)(X)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: Exception: Primitive gt requires argument replication types to match, but got (set(), {'gpus'}). Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_rep=False argument to shard_map
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "checkify_test.py", line 31, in <module>
err, delta = checkify.checkify(test_function, errors=checkify.index_checks)(X)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
File "venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 1230, in checked_fun
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 396, in checkify_jaxpr
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
enabled_errors, err_tree, *err_vals, *args)
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 427, in checkify_jaxpr_flat
error, outvals = checkify_rule(error, enabled_errors,
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
*invals, **eqn.params)
^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 910, in pjit_error_check
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
err_tree, *in_avals)
^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 754, in jaxpr_to_checkify_jaxpr
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 334, in wrapper
return func(*args, **kwargs)
File "venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "venv/lib/python3.13/site-packages/jax/_src/linear_util.py", line 210, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 337, in _flatten_and_get_error_metadata_thunk
error, out = f(*invals)
~^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 427, in checkify_jaxpr_flat
error, outvals = checkify_rule(error, enabled_errors,
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
*invals, **eqn.params)
^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/checkify.py", line 1011, in shard_map_error_check
err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params)
File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 461, in bind
return self._true_bind(*args, **params)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 466, in bind_with_trace
return trace.process_shard_map(shard_map_p, fun, args, **params)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 524, in _shard_map_staging
out_rep = _check_rep(mesh, jaxpr, in_rep)
File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 647, in _check_rep
out_rep = rule(mesh, *map(read, e.invars), **e.params)
File "venv/lib/python3.13/site-packages/jax/experimental/shard_map.py", line 1134, in _standard_check
raise Exception(f"Primitive {prim} requires argument replication types "
...<2 lines>...
"workaround pass the check_rep=False argument to shard_map")
Exception: Primitive gt requires argument replication types to match, but got (set(), {'gpus'}). Please open an issue at https://github.com/jax-ml/jax/issues and as a temporary workaround pass the check_rep=False argument to shard_map
Adding check_rep=False as mentioned in the Exception is indeed a workaround
This only happens if test_function is jitted
order = jax.numpy.array([1, 2, 3]) also causes the same issue
Description
I am currently trying to debug a bit more complex multi-gpu code and am wondering if some issues are caused by bound checks.
I wanted to try checkify to help with this, but it seems like there are some bugs when using checkify in multi-GPU setups (especially with shard_map).
Thankfully I was able to reduce all issues to short reproducable snippets:
results in
Primitive gt requires argument replication types to match, but got (set(), {'gpus'})
.The full output (with
JAX_TRACEBACK_FILTERING=off
)check_rep=False
as mentioned in the Exception is indeed a workaroundorder = jax.numpy.array([1, 2, 3])
also causes the same issueAlso unrelated, but is the
"
inprint_environment_info()
here intentional?https://github.com/jax-ml/jax/blob/ed4a7bbab10aae1a38ea8127c1349962d13ddba4/jax/_src/environment_info.py#L49C106-L49C107
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: