Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vulkan] [refactor] [bug] Redesign gfx::OfflineCacheManager to unify compilation of kernels on vulkan  #5889

Merged
merged 6 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ bool KernelCodeGen::maybe_read_compilation_from_cache(
return false;
}
data.swap(cache_data.compiled_data_list);
kernel->set_from_offline_cache();
kernel->mark_as_from_cache();
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void Kernel::operator()(LaunchContextBuilder &ctx_builder) {
compile();
}

if (!this->from_offline_cache_) {
if (!from_cache_) {
for (auto &offloaded : ir->as<Block>()->statements) {
account_for_offloaded(offloaded->as<OffloadedStmt>());
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ class TI_DLL_EXPORT Kernel : public Callable {
return task_counter_++;
}

void set_from_offline_cache() {
this->from_offline_cache_ = true;
void mark_as_from_cache() {
from_cache_ = true;
}

[[nodiscard]] std::string get_name() const override;
Expand Down Expand Up @@ -165,7 +165,7 @@ class TI_DLL_EXPORT Kernel : public Callable {
bool lowered_{false};
std::atomic<uint64> task_counter_{0};
std::string kernel_key_;
bool from_offline_cache_{false};
bool from_cache_{false};
};

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion taichi/runtime/gfx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ target_sources(gfx_runtime
snode_tree_manager.cpp
aot_module_builder_impl.cpp
aot_module_loader_impl.cpp
offline_cache_manager.cpp
cache_manager.cpp
)
#TODO 4832, some dependencies here should not be required as they
# are build requirements of other targets.
Expand Down
153 changes: 153 additions & 0 deletions taichi/runtime/gfx/cache_manager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include "taichi/runtime/gfx/cache_manager.h"
#include "taichi/analysis/offline_cache_util.h"
#include "taichi/codegen/spirv/snode_struct_compiler.h"
#include "taichi/common/cleanup.h"
#include "taichi/program/kernel.h"
#include "taichi/runtime/gfx/aot_module_loader_impl.h"
#include "taichi/runtime/gfx/snode_tree_manager.h"
#include "taichi/util/lock.h"

namespace taichi {
namespace lang {
namespace gfx {

namespace {

constexpr char kMetadataFileLockName[] = "metadata.lock";
using CompiledKernelData = gfx::GfxRuntime::RegisterParams;

} // namespace

CacheManager::CacheManager(Params &&init_params)
: mode_(init_params.mode),
runtime_(init_params.runtime),
compiled_structs_(*init_params.compiled_structs) {
TI_ASSERT(init_params.runtime);
TI_ASSERT(init_params.target_device);

path_ = offline_cache::get_cache_path_by_arch(init_params.cache_path,
init_params.arch);

if (taichi::path_exists(taichi::join_path(path_, "metadata.tcb")) &&
taichi::path_exists(taichi::join_path(path_, "graphs.tcb"))) {
auto lock_path = taichi::join_path(path_, kMetadataFileLockName);
if (lock_with_file(lock_path)) {
auto _ = make_cleanup([&lock_path]() {
if (!unlock_with_file(lock_path)) {
TI_WARN("Unlock {} failed", lock_path);
}
});
gfx::AotModuleParams params;
params.module_path = path_;
params.runtime = runtime_;
cached_module_ = gfx::make_aot_module(params, init_params.arch);
}
}

caching_module_builder_ = std::make_unique<gfx::AotModuleBuilderImpl>(
compiled_structs_, init_params.arch,
std::move(init_params.target_device));
}

CompiledKernelData CacheManager::load_or_compile(CompileConfig *config,
Kernel *kernel) {
if (kernel->is_evaluator) {
spirv::lower(kernel);
return gfx::run_codegen(kernel, runtime_->get_ti_device(),
compiled_structs_);
}
std::string kernel_key = make_kernel_key(config, kernel);
if (mode_ > NotCache) {
if (auto opt = this->try_load_cached_kernel(kernel, kernel_key)) {
return *opt;
}
}
return this->compile_and_cache_kernel(kernel_key, kernel);
}

void CacheManager::dump_with_merging() const {
if (mode_ == MemAndDiskCache) {
taichi::create_directories(path_);
auto *cache_builder =
static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
cache_builder->mangle_aot_data();

auto lock_path = taichi::join_path(path_, kMetadataFileLockName);
if (lock_with_file(lock_path)) {
auto _ = make_cleanup([&lock_path]() {
if (!unlock_with_file(lock_path)) {
TI_WARN("Unlock {} failed", lock_path);
}
});
cache_builder->merge_with_old_meta_data(path_);
cache_builder->dump(path_, "");
}
}
}

std::optional<CompiledKernelData> CacheManager::try_load_cached_kernel(
Kernel *kernel,
const std::string &key) {
if (mode_ == NotCache) {
return std::nullopt;
}
// Find in memory-cache
auto *cache_builder =
static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
auto params_opt = cache_builder->try_get_kernel_register_params(key);
if (params_opt.has_value()) {
TI_DEBUG("Create kernel '{}' from in-memory cache (key='{}')",
kernel->get_name(), key);
kernel->mark_as_from_cache();
// TODO: Support multiple SNodeTrees in AOT.
params_opt->num_snode_trees = compiled_structs_.size();
return params_opt;
}
// Find in disk-cache
if (mode_ == MemAndDiskCache && cached_module_) {
if (auto *aot_kernel = cached_module_->get_kernel(key)) {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
key);
kernel->mark_as_from_cache();
auto *aot_kernel_impl = static_cast<gfx::KernelImpl *>(aot_kernel);
auto compiled = aot_kernel_impl->params();
// TODO: Support multiple SNodeTrees in AOT.
compiled.num_snode_trees = compiled_structs_.size();
return compiled;
}
}
return std::nullopt;
}

CompiledKernelData CacheManager::compile_and_cache_kernel(
const std::string &key,
Kernel *kernel) {
TI_DEBUG_IF(mode_ == MemAndDiskCache, "Cache kernel '{}' (key='{}')",
kernel->get_name(), key);
auto *cache_builder =
static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
TI_ASSERT(cache_builder != nullptr);
cache_builder->add(key, kernel);
auto params_opt = cache_builder->try_get_kernel_register_params(key);
TI_ASSERT(params_opt.has_value());
// TODO: Support multiple SNodeTrees in AOT.
params_opt->num_snode_trees = compiled_structs_.size();
return *params_opt;
}

std::string CacheManager::make_kernel_key(CompileConfig *config,
Kernel *kernel) const {
if (mode_ < MemAndDiskCache) {
return kernel->get_name();
}
auto key = kernel->get_cached_kernel_key();
if (key.empty()) {
key = get_hashed_offline_cache_key(config, kernel);
kernel->set_kernel_key_for_cache(key);
}
return key;
}

} // namespace gfx
} // namespace lang
} // namespace taichi
49 changes: 49 additions & 0 deletions taichi/runtime/gfx/cache_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/runtime/gfx/runtime.h"

namespace taichi {
namespace lang {
namespace gfx {

class CacheManager {
public:
enum Mode { NotCache, MemCache, MemAndDiskCache };

struct Params {
Arch arch;
Mode mode{MemCache};
std::string cache_path;
GfxRuntime *runtime{nullptr};
std::unique_ptr<aot::TargetDevice> target_device;
const std::vector<spirv::CompiledSNodeStructs> *compiled_structs;
};

using CompiledKernelData = gfx::GfxRuntime::RegisterParams;

CacheManager(Params &&init_params);

CompiledKernelData load_or_compile(CompileConfig *config, Kernel *kernel);
void dump_with_merging() const;

private:
std::optional<CompiledKernelData> try_load_cached_kernel(
Kernel *kernel,
const std::string &key);
CompiledKernelData compile_and_cache_kernel(const std::string &key,
Kernel *kernel);
std::string make_kernel_key(CompileConfig *config, Kernel *kernel) const;

Mode mode_{MemCache};
std::string path_;
GfxRuntime *runtime_{nullptr};
const std::vector<spirv::CompiledSNodeStructs> &compiled_structs_;
std::unique_ptr<AotModuleBuilder> caching_module_builder_{nullptr};
std::unique_ptr<aot::Module> cached_module_{nullptr};
};

} // namespace gfx
} // namespace lang
} // namespace taichi
91 changes: 0 additions & 91 deletions taichi/runtime/gfx/offline_cache_manager.cpp

This file was deleted.

34 changes: 0 additions & 34 deletions taichi/runtime/gfx/offline_cache_manager.h

This file was deleted.

Loading