Skip to content

Commit fae94a2

Browse files
authored
[aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTModuleBuilder to support Fields (#5120)
* [aot] [llvm] Implemented FieldCacheData and refactored initialize_llvm_runtime_snodes() * Addressed compilation erros * [aot] [llvm] LLVM AOT Field #1: Adjust serialization/deserialization logics for FieldCacheData * [llvm] [aot] Added Field support for LLVM AOT * [aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTModuleBuilder to support Fields * Fixed merge issues * Stopped abusing Program*
1 parent 6349d60 commit fae94a2

13 files changed

+190
-46
lines changed

taichi/backends/cpu/aot_module_builder_impl.h

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ namespace lang {
99
namespace cpu {
1010

1111
class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
12+
public:
13+
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
14+
: LlvmAotModuleBuilder(prog) {
15+
}
16+
1217
private:
1318
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
1419
};

taichi/backends/cpu/aot_module_loader_impl.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
4444
TI_NOT_IMPLEMENTED;
4545
return nullptr;
4646
}
47-
48-
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
49-
TI_NOT_IMPLEMENTED;
50-
return nullptr;
51-
}
5247
};
5348

5449
} // namespace

taichi/backends/cuda/aot_module_builder_impl.h

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ namespace lang {
99
namespace cuda {
1010

1111
class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
12+
public:
13+
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
14+
: LlvmAotModuleBuilder(prog) {
15+
}
16+
1217
private:
1318
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
1419
};

taichi/backends/cuda/aot_module_loader_impl.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
4444
TI_NOT_IMPLEMENTED;
4545
return nullptr;
4646
}
47-
48-
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
49-
TI_NOT_IMPLEMENTED;
50-
return nullptr;
51-
}
5247
};
5348

5449
} // namespace

taichi/ir/snode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void SNode::set_snode_tree_id(int id) {
326326
snode_tree_id_ = id;
327327
}
328328

329-
int SNode::get_snode_tree_id() {
329+
int SNode::get_snode_tree_id() const {
330330
return snode_tree_id_;
331331
}
332332

taichi/ir/snode.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class SNode {
354354

355355
void set_snode_tree_id(int id);
356356

357-
int get_snode_tree_id();
357+
int get_snode_tree_id() const;
358358

359359
static void reset_counter() {
360360
counter = 0;

taichi/llvm/llvm_aot_module_builder.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44
#include "taichi/llvm/launch_arg_info.h"
5+
#include "taichi/llvm/llvm_program.h"
56

67
namespace taichi {
78
namespace lang {
@@ -34,5 +35,37 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
3435
cache_.kernels[identifier] = std::move(kcache);
3536
}
3637

38+
void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier,
39+
const SNode *rep_snode,
40+
bool is_scalar,
41+
DataType dt,
42+
std::vector<int> shape,
43+
int row_num,
44+
int column_num) {
45+
// Field refers to a leaf node(Place SNode) in a SNodeTree.
46+
// It makes no sense to just serialize the leaf node or its corresponding
47+
// branch. Instead, the minimal unit we have to serialize is the entire
48+
// SNodeTree. Note that SNodeTree's uses snode_tree_id as its identifier,
49+
// rather than the field's name. (multiple fields may end up referring to the
50+
// same SNodeTree)
51+
52+
// 1. Find snode_tree_id
53+
int snode_tree_id = rep_snode->get_snode_tree_id();
54+
55+
// 2. Fetch Cache from the Program
56+
// Kernel compilation is not allowed until all the Fields are finalized,
57+
// so we finished SNodeTree compilation during AOTModuleBuilder construction.
58+
//
59+
// By the time "add_field_per_backend()" is called,
60+
// SNodeTrees should have already been finalized,
61+
// with compiled info stored in LlvmProgramImpl::cache_data_.
62+
TI_ASSERT(prog_ != nullptr);
63+
LlvmOfflineCache::FieldCacheData field_cache =
64+
prog_->get_cached_field(snode_tree_id);
65+
66+
// 3. Update AOT Cache
67+
cache_.fields[snode_tree_id] = std::move(field_cache);
68+
}
69+
3770
} // namespace lang
3871
} // namespace taichi

taichi/llvm/llvm_aot_module_builder.h

+12
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,27 @@ namespace lang {
99

1010
class LlvmAotModuleBuilder : public AotModuleBuilder {
1111
public:
12+
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
13+
}
14+
1215
void dump(const std::string &output_dir,
1316
const std::string &filename) const override;
1417

1518
protected:
1619
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
1720
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0;
1821

22+
void add_field_per_backend(const std::string &identifier,
23+
const SNode *rep_snode,
24+
bool is_scalar,
25+
DataType dt,
26+
std::vector<int> shape,
27+
int row_num,
28+
int column_num) override;
29+
1930
private:
2031
mutable LlvmOfflineCache cache_;
32+
LlvmProgramImpl *prog_ = nullptr;
2133
};
2234

2335
} // namespace lang

taichi/llvm/llvm_aot_module_loader.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ class KernelImpl : public aot::Kernel {
1717
FunctionType fn_;
1818
};
1919

20+
class FieldImpl : public aot::Field {
21+
public:
22+
explicit FieldImpl(const LlvmOfflineCache::FieldCacheData &field)
23+
: field_(field) {
24+
}
25+
26+
explicit FieldImpl(LlvmOfflineCache::FieldCacheData &&field)
27+
: field_(std::move(field)) {
28+
}
29+
30+
LlvmOfflineCache::FieldCacheData get_field() const {
31+
return field_;
32+
}
33+
34+
private:
35+
LlvmOfflineCache::FieldCacheData field_;
36+
};
37+
2038
} // namespace
2139

2240
LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache(
@@ -37,5 +55,42 @@ std::unique_ptr<aot::Kernel> LlvmAotModule::make_new_kernel(
3755
return std::make_unique<KernelImpl>(fn);
3856
}
3957

58+
std::unique_ptr<aot::Field> LlvmAotModule::make_new_field(
59+
const std::string &name) {
60+
// Check if "name" represents snode_tree_id.
61+
// Avoid using std::atoi due to its poor error handling.
62+
char *end;
63+
int snode_tree_id = static_cast<int>(strtol(name.c_str(), &end, 10 /*base*/));
64+
65+
TI_ASSERT(end != name.c_str());
66+
TI_ASSERT(*end == '\0');
67+
68+
// Load FieldCache
69+
LlvmOfflineCache::FieldCacheData loaded;
70+
auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id);
71+
TI_ERROR_IF(!ok, "Failed to load field with id={}", snode_tree_id);
72+
73+
return std::make_unique<FieldImpl>(std::move(loaded));
74+
}
75+
76+
void finalize_aot_field(aot::Module *aot_module,
77+
aot::Field *aot_field,
78+
uint64 *result_buffer) {
79+
auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module);
80+
auto *aot_field_impl = dynamic_cast<FieldImpl *>(aot_field);
81+
82+
TI_ASSERT(llvm_aot_module != nullptr);
83+
TI_ASSERT(aot_field_impl != nullptr);
84+
85+
auto *llvm_prog = llvm_aot_module->get_program();
86+
const auto &field_cache = aot_field_impl->get_field();
87+
88+
int snode_tree_id = field_cache.tree_id;
89+
if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) {
90+
llvm_prog->initialize_llvm_runtime_snodes(field_cache, result_buffer);
91+
llvm_aot_module->set_initialized_snode_tree(snode_tree_id);
92+
}
93+
}
94+
4095
} // namespace lang
4196
} // namespace taichi

taichi/llvm/llvm_aot_module_loader.h

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
namespace taichi {
77
namespace lang {
88

9+
TI_DLL_EXPORT void finalize_aot_field(aot::Module *aot_module,
10+
aot::Field *aot_field,
11+
uint64 *result_buffer);
12+
913
class LlvmAotModule : public aot::Module {
1014
public:
1115
explicit LlvmAotModule(const std::string &module_path,
@@ -27,6 +31,18 @@ class LlvmAotModule : public aot::Module {
2731
return 0;
2832
}
2933

34+
LlvmProgramImpl *const get_program() {
35+
return program_;
36+
}
37+
38+
void set_initialized_snode_tree(int snode_tree_id) {
39+
initialized_snode_tree_ids.insert(snode_tree_id);
40+
}
41+
42+
bool is_snode_tree_initialized(int snode_tree_id) {
43+
return initialized_snode_tree_ids.count(snode_tree_id);
44+
}
45+
3046
protected:
3147
virtual FunctionType convert_module_to_function(
3248
const std::string &name,
@@ -38,8 +54,13 @@ class LlvmAotModule : public aot::Module {
3854
std::unique_ptr<aot::Kernel> make_new_kernel(
3955
const std::string &name) override;
4056

57+
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override;
58+
4159
LlvmProgramImpl *const program_{nullptr};
4260
std::unique_ptr<LlvmOfflineCacheFileReader> cache_reader_{nullptr};
61+
62+
// To prevent repeated SNodeTree initialization
63+
std::unordered_set<int> initialized_snode_tree_ids;
4364
};
4465

4566
} // namespace lang

taichi/llvm/llvm_offline_cache.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ struct LlvmOfflineCache {
9595
std::unordered_map<std::string, KernelCacheData>
9696
kernels; // key = kernel_name
9797

98-
TI_IO_DEF(kernels);
98+
TI_IO_DEF(fields, kernels);
9999
};
100100

101101
class LlvmOfflineCacheFileReader {

taichi/llvm/llvm_program.cpp

+39-27
Original file line numberDiff line numberDiff line change
@@ -273,37 +273,22 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
273273
}
274274

275275
void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
276-
compile_snode_tree_types_impl(tree);
277-
}
278-
279-
static LlvmOfflineCache::FieldCacheData construct_filed_cache_data(
280-
const SNodeTree &tree,
281-
const StructCompiler &struct_compiler) {
282-
LlvmOfflineCache::FieldCacheData ret;
283-
ret.tree_id = tree.id();
284-
ret.root_id = tree.root()->id;
285-
ret.root_size = struct_compiler.root_size;
286-
287-
const auto &snodes = struct_compiler.snodes;
288-
for (size_t i = 0; i < snodes.size(); i++) {
289-
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
290-
snode_cache_data.id = snodes[i]->id;
291-
snode_cache_data.type = snodes[i]->type;
292-
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
293-
snode_cache_data.chunk_size = snodes[i]->chunk_size;
294-
295-
ret.snode_metas.emplace_back(std::move(snode_cache_data));
296-
}
276+
auto struct_compiler = compile_snode_tree_types_impl(tree);
277+
int snode_tree_id = tree->id();
278+
int root_id = tree->root()->id;
297279

298-
return ret;
280+
// Add compiled result to Cache
281+
cache_field(snode_tree_id, root_id, *struct_compiler);
299282
}
300283

301284
void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
302285
uint64 *result_buffer) {
303-
auto struct_compiler = compile_snode_tree_types_impl(tree);
286+
compile_snode_tree_types(tree);
287+
int snode_tree_id = tree->id();
304288

305-
auto field_cache_data = construct_filed_cache_data(*tree, *struct_compiler);
306-
initialize_llvm_runtime_snodes(field_cache_data, result_buffer);
289+
TI_ASSERT(cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end());
290+
initialize_llvm_runtime_snodes(cache_data_.fields.at(snode_tree_id),
291+
result_buffer);
307292
}
308293

309294
uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) {
@@ -365,12 +350,12 @@ void LlvmProgramImpl::print_list_manager_info(void *list_manager,
365350

366351
std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder() {
367352
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
368-
return std::make_unique<cpu::AotModuleBuilderImpl>();
353+
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
369354
}
370355

371356
#if defined(TI_WITH_CUDA)
372357
if (config->arch == Arch::cuda) {
373-
return std::make_unique<cuda::AotModuleBuilderImpl>();
358+
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
374359
}
375360
#endif
376361

@@ -701,6 +686,33 @@ void LlvmProgramImpl::cache_kernel(
701686
kernel_cache.offloaded_task_list = std::move(offloaded_task_list);
702687
}
703688

689+
void LlvmProgramImpl::cache_field(int snode_tree_id,
690+
int root_id,
691+
const StructCompiler &struct_compiler) {
692+
if (cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end()) {
693+
// [TODO] check and update the Cache, instead of simply return.
694+
return;
695+
}
696+
697+
LlvmOfflineCache::FieldCacheData ret;
698+
ret.tree_id = snode_tree_id;
699+
ret.root_id = root_id;
700+
ret.root_size = struct_compiler.root_size;
701+
702+
const auto &snodes = struct_compiler.snodes;
703+
for (size_t i = 0; i < snodes.size(); i++) {
704+
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
705+
snode_cache_data.id = snodes[i]->id;
706+
snode_cache_data.type = snodes[i]->type;
707+
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
708+
snode_cache_data.chunk_size = snodes[i]->chunk_size;
709+
710+
ret.snode_metas.emplace_back(std::move(snode_cache_data));
711+
}
712+
713+
cache_data_.fields[snode_tree_id] = std::move(ret);
714+
}
715+
704716
void LlvmProgramImpl::dump_cache_data_to_disk() {
705717
if (config->offline_cache && !cache_data_.kernels.empty()) {
706718
LlvmOfflineCacheFileWriter writer{};

taichi/llvm/llvm_program.h

+17-6
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,34 @@ class LlvmProgramImpl : public ProgramImpl {
118118
std::vector<LlvmOfflineCache::OffloadedTaskCacheData>
119119
&&offloaded_task_list);
120120

121+
void cache_field(int snode_tree_id,
122+
int root_id,
123+
const StructCompiler &struct_compiler);
124+
125+
LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const {
126+
TI_ASSERT(cache_data_.fields.find(snode_tree_id) !=
127+
cache_data_.fields.end());
128+
return cache_data_.fields.at(snode_tree_id);
129+
}
130+
121131
Device *get_compute_device() override {
122132
return device_.get();
123133
}
124134

135+
/**
136+
* Initializes the SNodes for LLVM based backends.
137+
*/
138+
void initialize_llvm_runtime_snodes(
139+
const LlvmOfflineCache::FieldCacheData &field_cache_data,
140+
uint64 *result_buffer);
141+
125142
private:
126143
std::unique_ptr<llvm::Module> clone_struct_compiler_initial_context(
127144
bool has_multiple_snode_trees,
128145
TaichiLLVMContext *tlctx);
129146

130147
std::unique_ptr<StructCompiler> compile_snode_tree_types_impl(
131148
SNodeTree *tree);
132-
/**
133-
* Initializes the SNodes for LLVM based backends.
134-
*/
135-
void initialize_llvm_runtime_snodes(
136-
const LlvmOfflineCache::FieldCacheData &field_cache_data,
137-
uint64 *result_buffer);
138149

139150
uint64 fetch_result_uint64(int i, uint64 *result_buffer);
140151

0 commit comments

Comments
 (0)