Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add forward mode for autodiff #5055

Open
18 of 20 tasks
erizmr opened this issue May 27, 2022 · 2 comments
Open
18 of 20 tasks

[RFC] Add forward mode for autodiff #5055

erizmr opened this issue May 27, 2022 · 2 comments
Labels

Comments

@erizmr
Copy link
Contributor

erizmr commented May 27, 2022

In this issue, we would like to share a draft implementation plan for the forward mode autodiff.

Background

In general, there are two modes for autodiff: reverse mode and forward mode. The two modes have their advantage in different scenarios. The reverse mode is more efficient when the number of inputs is much more than the outputs (e.g., machine learning cases, thousands of trainable parameters and one scalar loss). On the contrary, the forward mode is more efficient. In addtion, the second-order derivatives can be efficiently computed by combining both the forward and reverse mode.

For a roadmap for the autodiff feature in Taichi, please check out #5050.

Goals

  • Implement forward mode autodiff.
  • Design python interface for forward and reverse mode.
  • Make it possible to apply both forward/reverse mode iteratively (e.g., forward(reverse())), preparing for computing second-order derivatives.

Implementation Roadmap

Discussions

  • How many kernels we need to compile for forward mode autodiff?

Currently in reverse mode, two kernels (original kernel and grad kernel) for evaluating function values and compute the gradients respectively. However, in forward mode autodiff, the derivatives are computed eagerly during the function evaluating process, i.e., the functions values and gradients can be computed using only kernel. This raise the question whether need to compile one or two kernels.`

Update: three kinds of kernels are generated: primal, forward ad and reverse ad according to different autodiff modes, see #5098.

@erizmr erizmr added the RFC label May 27, 2022
@taichi-ci-bot taichi-ci-bot moved this to Untriaged in Taichi Lang May 27, 2022
@victoriacity
Copy link
Member

I wonder if explicitly differentiating a function as in JAX will be supported, for example,

@ti.func
def f(x): return x**3 + 2*x**2 - 3*x + 1

dfdx = forward(f)

@ti.kernel
def k() -> float:
    return dfdx(1.0)
k() # returns 4.0

@qiao-bo qiao-bo moved this from Untriaged to In Progress in Taichi Lang May 27, 2022
@erizmr
Copy link
Contributor Author

erizmr commented May 27, 2022

I think it is possible to support similar features. A naive current Taichi equivalent is:

import taichi as ti

ti.init()

x = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)

@ti.kernel
def f(): 
    y[None] += x[None]**3 + 2*x[None]**2 - 3*x[None] + 1
    
def dfdx(_x):
    x[None] = _x
    y.grad[None] = 1.0
    f.grad()
    return x.grad[None]

print(dfdx(1.0))

For more general case, it may require to specify the input and output if we would like to generate dfdx for the users. A possible implementation might be:

import taichi as ti

ti.init()

x1 = ti.field(float, shape=(), needs_grad=True)
x2 = ti.field(float, shape=(), needs_grad=True)
x3 = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)

@ti.kernel
def f(): 
    y[None] += x1[None]**3 + 2*x2[None]**2 - 3*x3[None] + 1

def backward(f, input_field, out_field):
    import numpy as np
    out_field.grad[None] = 1.0
    def _dfdx(inputs):
        for i, x in enumerate(inputs):
            input_field[i].from_numpy(np.array(inputs[i]))
        f.grad()
        ret = []
        for x in input_field:
            ret.append(x.grad.to_numpy())
        return ret
    return _dfdx

dfdx = backward(f, [x1, x2, x3], y)

print(dfdx([1.0, 2.0, 3.0])) # [3, 8, -3]

erizmr added a commit that referenced this issue Jun 14, 2022
…rward mode autodiff"


Support cpu and gpu backends.
The cc backend has an issue on FieldBuilder ref to #5143.
The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
…iff"


Support cpu and gpu backends.
The cc backend has an issue on FieldBuilder ref to #5143.
The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
The primal kernels inside the context manager will be transofrmed to a forward ad kernel.
They will be recovered to primal kernels after exiting the context manager for futher non-ad use.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
…rward mode autodiff"


Support cpu and gpu backends.
The cc backend has an issue on FieldBuilder ref to #5143.
The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
…rd mode autodiff"


The primal kernels inside the context manager will be transofrmed to a forward ad kernel.
They will be recovered to primal kernels after exiting the context manager for futher non-ad use.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
…iff"


Support cpu and gpu backends.
The cc backend has an issue on FieldBuilder ref to #5143.
The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported.

Related #5055 

[ghstack-poisoned]
erizmr added a commit that referenced this issue Jun 14, 2022
The primal kernels inside the context manager will be transofrmed to a forward ad kernel.
They will be recovered to primal kernels after exiting the context manager for futher non-ad use.

Related #5055 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: In Progress
Development

No branches or pull requests

2 participants