|
1 | 1 | from contextlib import contextmanager
|
| 2 | +from glob import glob |
2 | 3 | from pathlib import Path, PurePosixPath
|
| 4 | +from shutil import rmtree |
| 5 | +from tempfile import mkdtemp |
| 6 | +from zipfile import ZipFile |
3 | 7 |
|
4 | 8 | from taichi.aot.utils import (produce_injected_args,
|
5 | 9 | produce_injected_args_from_template)
|
|
8 | 12 | from taichi.lang.matrix import MatrixField
|
9 | 13 | from taichi.types.annotations import template
|
10 | 14 |
|
| 15 | +import taichi |
| 16 | + |
11 | 17 |
|
12 | 18 | class KernelTemplate:
|
13 | 19 | def __init__(self, kernel_fn, aot_module):
|
@@ -88,6 +94,7 @@ def __init__(self, arch):
|
88 | 94 | rtm = impl.get_runtime()
|
89 | 95 | rtm._finalize_root_fb_for_aot()
|
90 | 96 | self._aot_builder = rtm.prog.make_aot_module_builder(arch)
|
| 97 | + self._content = [] |
91 | 98 |
|
92 | 99 | def add_field(self, name, field):
|
93 | 100 | """Add a taichi field to the AOT module.
|
@@ -147,8 +154,11 @@ def add_kernel(self, kernel_fn, template_args=None, name=None):
|
147 | 154 | # kernel AOT
|
148 | 155 | self._kernels.append(kernel)
|
149 | 156 |
|
| 157 | + self._content += ["kernel:" + kernel_name] |
| 158 | + |
150 | 159 | def add_graph(self, name, graph):
|
151 | 160 | self._aot_builder.add_graph(name, graph._compiled_graph)
|
| 161 | + self._content += ["cgraph:" + name] |
152 | 162 |
|
153 | 163 | @contextmanager
|
154 | 164 | def add_kernel_template(self, kernel_fn):
|
@@ -193,3 +203,29 @@ def save(self, filepath, filename):
|
193 | 203 | """
|
194 | 204 | filepath = str(PurePosixPath(Path(filepath)))
|
195 | 205 | self._aot_builder.dump(filepath, filename)
|
| 206 | + |
| 207 | + def archive(self, filepath: str): |
| 208 | + """ |
| 209 | + Args: |
| 210 | + filepath (str): path to the stored archive of aot artifacts, MUST |
| 211 | + end with `.tcm`. |
| 212 | + """ |
| 213 | + assert filepath.endswith(".tcm"), \ |
| 214 | + "AOT module artifact archive must ends with .tcm" |
| 215 | + tcm_path = Path(filepath).absolute() |
| 216 | + assert tcm_path.parent.exists(), "Output directory doesn't exist" |
| 217 | + |
| 218 | + temp_dir = mkdtemp(prefix="tcm_") |
| 219 | + # Save first as usual. |
| 220 | + self.save(temp_dir, "") |
| 221 | + |
| 222 | + # Package all artifacts into a zip archive and attach contend data. |
| 223 | + with ZipFile(tcm_path, "w") as z: |
| 224 | + z.writestr("__content__", '\n'.join(self._content)) |
| 225 | + z.writestr("__version__", |
| 226 | + '.'.join(str(x) for x in taichi.__version__)) |
| 227 | + for path in glob(f"{temp_dir}/*", recursive=True): |
| 228 | + z.write(path, Path.relative_to(Path(path), temp_dir)) |
| 229 | + |
| 230 | + # Remove cached files |
| 231 | + rmtree(temp_dir) |
0 commit comments