-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using a device function inside the host function of host_callback fails confusingly #5934
Comments
The success of the reentrant call depends on the device configuration.
It is less clear what happens under transformations like |
Allowing re-entrant calls is a larger project, but I will look into providing a better message. The rule is that when a callback executes on a device it blocks the device until the callback finishes. The callback cannot launch other computations on the same device. |
Related to this potential reentrant call, but looking at it from a more general perspective. One of the potential use cases for
It would be fantastic if the inputs and outputs to the
|
Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX. |
Inside a vmap computation you will see a single call to the host, with the
entire batch. You'd have to write your host function to split the data into
batches and run it in several threads.
…On Tue, Apr 26, 2022, 16:28 Evan Walters ***@***.***> wrote:
Quick question regarding host_callback.call within jitted functions and
multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I
jit and vmap a fn that calls a python function through host_callback.call
with batch_dim=4, is this a way to have the python functions run
concurrently on all four CPU cores at once, or will JAX force running them
all on the same CPU core one after another? Wondering if this is sort of a
hack to get some concurrency with python functions through JAX.
—
Reply to this email directly, view it on GitHub
<#5934 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA5V6J3IWWZNA6N6GSR4W3LVHADSNANCNFSM4YTUOI5A>
.
You are receiving this because you were assigned.Message ID:
***@***.***>
|
@sharadmv please close this issue when the replacement is ready! |
Hi @tomhennigan The issue seems to be fixed in later versions of JAX. Since import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
from jax.experimental import io_callback
import numpy as np
print(jax.__version__, jax.devices())
def host_fun(m: np.ndarray) -> np.ndarray:
# return np.sum(m) # works
return jnp.sum(m) # causes errors
def device_fun(m):
return io_callback(host_fun,
jax.ShapeDtypeStruct(m.shape, m.dtype),
m)
jax.jit(device_fun)(0) Output:
Attaching gist for reference. Could you please verify if the issue still persists with latest JAX version? Thank you. |
Thank you @rajasekharporeddy for investigating. I propose to close this issue now. |
Using a JAX device function in the host fn of
host_callback
causes hangs/long compile/oom (breakage changes with the backend).This not being supported is probably WAI, but we could provide a better error message to avoid user confusion. Minimal reproducer below:
The text was updated successfully, but these errors were encountered: