diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 871f092a4a28f..798d00509a0ee 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -488,10 +488,9 @@ class ReverseOuterLoops : public BasicStmtVisitor { } }; -// Generate the adjoint version of an independent block - -class MakeAdjoint : public IRVisitor { - private: +// Base class for both reverse (make adjoint) and forward (make dual) mode +class ADTransform : public IRVisitor { + protected: Stmt *constant(float32 x) { return insert(TypedConstant(x)); } @@ -556,6 +555,107 @@ class MakeAdjoint : public IRVisitor { } public: + virtual Stmt *insert_grad_stmt(std::unique_ptr &&stmt) = 0; + + template + Stmt *insert(Args &&...args) { + return insert_grad_stmt(Stmt::make(args...)); + } + + void visit(AllocaStmt *alloca) override { + // do nothing. + } + + void visit(AdStackAllocaStmt *alloca) override { + // do nothing. + } + + void visit(ArgLoadStmt *stmt) override { + // do nothing. + } + + void visit(LoopIndexStmt *stmt) override { + // do nothing. + } + + void visit(PrintStmt *print_stmt) override { + // do nothing + } + + void visit(ConstStmt *const_stmt) override { + // do nothing + } + + void visit(WhileControlStmt *stmt) override { + TI_NOT_IMPLEMENTED + } + + void visit(ContinueStmt *stmt) override { + TI_NOT_IMPLEMENTED; + } + + void visit(WhileStmt *stmt) override { + TI_NOT_IMPLEMENTED + } + + void visit(GlobalPtrStmt *stmt) override { + // do nothing + } + + Stmt *load(Stmt *alloc) { + TI_ASSERT(alloc != nullptr); + if (alloc->is()) { + return insert(LocalAddress(alloc, 0)); + } else { + // non alloca + return alloc; + } + } + + bool gradients_stopped(GlobalLoadStmt *stmt, SNode *snode) { + for (auto block = stmt->parent; block; block = block->parent_block()) { + for (auto s : block->stop_gradients) { + if (s == snode) { + return true; + } + } + } + return false; + } + + void visit(ElementShuffleStmt *stmt) override { + TI_NOT_IMPLEMENTED + } + + void visit(AssertStmt *stmt) override { + // do nothing + } + + void visit(RangeAssumptionStmt *stmt) override { + // do nothing + } + + void visit(LinearizeStmt *stmt) override { + // do nothing + } + + void visit(BitExtractStmt *stmt) override { + // do nothing + } + + void visit(IntegerOffsetStmt *stmt) override { + // do nothing + } + + void visit(RandStmt *stmt) override { + TI_ERROR("RandStmt not supported in AutoDiff for now."); + } +}; + +// Generate the adjoint version of an independent block +class MakeAdjoint : public ADTransform { + public: + using ADTransform::visit; Block *current_block; Block *alloca_block; // Backup the forward pass (the forward pass might be modified during the @@ -593,17 +693,12 @@ class MakeAdjoint : public IRVisitor { } } - Stmt *insert_back(std::unique_ptr &&stmt) { + Stmt *insert_grad_stmt(std::unique_ptr &&stmt) override { auto ptr = stmt.get(); current_block->insert(std::move(stmt), -1); return ptr; } - template - Stmt *insert(Args &&...args) { - return insert_back(Stmt::make(args...)); - } - // Accumulate [value] to the adjoint of [primal] void accumulate(Stmt *primal, Stmt *value) { auto alloca_ = adjoint(primal); @@ -675,22 +770,6 @@ class MakeAdjoint : public IRVisitor { return adjoint_stmt[stmt]; } - void visit(AllocaStmt *alloca) override { - // do nothing. - } - - void visit(AdStackAllocaStmt *alloca) override { - // do nothing. - } - - void visit(ArgLoadStmt *stmt) override { - // do nothing. - } - - void visit(LoopIndexStmt *stmt) override { - // do nothing. - } - void visit(UnaryOpStmt *stmt) override { if (stmt->op_type == UnaryOpType::floor || stmt->op_type == UnaryOpType::ceil) { @@ -827,34 +906,14 @@ class MakeAdjoint : public IRVisitor { } current_block = old_current_block; } - insert_back(std::move(new_if)); - } - - void visit(PrintStmt *print_stmt) override { - // do nothing - } - - void visit(ConstStmt *const_stmt) override { - // do nothing - } - - void visit(WhileControlStmt *stmt) override { - TI_NOT_IMPLEMENTED - } - - void visit(ContinueStmt *stmt) override { - TI_NOT_IMPLEMENTED; - } - - void visit(WhileStmt *stmt) override { - TI_NOT_IMPLEMENTED + insert_grad_stmt(std::move(new_if)); } void visit(RangeForStmt *for_stmt) override { auto new_for = for_stmt->clone(); auto new_for_ptr = new_for->as(); new_for_ptr->reversed = !new_for_ptr->reversed; - insert_back(std::move(new_for)); + insert_grad_stmt(std::move(new_for)); const int len = new_for_ptr->body->size(); for (int i = 0; i < len; i++) { @@ -889,10 +948,6 @@ class MakeAdjoint : public IRVisitor { for_stmt->body->accept(this); } - void visit(GlobalPtrStmt *stmt) override { - // do nothing - } - // Equivalent to AdStackLoadTopStmt when no stack is needed void visit(LocalLoadStmt *stmt) override { // TI_ASSERT(!needs_grad(stmt->ret_type)); @@ -999,34 +1054,6 @@ class MakeAdjoint : public IRVisitor { } stmt->parent->erase(stmt); } - - void visit(ElementShuffleStmt *stmt) override { - TI_NOT_IMPLEMENTED - } - - void visit(AssertStmt *stmt) override { - // do nothing - } - - void visit(RangeAssumptionStmt *stmt) override { - // do nothing - } - - void visit(LinearizeStmt *stmt) override { - // do nothing - } - - void visit(BitExtractStmt *stmt) override { - // do nothing - } - - void visit(IntegerOffsetStmt *stmt) override { - // do nothing - } - - void visit(RandStmt *stmt) override { - TI_ERROR("RandStmt not supported in AutoDiff for now."); - } }; class BackupSSA : public BasicStmtVisitor {