From fb4d55b1dcd41c463563cb90af96a341af09496b Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 2 Jun 2022 12:31:39 +0800 Subject: [PATCH 1/4] extract shared components for reverse and forward mode --- taichi/transforms/auto_diff.cpp | 251 ++++++++++++++++++-------------- 1 file changed, 139 insertions(+), 112 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 871f092a4a28f..31fb5b6f8f04f 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -488,74 +488,174 @@ class ReverseOuterLoops : public BasicStmtVisitor { } }; -// Generate the adjoint version of an independent block +// Base class for both reverse (make adjoint) and forward (make dual) mode autodiff +class ADTransform : public IRVisitor{ + protected: + Stmt *constant(float32 x) { + return insert(TypedConstant(x)); + } -class MakeAdjoint : public IRVisitor { - private: - Stmt *constant(float32 x) { - return insert(TypedConstant(x)); + // utils + Stmt *sgn(Stmt *inp) { + return insert(UnaryOpType::sgn, load(inp)); + } + + // utils + Stmt *negate(Stmt *inp) { + return insert(UnaryOpType::neg, load(inp)); + } + + Stmt *sqrt(Stmt *inp) { + return insert(UnaryOpType::sqrt, load(inp)); + } + + Stmt *mul(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::mul, load(op1), load(op2)); + } + + Stmt *sqr(Stmt *op1) { + return mul(op1, op1); + } + + Stmt *add(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::add, load(op1), load(op2)); + } + + Stmt *cmp_lt(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::cmp_lt, load(op1), load(op2)); + } + + Stmt *sub(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::sub, load(op1), load(op2)); + } + + Stmt *div(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::div, load(op1), load(op2)); + } + + Stmt *sel(Stmt *op1, Stmt *op2, Stmt *op3) { + return insert(TernaryOpType::select, load(op1), load(op2), + load(op3)); + } + + Stmt *cos(Stmt *op1) { + return insert(UnaryOpType::cos, load(op1)); + } + + Stmt *sin(Stmt *op1) { + return insert(UnaryOpType::sin, load(op1)); + } + + Stmt *log(Stmt *op1) { + return insert(UnaryOpType::log, load(op1)); + } + + Stmt *pow(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::pow, load(op1), load(op2)); + } + +public: + virtual Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) = 0; + + template + Stmt *insert(Args &&...args) { + return insert_ad_transform_stmt(Stmt::make(args...)); } - // utils - Stmt *sgn(Stmt *inp) { - return insert(UnaryOpType::sgn, load(inp)); + void visit(AllocaStmt *alloca) override { + // do nothing. } - // utils - Stmt *negate(Stmt *inp) { - return insert(UnaryOpType::neg, load(inp)); + void visit(AdStackAllocaStmt *alloca) override { + // do nothing. } - Stmt *sqrt(Stmt *inp) { - return insert(UnaryOpType::sqrt, load(inp)); + void visit(ArgLoadStmt *stmt) override { + // do nothing. } - Stmt *mul(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::mul, load(op1), load(op2)); + void visit(LoopIndexStmt *stmt) override { + // do nothing. } - Stmt *sqr(Stmt *op1) { - return mul(op1, op1); + void visit(PrintStmt *print_stmt) override { + // do nothing } - Stmt *add(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::add, load(op1), load(op2)); + void visit(ConstStmt *const_stmt) override { + // do nothing } - Stmt *cmp_lt(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::cmp_lt, load(op1), load(op2)); + void visit(WhileControlStmt *stmt) override { + TI_NOT_IMPLEMENTED } - Stmt *sub(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::sub, load(op1), load(op2)); + void visit(ContinueStmt *stmt) override { + TI_NOT_IMPLEMENTED; + } + + void visit(WhileStmt *stmt) override { + TI_NOT_IMPLEMENTED } - Stmt *div(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::div, load(op1), load(op2)); + void visit(GlobalPtrStmt *stmt) override { + // do nothing + } + + Stmt *load(Stmt *alloc) { + TI_ASSERT(alloc != nullptr); + if (alloc->is()) { + return insert(LocalAddress(alloc, 0)); + } else { + // non alloca + return alloc; + } } - Stmt *sel(Stmt *op1, Stmt *op2, Stmt *op3) { - return insert(TernaryOpType::select, load(op1), load(op2), - load(op3)); + bool gradients_stopped(GlobalLoadStmt *stmt, SNode *snode) { + for (auto block = stmt->parent; block; block = block->parent_block()) { + for (auto s : block->stop_gradients) { + if (s == snode) { + return true; + } + } + } + return false; } - Stmt *cos(Stmt *op1) { - return insert(UnaryOpType::cos, load(op1)); + void visit(ElementShuffleStmt *stmt) override { + TI_NOT_IMPLEMENTED } - Stmt *sin(Stmt *op1) { - return insert(UnaryOpType::sin, load(op1)); + void visit(AssertStmt *stmt) override { + // do nothing } - Stmt *log(Stmt *op1) { - return insert(UnaryOpType::log, load(op1)); + void visit(RangeAssumptionStmt *stmt) override { + // do nothing } - Stmt *pow(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::pow, load(op1), load(op2)); + void visit(LinearizeStmt *stmt) override { + // do nothing } + void visit(BitExtractStmt *stmt) override { + // do nothing + } + + void visit(IntegerOffsetStmt *stmt) override { + // do nothing + } + + void visit(RandStmt *stmt) override { + TI_ERROR("RandStmt not supported in AutoDiff for now."); + } +}; + +// Generate the adjoint version of an independent block +class MakeAdjoint : public ADTransform { public: + using ADTransform::visit; Block *current_block; Block *alloca_block; // Backup the forward pass (the forward pass might be modified during the @@ -593,17 +693,12 @@ class MakeAdjoint : public IRVisitor { } } - Stmt *insert_back(std::unique_ptr &&stmt) { + Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) override{ auto ptr = stmt.get(); current_block->insert(std::move(stmt), -1); return ptr; } - template - Stmt *insert(Args &&...args) { - return insert_back(Stmt::make(args...)); - } - // Accumulate [value] to the adjoint of [primal] void accumulate(Stmt *primal, Stmt *value) { auto alloca_ = adjoint(primal); @@ -675,22 +770,6 @@ class MakeAdjoint : public IRVisitor { return adjoint_stmt[stmt]; } - void visit(AllocaStmt *alloca) override { - // do nothing. - } - - void visit(AdStackAllocaStmt *alloca) override { - // do nothing. - } - - void visit(ArgLoadStmt *stmt) override { - // do nothing. - } - - void visit(LoopIndexStmt *stmt) override { - // do nothing. - } - void visit(UnaryOpStmt *stmt) override { if (stmt->op_type == UnaryOpType::floor || stmt->op_type == UnaryOpType::ceil) { @@ -827,34 +906,14 @@ class MakeAdjoint : public IRVisitor { } current_block = old_current_block; } - insert_back(std::move(new_if)); - } - - void visit(PrintStmt *print_stmt) override { - // do nothing - } - - void visit(ConstStmt *const_stmt) override { - // do nothing - } - - void visit(WhileControlStmt *stmt) override { - TI_NOT_IMPLEMENTED - } - - void visit(ContinueStmt *stmt) override { - TI_NOT_IMPLEMENTED; - } - - void visit(WhileStmt *stmt) override { - TI_NOT_IMPLEMENTED + insert_ad_transform_stmt(std::move(new_if)); } void visit(RangeForStmt *for_stmt) override { auto new_for = for_stmt->clone(); auto new_for_ptr = new_for->as(); new_for_ptr->reversed = !new_for_ptr->reversed; - insert_back(std::move(new_for)); + insert_ad_transform_stmt(std::move(new_for)); const int len = new_for_ptr->body->size(); for (int i = 0; i < len; i++) { @@ -889,10 +948,6 @@ class MakeAdjoint : public IRVisitor { for_stmt->body->accept(this); } - void visit(GlobalPtrStmt *stmt) override { - // do nothing - } - // Equivalent to AdStackLoadTopStmt when no stack is needed void visit(LocalLoadStmt *stmt) override { // TI_ASSERT(!needs_grad(stmt->ret_type)); @@ -999,34 +1054,6 @@ class MakeAdjoint : public IRVisitor { } stmt->parent->erase(stmt); } - - void visit(ElementShuffleStmt *stmt) override { - TI_NOT_IMPLEMENTED - } - - void visit(AssertStmt *stmt) override { - // do nothing - } - - void visit(RangeAssumptionStmt *stmt) override { - // do nothing - } - - void visit(LinearizeStmt *stmt) override { - // do nothing - } - - void visit(BitExtractStmt *stmt) override { - // do nothing - } - - void visit(IntegerOffsetStmt *stmt) override { - // do nothing - } - - void visit(RandStmt *stmt) override { - TI_ERROR("RandStmt not supported in AutoDiff for now."); - } }; class BackupSSA : public BasicStmtVisitor { From bafb142f6a0705c2113cb78500e1dae4ee3815f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jun 2022 04:34:31 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/transforms/auto_diff.cpp | 107 ++++++++++++++++---------------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 31fb5b6f8f04f..51c196f648b23 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -488,73 +488,74 @@ class ReverseOuterLoops : public BasicStmtVisitor { } }; -// Base class for both reverse (make adjoint) and forward (make dual) mode autodiff -class ADTransform : public IRVisitor{ - protected: - Stmt *constant(float32 x) { - return insert(TypedConstant(x)); - } +// Base class for both reverse (make adjoint) and forward (make dual) mode +// autodiff +class ADTransform : public IRVisitor { + protected: + Stmt *constant(float32 x) { + return insert(TypedConstant(x)); + } - // utils - Stmt *sgn(Stmt *inp) { - return insert(UnaryOpType::sgn, load(inp)); - } + // utils + Stmt *sgn(Stmt *inp) { + return insert(UnaryOpType::sgn, load(inp)); + } - // utils - Stmt *negate(Stmt *inp) { - return insert(UnaryOpType::neg, load(inp)); - } + // utils + Stmt *negate(Stmt *inp) { + return insert(UnaryOpType::neg, load(inp)); + } - Stmt *sqrt(Stmt *inp) { - return insert(UnaryOpType::sqrt, load(inp)); - } + Stmt *sqrt(Stmt *inp) { + return insert(UnaryOpType::sqrt, load(inp)); + } - Stmt *mul(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::mul, load(op1), load(op2)); - } + Stmt *mul(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::mul, load(op1), load(op2)); + } - Stmt *sqr(Stmt *op1) { - return mul(op1, op1); - } + Stmt *sqr(Stmt *op1) { + return mul(op1, op1); + } - Stmt *add(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::add, load(op1), load(op2)); - } + Stmt *add(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::add, load(op1), load(op2)); + } - Stmt *cmp_lt(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::cmp_lt, load(op1), load(op2)); - } + Stmt *cmp_lt(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::cmp_lt, load(op1), load(op2)); + } - Stmt *sub(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::sub, load(op1), load(op2)); - } + Stmt *sub(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::sub, load(op1), load(op2)); + } - Stmt *div(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::div, load(op1), load(op2)); - } + Stmt *div(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::div, load(op1), load(op2)); + } - Stmt *sel(Stmt *op1, Stmt *op2, Stmt *op3) { - return insert(TernaryOpType::select, load(op1), load(op2), - load(op3)); - } + Stmt *sel(Stmt *op1, Stmt *op2, Stmt *op3) { + return insert(TernaryOpType::select, load(op1), load(op2), + load(op3)); + } - Stmt *cos(Stmt *op1) { - return insert(UnaryOpType::cos, load(op1)); - } + Stmt *cos(Stmt *op1) { + return insert(UnaryOpType::cos, load(op1)); + } - Stmt *sin(Stmt *op1) { - return insert(UnaryOpType::sin, load(op1)); - } + Stmt *sin(Stmt *op1) { + return insert(UnaryOpType::sin, load(op1)); + } - Stmt *log(Stmt *op1) { - return insert(UnaryOpType::log, load(op1)); - } + Stmt *log(Stmt *op1) { + return insert(UnaryOpType::log, load(op1)); + } - Stmt *pow(Stmt *op1, Stmt *op2) { - return insert(BinaryOpType::pow, load(op1), load(op2)); - } + Stmt *pow(Stmt *op1, Stmt *op2) { + return insert(BinaryOpType::pow, load(op1), load(op2)); + } -public: + public: virtual Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) = 0; template @@ -693,7 +694,7 @@ class MakeAdjoint : public ADTransform { } } - Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) override{ + Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) override { auto ptr = stmt.get(); current_block->insert(std::move(stmt), -1); return ptr; From 0a01554efa9230f20e6b76116f23b0f1235581b8 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 2 Jun 2022 12:45:06 +0800 Subject: [PATCH 3/4] update --- taichi/transforms/auto_diff.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 51c196f648b23..764b0c2de1d51 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -489,7 +489,6 @@ class ReverseOuterLoops : public BasicStmtVisitor { }; // Base class for both reverse (make adjoint) and forward (make dual) mode -// autodiff class ADTransform : public IRVisitor { protected: Stmt *constant(float32 x) { From cc5cd3f002a3125771381ccc222dca0f6d011b18 Mon Sep 17 00:00:00 2001 From: mingrui Date: Thu, 2 Jun 2022 19:50:24 +0800 Subject: [PATCH 4/4] change the name of the insert function --- taichi/transforms/auto_diff.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 764b0c2de1d51..798d00509a0ee 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -555,11 +555,11 @@ class ADTransform : public IRVisitor { } public: - virtual Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) = 0; + virtual Stmt *insert_grad_stmt(std::unique_ptr &&stmt) = 0; template Stmt *insert(Args &&...args) { - return insert_ad_transform_stmt(Stmt::make(args...)); + return insert_grad_stmt(Stmt::make(args...)); } void visit(AllocaStmt *alloca) override { @@ -693,7 +693,7 @@ class MakeAdjoint : public ADTransform { } } - Stmt *insert_ad_transform_stmt(std::unique_ptr &&stmt) override { + Stmt *insert_grad_stmt(std::unique_ptr &&stmt) override { auto ptr = stmt.get(); current_block->insert(std::move(stmt), -1); return ptr; @@ -906,14 +906,14 @@ class MakeAdjoint : public ADTransform { } current_block = old_current_block; } - insert_ad_transform_stmt(std::move(new_if)); + insert_grad_stmt(std::move(new_if)); } void visit(RangeForStmt *for_stmt) override { auto new_for = for_stmt->clone(); auto new_for_ptr = new_for->as(); new_for_ptr->reversed = !new_for_ptr->reversed; - insert_ad_transform_stmt(std::move(new_for)); + insert_grad_stmt(std::move(new_for)); const int len = new_for_ptr->body->size(); for (int i = 0; i < len; i++) {