Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autodiff] Support explicitly differentiating a function as in JAX #5060

Open
erizmr opened this issue May 30, 2022 · 1 comment
Open

[autodiff] Support explicitly differentiating a function as in JAX #5060

erizmr opened this issue May 30, 2022 · 1 comment
Labels
discussion Welcome discussion! feature request Suggest an idea on this project

Comments

@erizmr
Copy link
Contributor

erizmr commented May 30, 2022

As mentioned by @victoriacity in issue #5055, Jax supports explicitly differentiating a function i.e.,

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

grad takes a function and returns a function. If you have a Python function f that evaluates the mathematical function , then grad(f) is a Python function that evaluates the gradient of the function. That means grad(f)(x) represents the value.

It is possible to support this feature in Taichi. A proof of concept is shown below:

import taichi as ti

ti.init()

x1 = ti.field(float, shape=(), needs_grad=True)
x2 = ti.field(float, shape=(), needs_grad=True)
x3 = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)

@ti.kernel
def f(): 
    y[None] += x1[None]**3 + 2*x2[None]**2 - 3*x3[None] + 1

def backward(f, input_field, out_field):
    import numpy as np
    out_field.grad[None] = 1.0
    def _dfdx(inputs):
        for i, x in enumerate(inputs):
            input_field[i].from_numpy(np.array(inputs[i]))
        f.grad()
        ret = []
        for x in input_field:
            ret.append(x.grad.to_numpy())
        return ret
    return _dfdx

dfdx = backward(f, [x1, x2, x3], y)

print(dfdx([1.0, 2.0, 3.0])) # [3, 8, -3]

I am wondering should we implement this feature? Opinions are welcomed!

@erizmr erizmr added feature request Suggest an idea on this project discussion Welcome discussion! labels May 30, 2022
@taichi-ci-bot taichi-ci-bot moved this to Untriaged in Taichi Lang May 30, 2022
@ailzhang ailzhang moved this from Untriaged to Backlog in Taichi Lang Jun 2, 2022
@erizmr
Copy link
Contributor Author

erizmr commented Jun 8, 2022

Updates: inputs from @yuanming-hu , we can try to implement this feature when the real function is ready so that function level autodiff is possible. One advantage for this is the differentiated functions can be paralleled executed inside a kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Welcome discussion! feature request Suggest an idea on this project
Projects
Status: Backlog
Development

No branches or pull requests

1 participant