@@ -273,37 +273,22 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
273
273
}
274
274
275
275
void LlvmProgramImpl::compile_snode_tree_types (SNodeTree *tree) {
276
- compile_snode_tree_types_impl (tree);
277
- }
278
-
279
- static LlvmOfflineCache::FieldCacheData construct_filed_cache_data (
280
- const SNodeTree &tree,
281
- const StructCompiler &struct_compiler) {
282
- LlvmOfflineCache::FieldCacheData ret;
283
- ret.tree_id = tree.id ();
284
- ret.root_id = tree.root ()->id ;
285
- ret.root_size = struct_compiler.root_size ;
286
-
287
- const auto &snodes = struct_compiler.snodes ;
288
- for (size_t i = 0 ; i < snodes.size (); i++) {
289
- LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
290
- snode_cache_data.id = snodes[i]->id ;
291
- snode_cache_data.type = snodes[i]->type ;
292
- snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes ;
293
- snode_cache_data.chunk_size = snodes[i]->chunk_size ;
294
-
295
- ret.snode_metas .emplace_back (std::move (snode_cache_data));
296
- }
276
+ auto struct_compiler = compile_snode_tree_types_impl (tree);
277
+ int snode_tree_id = tree->id ();
278
+ int root_id = tree->root ()->id ;
297
279
298
- return ret;
280
+ // Add compiled result to Cache
281
+ cache_field (snode_tree_id, root_id, *struct_compiler);
299
282
}
300
283
301
284
void LlvmProgramImpl::materialize_snode_tree (SNodeTree *tree,
302
285
uint64 *result_buffer) {
303
- auto struct_compiler = compile_snode_tree_types_impl (tree);
286
+ compile_snode_tree_types (tree);
287
+ int snode_tree_id = tree->id ();
304
288
305
- auto field_cache_data = construct_filed_cache_data (*tree, *struct_compiler);
306
- initialize_llvm_runtime_snodes (field_cache_data, result_buffer);
289
+ TI_ASSERT (cache_data_.fields .find (snode_tree_id) != cache_data_.fields .end ());
290
+ initialize_llvm_runtime_snodes (cache_data_.fields .at (snode_tree_id),
291
+ result_buffer);
307
292
}
308
293
309
294
uint64 LlvmProgramImpl::fetch_result_uint64 (int i, uint64 *result_buffer) {
@@ -365,12 +350,12 @@ void LlvmProgramImpl::print_list_manager_info(void *list_manager,
365
350
366
351
std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder () {
367
352
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
368
- return std::make_unique<cpu::AotModuleBuilderImpl>();
353
+ return std::make_unique<cpu::AotModuleBuilderImpl>(this );
369
354
}
370
355
371
356
#if defined(TI_WITH_CUDA)
372
357
if (config->arch == Arch::cuda) {
373
- return std::make_unique<cuda::AotModuleBuilderImpl>();
358
+ return std::make_unique<cuda::AotModuleBuilderImpl>(this );
374
359
}
375
360
#endif
376
361
@@ -701,6 +686,33 @@ void LlvmProgramImpl::cache_kernel(
701
686
kernel_cache.offloaded_task_list = std::move (offloaded_task_list);
702
687
}
703
688
689
+ void LlvmProgramImpl::cache_field (int snode_tree_id,
690
+ int root_id,
691
+ const StructCompiler &struct_compiler) {
692
+ if (cache_data_.fields .find (snode_tree_id) != cache_data_.fields .end ()) {
693
+ // [TODO] check and update the Cache, instead of simply return.
694
+ return ;
695
+ }
696
+
697
+ LlvmOfflineCache::FieldCacheData ret;
698
+ ret.tree_id = snode_tree_id;
699
+ ret.root_id = root_id;
700
+ ret.root_size = struct_compiler.root_size ;
701
+
702
+ const auto &snodes = struct_compiler.snodes ;
703
+ for (size_t i = 0 ; i < snodes.size (); i++) {
704
+ LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
705
+ snode_cache_data.id = snodes[i]->id ;
706
+ snode_cache_data.type = snodes[i]->type ;
707
+ snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes ;
708
+ snode_cache_data.chunk_size = snodes[i]->chunk_size ;
709
+
710
+ ret.snode_metas .emplace_back (std::move (snode_cache_data));
711
+ }
712
+
713
+ cache_data_.fields [snode_tree_id] = std::move (ret);
714
+ }
715
+
704
716
void LlvmProgramImpl::dump_cache_data_to_disk () {
705
717
if (config->offline_cache && !cache_data_.kernels .empty ()) {
706
718
LlvmOfflineCacheFileWriter writer{};
0 commit comments