|
9 | 9 | from taichi.lang import impl
|
10 | 10 | from taichi.lang.enums import AutodiffMode
|
11 | 11 | from taichi.lang.field import ScalarField
|
| 12 | +from taichi.lang.snode import SNode |
| 13 | + |
| 14 | +from taichi import _snode |
| 15 | + |
| 16 | + |
| 17 | +class Tape: |
| 18 | + def __init__(self, loss=None, clear_gradients=True): |
| 19 | + """A context manager for reverse mode autodiff :class:`~taichi.ad.Tape`. The |
| 20 | + context manager would catching all of the callings of functions that |
| 21 | + decorated by :func:`~taichi.lang.kernel_impl.kernel` or |
| 22 | + :func:`~taichi.ad.grad_replaced` under `with` statement, and calculate |
| 23 | + all the partial gradients of a given loss variable by calling all of the |
| 24 | + gradient function of the callings caught in reverse order while `with` |
| 25 | + statement ended. |
| 26 | +
|
| 27 | + See also :func:`~taichi.lang.kernel_impl.kernel` and |
| 28 | + :func:`~taichi.ad.grad_replaced` for gradient functions. |
| 29 | +
|
| 30 | + Args: |
| 31 | + loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be (). |
| 32 | + clear_gradients(Bool): Before `with` body start, clear all gradients or not. |
| 33 | +
|
| 34 | + Example:: |
| 35 | +
|
| 36 | + >>> @ti.kernel |
| 37 | + >>> def sum(a: ti.float32): |
| 38 | + >>> for I in ti.grouped(x): |
| 39 | + >>> y[None] += x[I] ** a |
| 40 | + >>> |
| 41 | + >>> with ti.Tape(loss = y): |
| 42 | + >>> sum(2) |
| 43 | + """ |
| 44 | + self.calls = [] |
| 45 | + self.entered = False |
| 46 | + self.gradient_evaluated = False |
| 47 | + self.clear_gradients = clear_gradients |
| 48 | + self.runtime = impl.get_runtime() |
| 49 | + self.eval_on_exit = loss is not None |
| 50 | + self.loss = loss |
| 51 | + |
| 52 | + def __enter__(self): |
| 53 | + assert not self.entered, "Tape can be entered only once." |
| 54 | + self.entered = True |
| 55 | + |
| 56 | + impl.get_runtime().materialize() |
| 57 | + if len(self.loss.shape) != 0: |
| 58 | + raise RuntimeError( |
| 59 | + 'The loss of `Tape` must be a 0-D field, i.e. scalar') |
| 60 | + if not self.loss.snode.ptr.has_adjoint(): |
| 61 | + raise RuntimeError( |
| 62 | + 'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)' |
| 63 | + ' for all fields that are required by autodiff.') |
| 64 | + if self.clear_gradients: |
| 65 | + clear_all_gradients() |
| 66 | + |
| 67 | + from taichi._kernels import clear_loss # pylint: disable=C0415 |
| 68 | + clear_loss(self.loss) |
| 69 | + |
| 70 | + # Attach the context manager to runtime |
| 71 | + self.runtime.target_tape = self |
| 72 | + |
| 73 | + def __exit__(self, _type, value, tb): |
| 74 | + self.runtime.target_tape = None |
| 75 | + if self.eval_on_exit: |
| 76 | + self.grad() |
| 77 | + |
| 78 | + def insert(self, func, args): |
| 79 | + self.calls.append((func, args)) |
| 80 | + |
| 81 | + def grad(self): |
| 82 | + assert self.entered, "Before evaluating gradients tape must be entered." |
| 83 | + assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once." |
| 84 | + for func, args in reversed(self.calls): |
| 85 | + func.grad(*args) |
| 86 | + self.gradient_evaluated = True |
| 87 | + |
| 88 | + |
| 89 | +def clear_all_gradients(): |
| 90 | + """Sets the gradients of all fields to zero. |
| 91 | + """ |
| 92 | + impl.get_runtime().materialize() |
| 93 | + |
| 94 | + def visit(node): |
| 95 | + places = [] |
| 96 | + for _i in range(node.ptr.get_num_ch()): |
| 97 | + ch = node.ptr.get_ch(_i) |
| 98 | + if not ch.is_place(): |
| 99 | + visit(SNode(ch)) |
| 100 | + else: |
| 101 | + if not ch.is_primal(): |
| 102 | + places.append(ch.get_expr()) |
| 103 | + |
| 104 | + places = tuple(places) |
| 105 | + if places: |
| 106 | + from taichi._kernels import \ |
| 107 | + clear_gradients # pylint: disable=C0415 |
| 108 | + clear_gradients(places) |
| 109 | + |
| 110 | + for root_fb in _snode.FieldsBuilder._finalized_roots(): |
| 111 | + visit(root_fb) |
12 | 112 |
|
13 | 113 |
|
14 | 114 | def grad_replaced(func):
|
|
0 commit comments