You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When examining the dispatch path of shard_map, neither the fast nor slow dispatch paths pinpoint the exact line where a NaN occurs. Ideally, shard_map should behave more like jit, where both dispatch paths accurately identify the exact line causing the NaN.
Here is the test case:
import jax
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
jax.clear_caches()
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((1,), ('x',))
@jax.jit
def f(x):
y = jnp.square(x)
return jnp.log(-y)
f_shard_map = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
with jax.debug_nans(True):
one = jnp.ones([1])
# print(f_shard_map(jnp.zeros([1])))
print(f_shard_map(one))
Turning on the slow path:
with jax.debug_nans(True):
one = jnp.ones([1])
# print(f_shard_map(jnp.zeros([1])))
print(f_shard_map(one))
and the output is:
Traceback (most recent call last):
File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 22, in <module>
print(f_shard_map(one))
^^^^^^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in sharded computation.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Turning on the fast path:
f_shard_map = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
with jax.debug_nans(True):
one = jnp.ones([1])
print(f_shard_map(jnp.zeros([1])))
print(f_shard_map(one))
and the output is:
File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 22, in <module>
print(f_shard_map(one))
^^^^^^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in sharded computation.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Here is the debug_nan output for jit as a reference.
with jax.debug_nans(True):
one = jnp.ones([1])
# print(f(one))
print(f(jnp.zeros([1])))
print(f(one))
Traceback (most recent call last):
File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 28, in <module>
print(f(one))
^^^^^^
File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 14, in f
return jnp.log(-y)
^^^^^^^^^^^
File "/usr/local/google/home/stellasyan/miniconda3/envs/jax/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py", line 489, in log
return lax.log(*promote_args_inexact('log', x))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The text was updated successfully, but these errors were encountered:
When examining the dispatch path of
shard_map
, neither the fast nor slow dispatch paths pinpoint the exact line where aNaN
occurs. Ideally,shard_map
should behave more likejit
, where both dispatch paths accurately identify the exact line causing theNaN
.Here is the test case:
Turning on the slow path:
and the output is:
Turning on the fast path:
and the output is:
Here is the debug_nan output for jit as a reference.
The text was updated successfully, but these errors were encountered: