Skip to content

Commit 7db8389

Browse files
committed
Update base for Update on "[autodiff] Support basic operations for forward mode autodiff"
Support cpu and gpu backends. The cc backend has an issue on FieldBuilder ref to #5143. The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported. Related #5055 [ghstack-poisoned]
1 parent 01ffaac commit 7db8389

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

python/taichi/ad.py

+100
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,106 @@
99
from taichi.lang import impl
1010
from taichi.lang.enums import AutodiffMode
1111
from taichi.lang.field import ScalarField
12+
from taichi.lang.snode import SNode
13+
14+
from taichi import _snode
15+
16+
17+
class Tape:
18+
def __init__(self, loss=None, clear_gradients=True):
19+
"""A context manager for reverse mode autodiff :class:`~taichi.ad.Tape`. The
20+
context manager would catching all of the callings of functions that
21+
decorated by :func:`~taichi.lang.kernel_impl.kernel` or
22+
:func:`~taichi.ad.grad_replaced` under `with` statement, and calculate
23+
all the partial gradients of a given loss variable by calling all of the
24+
gradient function of the callings caught in reverse order while `with`
25+
statement ended.
26+
27+
See also :func:`~taichi.lang.kernel_impl.kernel` and
28+
:func:`~taichi.ad.grad_replaced` for gradient functions.
29+
30+
Args:
31+
loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be ().
32+
clear_gradients(Bool): Before `with` body start, clear all gradients or not.
33+
34+
Example::
35+
36+
>>> @ti.kernel
37+
>>> def sum(a: ti.float32):
38+
>>> for I in ti.grouped(x):
39+
>>> y[None] += x[I] ** a
40+
>>>
41+
>>> with ti.Tape(loss = y):
42+
>>> sum(2)
43+
"""
44+
self.calls = []
45+
self.entered = False
46+
self.gradient_evaluated = False
47+
self.clear_gradients = clear_gradients
48+
self.runtime = impl.get_runtime()
49+
self.eval_on_exit = loss is not None
50+
self.loss = loss
51+
52+
def __enter__(self):
53+
assert not self.entered, "Tape can be entered only once."
54+
self.entered = True
55+
56+
impl.get_runtime().materialize()
57+
if len(self.loss.shape) != 0:
58+
raise RuntimeError(
59+
'The loss of `Tape` must be a 0-D field, i.e. scalar')
60+
if not self.loss.snode.ptr.has_adjoint():
61+
raise RuntimeError(
62+
'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)'
63+
' for all fields that are required by autodiff.')
64+
if self.clear_gradients:
65+
clear_all_gradients()
66+
67+
from taichi._kernels import clear_loss # pylint: disable=C0415
68+
clear_loss(self.loss)
69+
70+
# Attach the context manager to runtime
71+
self.runtime.target_tape = self
72+
73+
def __exit__(self, _type, value, tb):
74+
self.runtime.target_tape = None
75+
if self.eval_on_exit:
76+
self.grad()
77+
78+
def insert(self, func, args):
79+
self.calls.append((func, args))
80+
81+
def grad(self):
82+
assert self.entered, "Before evaluating gradients tape must be entered."
83+
assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once."
84+
for func, args in reversed(self.calls):
85+
func.grad(*args)
86+
self.gradient_evaluated = True
87+
88+
89+
def clear_all_gradients():
90+
"""Sets the gradients of all fields to zero.
91+
"""
92+
impl.get_runtime().materialize()
93+
94+
def visit(node):
95+
places = []
96+
for _i in range(node.ptr.get_num_ch()):
97+
ch = node.ptr.get_ch(_i)
98+
if not ch.is_place():
99+
visit(SNode(ch))
100+
else:
101+
if not ch.is_primal():
102+
places.append(ch.get_expr())
103+
104+
places = tuple(places)
105+
if places:
106+
from taichi._kernels import \
107+
clear_gradients # pylint: disable=C0415
108+
clear_gradients(places)
109+
110+
for root_fb in _snode.FieldsBuilder._finalized_roots():
111+
visit(root_fb)
12112

13113

14114
def grad_replaced(func):

0 commit comments

Comments
 (0)