Skip to content

Commit 81995fe

Browse files
[bug] Fix crashing when loading old offline cache files (#6089)
Issue: #4401, fixes #6081 In future, if necessary, we should maintain a version number for offline cache instead of using the Taichi version directly. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5e5321e commit 81995fe

File tree

8 files changed

+265
-79
lines changed

8 files changed

+265
-79
lines changed

cmake/TaichiTests.cmake

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ file(GLOB_RECURSE TAICHI_TESTS_SOURCE
2020
"tests/cpp/llvm/*.cpp"
2121
"tests/cpp/program/*.cpp"
2222
"tests/cpp/struct/*.cpp"
23-
"tests/cpp/transforms/*.cpp")
23+
"tests/cpp/transforms/*.cpp"
24+
"tests/cpp/offline_cache/*.cpp")
2425

2526
if (TI_WITH_OPENGL OR TI_WITH_VULKAN)
2627
file(GLOB TAICHI_TESTS_GFX_UTILS_SOURCE

taichi/cache/gfx/cache_manager.cpp

+9-15
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@ struct CacheCleanerUtils<gfx::CacheManager::Metadata> {
4747
using MetadataType = gfx::CacheManager::Metadata;
4848
using KernelMetaData = MetadataType::KernelMetadata;
4949

50-
// To load metadata from file
51-
static bool load_metadata(const CacheCleanerConfig &config,
52-
MetadataType &result) {
53-
return read_from_binary_file(
54-
result, taichi::join_path(config.path, config.metadata_filename));
55-
}
56-
5750
// To save metadata as file
5851
static bool save_metadata(const CacheCleanerConfig &config,
5952
const MetadataType &data) {
@@ -81,13 +74,6 @@ struct CacheCleanerUtils<gfx::CacheManager::Metadata> {
8174
return true;
8275
}
8376

84-
// To check version
85-
static bool check_version(const CacheCleanerConfig &config,
86-
const Version &version) {
87-
return version[0] == TI_VERSION_MAJOR && version[1] == TI_VERSION_MINOR &&
88-
version[2] == TI_VERSION_PATCH;
89-
}
90-
9177
// To get cache files name
9278
static std::vector<std::string> get_cache_files(
9379
const CacheCleanerConfig &config,
@@ -106,6 +92,12 @@ struct CacheCleanerUtils<gfx::CacheManager::Metadata> {
10692
taichi::join_path(config.path, kDebuggingAotMetadataFilename));
10793
taichi::remove(taichi::join_path(config.path, kGraphMetadataFilename));
10894
}
95+
96+
// To check if a file is cache file
97+
static bool is_valid_cache_file(const CacheCleanerConfig &config,
98+
const std::string &name) {
99+
return filename_extension(name) == "spv";
100+
}
109101
};
110102

111103
} // namespace offline_cache
@@ -184,10 +176,12 @@ void CacheManager::dump_with_merging() const {
184176
cache_builder->dump(path_, "");
185177

186178
// Update offline_cache_metadata.tcb
179+
using offline_cache::load_metadata_with_checking;
180+
using Error = offline_cache::LoadMetadataError;
187181
Metadata old_data;
188182
const auto filename =
189183
taichi::join_path(path_, kOfflineCacheMetadataFilename);
190-
if (read_from_binary_file(old_data, filename)) {
184+
if (load_metadata_with_checking(old_data, filename) == Error::kNoError) {
191185
for (auto &[k, v] : offline_cache_metadata_.kernels) {
192186
auto iter = old_data.kernels.find(k);
193187
if (iter != old_data.kernels.end()) { // Update

taichi/common/serialization.h

+22-2
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,20 @@ class BinarySerializer : public Serializer {
345345
preserved = 0;
346346
}
347347

348+
template <bool writing_ = writing>
349+
typename std::enable_if<!writing_, std::size_t>::type retrieve_length() {
350+
return *reinterpret_cast<std::size_t *>(c_data);
351+
}
352+
348353
void finalize() {
349-
if (writing) {
354+
if constexpr (writing) {
350355
if (c_data) {
351356
*reinterpret_cast<std::size_t *>(&c_data[0]) = head;
352357
} else {
353358
*reinterpret_cast<std::size_t *>(&data[0]) = head;
354359
}
355360
} else {
356-
assert(head == *reinterpret_cast<std::size_t *>(c_data));
361+
assert(head == retrieve_length());
357362
}
358363
}
359364

@@ -880,6 +885,21 @@ operator<<(std::ostream &os, const T &t) {
880885
}
881886

882887
// Returns true if deserialization succeeded.
888+
template <typename T>
889+
bool read_from_binary(T &t,
890+
const void *bin,
891+
std::size_t len,
892+
bool match_all = true) {
893+
BinaryInputSerializer reader;
894+
reader.initialize(const_cast<void *>(bin));
895+
if (len != reader.retrieve_length()) {
896+
return false;
897+
}
898+
reader(t);
899+
auto head = reader.head;
900+
return match_all ? head == len : head <= len;
901+
}
902+
883903
template <typename T>
884904
bool read_from_binary_file(T &t, const std::string &file_name) {
885905
BinaryInputSerializer reader;

taichi/runtime/llvm/llvm_offline_cache.cpp

+11-37
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ using Format = LlvmOfflineCache::Format;
2727
constexpr char kMetadataFilename[] = "metadata";
2828
constexpr char kMetadataFileLockName[] = "metadata.lock";
2929

30-
static bool is_current_llvm_cache_version(
31-
const LlvmOfflineCache::Version &ver) {
32-
// TODO(PGZXB): Do more detailed checking
33-
return ver[0] == TI_VERSION_MAJOR && ver[1] == TI_VERSION_MINOR &&
34-
ver[2] == TI_VERSION_PATCH;
35-
}
36-
3730
static std::string get_llvm_cache_metadata_file_path(const std::string &dir) {
3831
return taichi::join_path(dir, std::string(kMetadataFilename) + ".tcb");
3932
}
@@ -60,13 +53,6 @@ struct CacheCleanerUtils<LlvmOfflineCache> {
6053
using MetadataType = LlvmOfflineCache;
6154
using KernelMetaData = typename MetadataType::KernelMetadata;
6255

63-
// To load metadata from file
64-
static bool load_metadata(const CacheCleanerConfig &config,
65-
MetadataType &result) {
66-
return read_from_binary_file(
67-
result, taichi::join_path(config.path, config.metadata_filename));
68-
}
69-
7056
// To save metadata as file
7157
static bool save_metadata(const CacheCleanerConfig &config,
7258
const MetadataType &data) {
@@ -84,12 +70,6 @@ struct CacheCleanerUtils<LlvmOfflineCache> {
8470
return true;
8571
}
8672

87-
// To check version
88-
static bool check_version(const CacheCleanerConfig &config,
89-
const Version &version) {
90-
return is_current_llvm_cache_version(version);
91-
}
92-
9373
// To get cache files name
9474
static std::vector<std::string> get_cache_files(
9575
const CacheCleanerConfig &config,
@@ -106,6 +86,13 @@ struct CacheCleanerUtils<LlvmOfflineCache> {
10686
static void remove_other_files(const CacheCleanerConfig &config) {
10787
// Do nothing
10888
}
89+
90+
// To check if a file is cache file
91+
static bool is_valid_cache_file(const CacheCleanerConfig &config,
92+
const std::string &name) {
93+
std::string ext = filename_extension(name);
94+
return ext == "ll" || ext == "bc";
95+
}
10996
};
11097

11198
} // namespace offline_cache
@@ -126,20 +113,12 @@ bool LlvmOfflineCacheFileReader::load_meta_data(
126113
LlvmOfflineCache &data,
127114
const std::string &cache_file_path,
128115
bool with_lock) {
116+
using offline_cache::load_metadata_with_checking;
117+
using Error = offline_cache::LoadMetadataError;
129118
const auto tcb_path = get_llvm_cache_metadata_file_path(cache_file_path);
130-
{
131-
// No the best way to check for filepath existence, but whatever... See
132-
// https://stackoverflow.com/questions/12774207/fastest-way-to-check-if-a-file-exists-using-standard-c-c11-14-17-c
133-
std::ifstream fs(tcb_path, std::ios::in | std::ios::binary);
134-
if (!fs.good()) {
135-
TI_DEBUG("LLVM cache {} does not exist", cache_file_path);
136-
return false;
137-
}
138-
}
139119

140120
if (!with_lock) {
141-
read_from_binary_file(data, tcb_path);
142-
return true;
121+
return Error::kNoError == load_metadata_with_checking(data, tcb_path);
143122
}
144123

145124
std::string lock_path =
@@ -150,8 +129,7 @@ bool LlvmOfflineCacheFileReader::load_meta_data(
150129
TI_WARN("Unlock {} failed", lock_path);
151130
}
152131
});
153-
read_from_binary_file(data, tcb_path);
154-
return true;
132+
return Error::kNoError == load_metadata_with_checking(data, tcb_path);
155133
}
156134
TI_WARN("Lock {} failed", lock_path);
157135
return false;
@@ -389,10 +367,6 @@ void LlvmOfflineCacheFileWriter::mangle_offloaded_task_name(
389367
for (auto &offload : compiled_data.tasks) {
390368
std::string mangled_name =
391369
offline_cache::mangle_name(offload.name, kernel_key);
392-
TI_DEBUG(
393-
"Mangle offloaded-task from internal name '{}' to offline cache "
394-
"key '{}'",
395-
offload.name, mangled_name);
396370
auto func = compiled_data.module->getFunction(offload.name);
397371
TI_ASSERT(func != nullptr);
398372
func->setName(mangled_name);

taichi/runtime/llvm/llvm_offline_cache.h

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ struct LlvmOfflineCache {
100100
std::unordered_map<std::string, KernelCacheData>
101101
kernels; // key = kernel_name
102102

103+
// NOTE: The "version" must be the first field to be serialized
103104
TI_IO_DEF(version, size, fields, kernels);
104105
};
105106

taichi/util/io.h

+46
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
#if defined(TI_PLATFORM_WINDOWS)
1616
#include <filesystem>
17+
#else // POSIX
18+
#include <unistd.h>
19+
#include <dirent.h>
20+
#include <sys/types.h>
21+
#include <sys/stat.h>
1722
#endif
1823

1924
TI_NAMESPACE_BEGIN
@@ -51,6 +56,47 @@ inline bool remove(const std::string &path) {
5156
return std::remove(path.c_str()) == 0;
5257
}
5358

59+
template <typename Visitor> // void(const std::string &name, bool is_dir)
60+
inline bool traverse_directory(const std::string &dir, Visitor v) {
61+
#if defined(TI_PLATFORM_WINDOWS)
62+
namespace fs = std::filesystem;
63+
std::error_code ec{};
64+
auto iter = fs::directory_iterator(dir, ec);
65+
if (ec) {
66+
return false;
67+
}
68+
for (auto &f : iter) {
69+
v(f.path().filename().string(), f.is_directory());
70+
}
71+
return true;
72+
#else // POSIX
73+
struct dirent *f = nullptr;
74+
DIR *directory = ::opendir(dir.c_str());
75+
if (!directory) {
76+
return false;
77+
}
78+
while ((f = ::readdir(directory))) {
79+
struct stat *stat_buf = nullptr;
80+
auto fullpath = join_path(dir, f->d_name);
81+
auto ret = ::stat(fullpath.c_str(), stat_buf);
82+
TI_ASSERT(ret == 0);
83+
v(f->d_name, S_ISDIR(stat_buf->st_mode));
84+
}
85+
auto ret = ::closedir(directory);
86+
TI_ASSERT(ret == 0);
87+
return true;
88+
#endif
89+
}
90+
91+
inline std::string filename_extension(const std::string &filename) {
92+
std::string postfix;
93+
auto pos = filename.find_last_of('.');
94+
if (pos != std::string::npos) {
95+
postfix = filename.substr(pos + 1);
96+
}
97+
return postfix;
98+
}
99+
54100
template <typename T>
55101
void write_to_disk(const T &dat, std::string fn) {
56102
FILE *f = fopen(fn.c_str(), "wb");

0 commit comments

Comments
 (0)