You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As mentioned by @victoriacity in issue #5055, Jax supports explicitly differentiating a function i.e.,
grad_tanh=grad(jnp.tanh)
print(grad_tanh(2.0))
grad takes a function and returns a function. If you have a Python function f that evaluates the mathematical function , then grad(f) is a Python function that evaluates the gradient of the function. That means grad(f)(x) represents the value.
It is possible to support this feature in Taichi. A proof of concept is shown below:
Updates: inputs from @yuanming-hu , we can try to implement this feature when the real function is ready so that function level autodiff is possible. One advantage for this is the differentiated functions can be paralleled executed inside a kernel.
As mentioned by @victoriacity in issue #5055, Jax supports explicitly differentiating a function i.e.,
grad
takes a function and returns a function. If you have a Python functionf
that evaluates the mathematical function , thengrad(f)
is a Python function that evaluates the gradient of the function. That meansgrad(f)(x)
represents the value.It is possible to support this feature in Taichi. A proof of concept is shown below:
I am wondering should we implement this feature? Opinions are welcomed!
The text was updated successfully, but these errors were encountered: