Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a device function inside the host function of host_callback fails confusingly #5934

Closed
tomhennigan opened this issue Mar 4, 2021 · 8 comments
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@tomhennigan
Copy link
Collaborator

Using a JAX device function in the host fn of host_callback causes hangs/long compile/oom (breakage changes with the backend).

This not being supported is probably WAI, but we could provide a better error message to avoid user confusion. Minimal reproducer below:

import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
import numpy as np

def host_fun(m: np.ndarray) -> np.ndarray:
  return np.sum(m)  # works
  # return jnp.sum(m)  # causes errors

def device_fun(m):
  return hcb.call(host_fun, m,
                  result_shape=jax.ShapeDtypeStruct(m.shape, m.dtype))

jax.jit(device_fun)(0)
@tomhennigan tomhennigan added the bug Something isn't working label Mar 4, 2021
@ahoenselaar
Copy link
Contributor

The success of the reentrant call depends on the device configuration.

  • If there is only a single CPU device available, the call causes a deadlock.
  • If two or more CPU devices are explicitly made visible, one can create a working configuration by JITing device_fun onto one device and forcing host_fun onto a different device via the device kwarg of jit().
  • Similarly, a working configuration can be accomplished by spreading host_fun and device_fun across a CPU and a GPU device.

It is less clear what happens under transformations like pmap.

@gnecula
Copy link
Collaborator

gnecula commented Mar 16, 2021

Allowing re-entrant calls is a larger project, but I will look into providing a better message.

The rule is that when a callback executes on a device it blocks the device until the callback finishes. The callback cannot launch other computations on the same device.

@alvarosg
Copy link
Contributor

alvarosg commented Mar 17, 2021

Related to this potential reentrant call, but looking at it from a more general perspective.

One of the potential use cases for jax.experimental.host_callback.call is to get around the jax.jit fixed shapes requirements for some specific intermediate computations, since anything that happens in the host_callback can take any arbitrary shape. This can be very valuable for prototyping and research code. Of course, using jax.experimental.host_callback.call for this purpose means:

  • Data to and from jax.experimental.host_callback.call needs to be copied from the device to the host and viceversa, which adds additional overhead.
  • There cannot be any device parallelization in the callback.

It would be fantastic if the inputs and outputs to the host_fun remained DeviceArray instead of numpy arrays. I understand this may cause complications for XLA optimization now that there is arbitrary host code being executed, but I could imagine adding a parameter to jax.experimental.host_callback.call, indicating how much device memory will be required. Is guess since host_callback is still experimental, it is not too late to think about whether we want this to be a pure "host" callback, or more a "return control to host" callback, or even something like jax.unjitted_call, but that may also make use of the device. E.g.:

@jax.jit
def fun(input):
  aux1 = input ** 2 + 5  # This will run jitted.
  
  def unjitted_fn(x):
    # This will  give back control to python, but keep the input array 'x' on device.
    print("host")  # This will print every single time.
    # users may still feed x to a CPU only library (e.g. scipy), and this will trigger automatic 
    # transfer of the array to host memory if required (as usual), but only if actually necessary.
    return x ** 2 + 1  # This will run un-jitted, but on device.
  aux2 = jax.unjitted_call(unjitted_fn, aux1, 
                           result_shape=jax.ShapeDtypeStruct(aux1.shape, aux1.dtype), 
                           max_memory=...)
                           
  output = aux2 **3 + 5  # This will run jitted
  return output

@evanatyourservice
Copy link

Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX.

@gnecula
Copy link
Collaborator

gnecula commented Apr 27, 2022 via email

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) and removed P0 (urgent) An issue of the highest priority. We are addressing this urgently. (Assignee required) labels Aug 10, 2022
@sudhakarsingh27 sudhakarsingh27 assigned hawkinsp and unassigned gnecula Aug 15, 2022
@sudhakarsingh27 sudhakarsingh27 assigned sharadmv and unassigned hawkinsp Sep 7, 2022
@mattjj
Copy link
Collaborator

mattjj commented Sep 7, 2022

@sharadmv please close this issue when the replacement is ready!

@rajasekharporeddy
Copy link
Contributor

rajasekharporeddy commented Dec 10, 2024

Hi @tomhennigan

The issue seems to be fixed in later versions of JAX. Since jax.experimental.host_callback has been deprecated in JAX 0.4.26 and removed in JAX 0.4.35, I have tested the issue on colab T4 GPU with the new external callback jax.experimental.io_callback as suggested in #20385. The code now executes without any hangs/long compile/oom.

import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
from jax.experimental import io_callback
import numpy as np

print(jax.__version__, jax.devices())

def host_fun(m: np.ndarray) -> np.ndarray:
  # return np.sum(m)  # works
  return jnp.sum(m)  # causes errors

def device_fun(m):
  return io_callback(host_fun,
                     jax.ShapeDtypeStruct(m.shape, m.dtype),
                     m)

jax.jit(device_fun)(0)

Output:

0.4.37 [CudaDevice(id=0)]
Array(0, dtype=int32)

Attaching gist for reference. Could you please verify if the issue still persists with latest JAX version?

Thank you.

@gnecula
Copy link
Collaborator

gnecula commented Dec 10, 2024

Thank you @rajasekharporeddy for investigating. I propose to close this issue now.

@gnecula gnecula closed this as completed Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

10 participants