Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve debug_nan in shard_map to pinpoint exact line of NaN occurrence #26796

Open
Stella-S-Yan opened this issue Feb 27, 2025 · 0 comments
Open
Labels
enhancement New feature or request

Comments

@Stella-S-Yan
Copy link
Collaborator

When examining the dispatch path of shard_map, neither the fast nor slow dispatch paths pinpoint the exact line where a NaN occurs. Ideally, shard_map should behave more like jit, where both dispatch paths accurately identify the exact line causing the NaN.

Here is the test case:

import jax
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map

jax.clear_caches()

P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((1,), ('x',))

@jax.jit
def f(x):
    y = jnp.square(x)
    return jnp.log(-y)

f_shard_map = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))

with jax.debug_nans(True):
    one = jnp.ones([1])
   
    # print(f_shard_map(jnp.zeros([1]))) 
    print(f_shard_map(one))

Turning on the slow path:

with jax.debug_nans(True):
  one = jnp.ones([1])

  # print(f_shard_map(jnp.zeros([1]))) 
  print(f_shard_map(one))

and the output is:

Traceback (most recent call last):
  File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 22, in <module>
    print(f_shard_map(one))
          ^^^^^^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in sharded computation.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Turning on the fast path:

f_shard_map = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))

with jax.debug_nans(True):
    one = jnp.ones([1])
   
    print(f_shard_map(jnp.zeros([1]))) 
    print(f_shard_map(one))

and the output is:

 File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 22, in <module>
    print(f_shard_map(one))
          ^^^^^^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in sharded computation.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Here is the debug_nan output for jit as a reference.

with jax.debug_nans(True):
    one = jnp.ones([1])
    # print(f(one))
    print(f(jnp.zeros([1])))
    print(f(one))
Traceback (most recent call last):
  File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 28, in <module>
    print(f(one))
          ^^^^^^
  File "/usr/local/google/home/stellasyan/Documents/test_jax/test_nan.py", line 14, in f
    return jnp.log(-y)
           ^^^^^^^^^^^
  File "/usr/local/google/home/stellasyan/miniconda3/envs/jax/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
@Stella-S-Yan Stella-S-Yan added the enhancement New feature or request label Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant