Skip to content

Commit 2a0ab39

Browse files
committed
[autodiff] Support control flow for forward mode
ghstack-source-id: 06dfa909c4e0a4410a08693a54da15d35ecf84d8 Pull Request resolved: #5231
1 parent 975e414 commit 2a0ab39

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

taichi/transforms/auto_diff.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -1137,12 +1137,39 @@ class MakeDual : public ADTransform {
11371137
// d (x * y) = y * dx + x * dy
11381138
accumulate(bin, mul(bin->lhs, dual(bin->rhs)));
11391139
accumulate(bin, mul(bin->rhs, dual(bin->lhs)));
1140+
} else if (is_comparison(bin->op_type) || is_bit_op(bin->op_type)) {
1141+
// do nothing
11401142
} else {
11411143
TI_WARN("gradient of binary op {}", binary_op_type_name(bin->op_type));
11421144
TI_NOT_IMPLEMENTED
11431145
}
11441146
}
11451147

1148+
void visit(IfStmt *if_stmt) override {
1149+
if (if_stmt->true_statements) {
1150+
std::vector<Stmt *> true_statements;
1151+
for (auto &stmt : if_stmt->true_statements->statements) {
1152+
true_statements.push_back(stmt.get());
1153+
}
1154+
1155+
for (auto stmt : true_statements) {
1156+
current_stmt = stmt;
1157+
stmt->accept(this);
1158+
}
1159+
}
1160+
if (if_stmt->false_statements) {
1161+
std::vector<Stmt *> false_statements;
1162+
for (auto &stmt : if_stmt->false_statements->statements) {
1163+
false_statements.push_back(stmt.get());
1164+
}
1165+
1166+
for (auto stmt : false_statements) {
1167+
current_stmt = stmt;
1168+
stmt->accept(this);
1169+
}
1170+
}
1171+
}
1172+
11461173
void visit(RangeForStmt *for_stmt) override {
11471174
std::vector<Stmt *> statements;
11481175
// always make a copy since the list can be modified.

tests/python/test_ad_if_fwd.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from taichi.lang import impl
2+
from taichi.lang.misc import get_host_arch_list
3+
4+
import taichi as ti
5+
from tests import test_utils
6+
7+
8+
@test_utils.test(arch=[ti.cpu, ti.gpu])
9+
def test_ad_if_simple_fwd():
10+
x = ti.field(ti.f32, shape=())
11+
y = ti.field(ti.f32, shape=())
12+
13+
@ti.kernel
14+
def func():
15+
if x[None] > 0.:
16+
y[None] = x[None]
17+
18+
x[None] = 1
19+
with ti.ad.FwdMode(loss=y, parameters=x, seed=[1.0]):
20+
func()
21+
22+
assert y.grad[None] == 1
23+
24+
25+
@test_utils.test(arch=[ti.cpu, ti.gpu])
26+
def test_ad_if():
27+
x = ti.field(ti.f32, shape=2)
28+
y = ti.field(ti.f32, shape=2)
29+
30+
@ti.kernel
31+
def func(i: ti.i32):
32+
if x[i] > 0:
33+
y[i] = x[i]
34+
else:
35+
y[i] = 2 * x[i]
36+
37+
x[0] = 0
38+
x[1] = 1
39+
with ti.ad.FwdMode(loss=y, parameters=x, seed=[1.0, 1.0]):
40+
func(0)
41+
func(1)
42+
assert y.grad[0] == 2
43+
assert y.grad[1] == 1
44+
45+
46+
@test_utils.test(arch=[ti.cpu, ti.gpu])
47+
def test_ad_if_nested():
48+
n = 20
49+
x = ti.field(ti.f32, shape=n)
50+
y = ti.field(ti.f32, shape=n)
51+
z = ti.field(ti.f32, shape=n)
52+
53+
ti.root.lazy_grad()
54+
55+
@ti.kernel
56+
def func():
57+
for i in x:
58+
if x[i] < 2:
59+
if x[i] == 0:
60+
y[i] = 0
61+
else:
62+
y[i] = z[i] * 1
63+
else:
64+
if x[i] == 2:
65+
y[i] = z[i] * 2
66+
else:
67+
y[i] = z[i] * 3
68+
69+
z.fill(1)
70+
71+
for i in range(n):
72+
x[i] = i % 4
73+
74+
func()
75+
for i in range(n):
76+
assert y[i] == i % 4
77+
78+
with ti.ad.FwdMode(loss=y, parameters=z, seed=[1.0 for _ in range(n)]):
79+
func()
80+
81+
for i in range(n):
82+
assert y.grad[i] == i % 4

0 commit comments

Comments
 (0)