Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autodiff] Add forward mode pipeline for autodiff pass #5098

Merged
merged 19 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void autograd() {
std::make_unique<Kernel>(program, builder.extract_ir(), "init");
}

auto get_kernel_cal = [&](bool grad) -> Kernel * {
auto get_kernel_cal = [&](AutodiffMode autodiff_mode) -> Kernel * {
IRBuilder builder;
auto *loop = builder.create_struct_for(a, 0, 4);
{
Expand All @@ -132,10 +132,11 @@ void autograd() {
std::make_unique<AtomicOpStmt>(AtomicOpType::add, c_i, val));
}

return new Kernel(program, builder.extract_ir(), "cal", grad);
return new Kernel(program, builder.extract_ir(), "cal", autodiff_mode);
};
kernel_forward = std::unique_ptr<Kernel>(get_kernel_cal(false));
kernel_backward = std::unique_ptr<Kernel>(get_kernel_cal(true));
kernel_forward = std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kNone));
kernel_backward =
std::unique_ptr<Kernel>(get_kernel_cal(AutodiffMode::kReverseWithStack));

{
IRBuilder builder;
Expand Down
3 changes: 2 additions & 1 deletion python/taichi/lang/enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from taichi._lib import core as _ti_core

Layout = _ti_core.Layout
AutodiffMode = _ti_core.AutodiffMode

__all__ = ['Layout']
__all__ = ['Layout', 'AutodiffMode']
30 changes: 18 additions & 12 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from taichi.lang.ast import (ASTTransformerContext, KernelSimplicityASTChecker,
transform_tree)
from taichi.lang.ast.ast_transformer_utils import ReturnStatus
from taichi.lang.enums import Layout
from taichi.lang.enums import AutodiffMode, Layout
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiRuntimeTypeError, TaichiSyntaxError,
handle_exception_from_cpp)
Expand Down Expand Up @@ -210,7 +210,8 @@ def __call__(self, *args, **kwargs):
return self.func(*args)

if self.is_real_function:
if impl.get_runtime().current_kernel.is_grad:
if impl.get_runtime(
).current_kernel.autodiff_mode != AutodiffMode.NONE:
raise TaichiSyntaxError(
"Real function in gradient kernels unsupported.")
instance_id, _ = self.mapper.lookup(args)
Expand Down Expand Up @@ -400,11 +401,11 @@ def _get_global_vars(_func):
class Kernel:
counter = 0

def __init__(self, _func, is_grad, _classkernel=False):
def __init__(self, _func, autodiff_mode, _classkernel=False):
self.func = _func
self.kernel_counter = Kernel.counter
Kernel.counter += 1
self.is_grad = is_grad
self.autodiff_mode = autodiff_mode
self.grad = None
self.arguments = []
self.return_type = None
Expand All @@ -422,7 +423,7 @@ def __init__(self, _func, is_grad, _classkernel=False):

def reset(self):
self.runtime = impl.get_runtime()
if self.is_grad:
if self.autodiff_mode != AutodiffMode.NONE:
self.compiled_functions = self.runtime.compiled_grad_functions
else:
self.compiled_functions = self.runtime.compiled_functions
Expand Down Expand Up @@ -485,7 +486,7 @@ def materialize(self, key=None, args=None, arg_features=None):
if key in self.compiled_functions:
return
grad_suffix = ""
if self.is_grad:
if self.autodiff_mode != AutodiffMode.NONE:
grad_suffix = "_grad"
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}"
_logging.trace(f"Compiling kernel {kernel_name}...")
Expand All @@ -496,7 +497,7 @@ def materialize(self, key=None, args=None, arg_features=None):
excluded_parameters=self.template_slot_locations,
arg_features=arg_features)

if self.is_grad:
if self.autodiff_mode != AutodiffMode.NONE:
KernelSimplicityASTChecker(self.func).visit(tree)

if impl.current_cfg().use_mesh:
Expand Down Expand Up @@ -526,7 +527,7 @@ def taichi_ast_generator(kernel_cxx):
self.runtime.current_kernel = None

taichi_kernel = impl.get_runtime().prog.create_kernel(
taichi_ast_generator, kernel_name, self.is_grad)
taichi_ast_generator, kernel_name, self.autodiff_mode)

self.kernel_cpp = taichi_kernel

Expand Down Expand Up @@ -714,7 +715,7 @@ def func__(*args):
# Both the class kernels and the plain-function kernels are unified now.
# In both cases, |self.grad| is another Kernel instance that computes the
# gradient. For class kernels, args[0] is always the kernel owner.
if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced:
if self.autodiff_mode == AutodiffMode.NONE and self.runtime.target_tape and not self.runtime.grad_replaced:
self.runtime.target_tape.insert(self, args)

if actual_argument_slot > 8 and (
Expand Down Expand Up @@ -786,7 +787,8 @@ def ensure_compiled(self, *args):
@_shell_pop_print
def __call__(self, *args, **kwargs):
args = _process_args(self, args, kwargs)
if self.is_grad and impl.current_cfg().opt_level == 0:
if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg(
).opt_level == 0:
_logging.warn(
"""opt_level = 1 is enforced to enable gradient computation."""
)
Expand Down Expand Up @@ -834,8 +836,12 @@ def _kernel_impl(_func, level_of_class_stackframe, verbose=False):

if verbose:
print(f'kernel={_func.__name__} is_classkernel={is_classkernel}')
primal = Kernel(_func, is_grad=False, _classkernel=is_classkernel)
adjoint = Kernel(_func, is_grad=True, _classkernel=is_classkernel)
primal = Kernel(_func,
autodiff_mode=AutodiffMode.NONE,
_classkernel=is_classkernel)
adjoint = Kernel(_func,
autodiff_mode=AutodiffMode.REVERSE_WITH_STACK,
_classkernel=is_classkernel)
# Having |primal| contains |grad| makes the tape work.
primal.grad = adjoint

Expand Down
3 changes: 2 additions & 1 deletion taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ std::string get_hashed_offline_cache_key(CompileConfig *config,
hasher.finish();

auto res = picosha2::get_hash_hex_string(hasher);
res.insert(res.begin(), kernel->grad ? 'g' : 'n');
res.insert(res.begin(),
kernel->autodiff_mode != AutodiffMode::kNone ? 'g' : 'n');
return res;
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class CCTransformer : public IRVisitor {
auto ir = kernel_->ir.get();
auto config = kernel_->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel_, kernel_->grad,
/*ad_use_stack=*/true, config.print_ir,
irpass::compile_to_executable(ir, config, kernel_,
/*autodiff_mode=*/kernel_->autodiff_mode,
config.print_ir,
/*lower_global_access*/ true);
}

Expand Down
8 changes: 6 additions & 2 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,8 +1215,12 @@ void OpenglCodeGen::lower() {
auto ir = kernel_->ir.get();
auto &config = kernel_->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(ir, config, kernel_, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir,
if (kernel_->autodiff_mode == AutodiffMode::kReverseWithStack) {
kernel_->autodiff_mode = AutodiffMode::kReverseWithoutStack;
}
irpass::compile_to_executable(ir, config, kernel_,
/*autodiff_mode=*/kernel_->autodiff_mode,
config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/config.make_thread_local);
#ifdef _GLSL_DEBUG
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2197,8 +2197,8 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs,
void lower(Kernel *kernel) {
auto &config = kernel->program->config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel->ir.get(), config, kernel, kernel->grad,
/*ad_use_stack=*/false, config.print_ir,
irpass::compile_to_executable(kernel->ir.get(), config, kernel,
kernel->autodiff_mode, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/false);
}
Expand Down
7 changes: 7 additions & 0 deletions taichi/inc/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ T taichi_union_cast(G g) {
}

enum class ExternalArrayLayout { kAOS, kSOA, kNull };

enum class AutodiffMode {
kForward,
kReverseWithStack,
kReverseWithoutStack,
kNone
};
10 changes: 4 additions & 6 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ bool lower_access(IRNode *root,
const LowerAccessPass::Args &args);
void auto_diff(IRNode *root,
const CompileConfig &config,
bool use_stack = false);
AutodiffMode autodiffMode);
/**
* Determine all adaptive AD-stacks' size. This pass is idempotent, i.e.,
* there are no side effects if called more than once or called when not needed.
Expand Down Expand Up @@ -147,8 +147,7 @@ void compile_to_offloads(IRNode *ir,
const CompileConfig &config,
Kernel *kernel,
bool verbose,
bool grad,
bool ad_use_stack,
AutodiffMode autodiff_mode,
bool start_from_ast);

void offload_to_executable(IRNode *ir,
Expand All @@ -164,8 +163,7 @@ void offload_to_executable(IRNode *ir,
void compile_to_executable(IRNode *ir,
const CompileConfig &config,
Kernel *kernel,
bool grad,
bool ad_use_stack,
AutodiffMode autodiff_mode,
bool verbose,
bool lower_global_access = true,
bool make_thread_local = false,
Expand All @@ -176,7 +174,7 @@ void compile_to_executable(IRNode *ir,
void compile_function(IRNode *ir,
const CompileConfig &config,
Function *func,
bool grad,
AutodiffMode autodiff_mode,
bool verbose,
bool start_from_ast);
} // namespace irpass
Expand Down
4 changes: 2 additions & 2 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ void Function::set_function_body(const std::function<void()> &func) {
func();
}
irpass::compile_function(ir.get(), program->config, this,
/*grad=*/false,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->config.print_ir,
/*start_from_ast=*/true);
}

void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
ir = std::move(func_body);
irpass::compile_function(ir.get(), program->config, this,
/*grad=*/false,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->config.print_ir,
/*start_from_ast=*/false);
}
Expand Down
44 changes: 25 additions & 19 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ class Function;
Kernel::Kernel(Program &program,
const std::function<void()> &func,
const std::string &primal_name,
bool grad) {
this->init(program, func, primal_name, grad);
AutodiffMode autodiff_mode) {
this->init(program, func, primal_name, autodiff_mode);
}

Kernel::Kernel(Program &program,
const std::function<void(Kernel *)> &func,
const std::string &primal_name,
bool grad) {
this->init(program, std::bind(func, this), primal_name, grad);
AutodiffMode autodiff_mode) {
this->init(program, std::bind(func, this), primal_name, autodiff_mode);
}

Kernel::Kernel(Program &program,
std::unique_ptr<IRNode> &&ir,
const std::string &primal_name,
bool grad)
: grad(grad), lowered_(false) {
AutodiffMode autodiff_mode)
: autodiff_mode(autodiff_mode), lowered_(false) {
this->ir = std::move(ir);
this->program = &program;
is_accessor = false;
Expand All @@ -49,10 +49,13 @@ Kernel::Kernel(Program &program,

arch = program.config.arch;

if (!grad) {
if (autodiff_mode == AutodiffMode::kNone) {
name = primal_name;
} else {
name = primal_name + "_grad";
} else if (autodiff_mode == AutodiffMode::kForward) {
name = primal_name + "_forward_grad";
} else if (autodiff_mode == AutodiffMode::kReverseWithStack ||
autodiff_mode == AutodiffMode::kReverseWithoutStack) {
name = primal_name + "_reverse_grad";
}

if (!program.config.lazy_compilation)
Expand Down Expand Up @@ -89,16 +92,16 @@ void Kernel::lower(bool to_executable) {

if (to_executable) {
irpass::compile_to_executable(
ir.get(), config, this, grad,
/*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable,
ir.get(), config, this, /*autodiff_mode=*/autodiff_mode, verbose,
/*lower_global_access=*/to_executable,
/*make_thread_local=*/config.make_thread_local,
/*make_block_local=*/
is_extension_supported(config.arch, Extension::bls) &&
config.make_block_local,
/*start_from_ast=*/ir_is_ast_);
} else {
irpass::compile_to_offloads(ir.get(), config, this, verbose, grad,
/*ad_use_stack=*/true,
irpass::compile_to_offloads(ir.get(), config, this, verbose,
/*autodiff_mode=*/autodiff_mode,
/*start_from_ast=*/ir_is_ast_);
}

Expand Down Expand Up @@ -406,8 +409,8 @@ std::string Kernel::get_name() const {
void Kernel::init(Program &program,
const std::function<void()> &func,
const std::string &primal_name,
bool grad) {
this->grad = grad;
AutodiffMode autodiff_mode) {
this->autodiff_mode = autodiff_mode;
this->lowered_ = false;
this->program = &program;
#ifdef TI_WITH_LLVM
Expand All @@ -424,10 +427,13 @@ void Kernel::init(Program &program,

this->arch = program.config.arch;

if (!grad) {
this->name = primal_name;
} else {
this->name = primal_name + "_grad";
if (autodiff_mode == AutodiffMode::kNone) {
name = primal_name;
} else if (autodiff_mode == AutodiffMode::kForward) {
name = primal_name + "_forward_grad";
} else if (autodiff_mode == AutodiffMode::kReverseWithStack ||
autodiff_mode == AutodiffMode::kReverseWithoutStack) {
name = primal_name + "_reverse_grad";
}

{
Expand Down
10 changes: 5 additions & 5 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TI_DLL_EXPORT Kernel : public Callable {

bool is_accessor{false};
bool is_evaluator{false};
bool grad{false};
AutodiffMode autodiff_mode{AutodiffMode::kNone};

class LaunchContextBuilder {
public:
Expand Down Expand Up @@ -69,17 +69,17 @@ class TI_DLL_EXPORT Kernel : public Callable {
Kernel(Program &program,
const std::function<void()> &func,
const std::string &name = "",
bool grad = false);
AutodiffMode autodiff_mode = AutodiffMode::kNone);

Kernel(Program &program,
const std::function<void(Kernel *)> &func,
const std::string &name = "",
bool grad = false);
AutodiffMode autodiff_mode = AutodiffMode::kNone);

Kernel(Program &program,
std::unique_ptr<IRNode> &&ir,
const std::string &name = "",
bool grad = false);
AutodiffMode autodiff_mode = AutodiffMode::kNone);

bool lowered() const {
return lowered_;
Expand Down Expand Up @@ -136,7 +136,7 @@ class TI_DLL_EXPORT Kernel : public Callable {
void init(Program &program,
const std::function<void()> &func,
const std::string &name = "",
bool grad = false);
AutodiffMode autodiff_mode = AutodiffMode::kNone);

// True if |ir| is a frontend AST. False if it's already offloaded to CHI IR.
bool ir_is_ast_{false};
Expand Down
Loading