Skip to content

Commit 752a843

Browse files
[AOT] Save AOT module artifacts as zip archive (#5316)
* Save AOT module artifacts as zip archive * Add version mark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linting * Use tempfile for temp dir management * Added test for AOT module archiving * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 728d001 commit 752a843

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

python/taichi/aot/module.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from contextlib import contextmanager
2+
from glob import glob
23
from pathlib import Path, PurePosixPath
4+
from shutil import rmtree
5+
from tempfile import mkdtemp
6+
from zipfile import ZipFile
37

48
from taichi.aot.utils import (produce_injected_args,
59
produce_injected_args_from_template)
@@ -8,6 +12,8 @@
812
from taichi.lang.matrix import MatrixField
913
from taichi.types.annotations import template
1014

15+
import taichi
16+
1117

1218
class KernelTemplate:
1319
def __init__(self, kernel_fn, aot_module):
@@ -88,6 +94,7 @@ def __init__(self, arch):
8894
rtm = impl.get_runtime()
8995
rtm._finalize_root_fb_for_aot()
9096
self._aot_builder = rtm.prog.make_aot_module_builder(arch)
97+
self._content = []
9198

9299
def add_field(self, name, field):
93100
"""Add a taichi field to the AOT module.
@@ -147,8 +154,11 @@ def add_kernel(self, kernel_fn, template_args=None, name=None):
147154
# kernel AOT
148155
self._kernels.append(kernel)
149156

157+
self._content += ["kernel:" + kernel_name]
158+
150159
def add_graph(self, name, graph):
151160
self._aot_builder.add_graph(name, graph._compiled_graph)
161+
self._content += ["cgraph:" + name]
152162

153163
@contextmanager
154164
def add_kernel_template(self, kernel_fn):
@@ -193,3 +203,29 @@ def save(self, filepath, filename):
193203
"""
194204
filepath = str(PurePosixPath(Path(filepath)))
195205
self._aot_builder.dump(filepath, filename)
206+
207+
def archive(self, filepath: str):
208+
"""
209+
Args:
210+
filepath (str): path to the stored archive of aot artifacts, MUST
211+
end with `.tcm`.
212+
"""
213+
assert filepath.endswith(".tcm"), \
214+
"AOT module artifact archive must ends with .tcm"
215+
tcm_path = Path(filepath).absolute()
216+
assert tcm_path.parent.exists(), "Output directory doesn't exist"
217+
218+
temp_dir = mkdtemp(prefix="tcm_")
219+
# Save first as usual.
220+
self.save(temp_dir, "")
221+
222+
# Package all artifacts into a zip archive and attach contend data.
223+
with ZipFile(tcm_path, "w") as z:
224+
z.writestr("__content__", '\n'.join(self._content))
225+
z.writestr("__version__",
226+
'.'.join(str(x) for x in taichi.__version__))
227+
for path in glob(f"{temp_dir}/*", recursive=True):
228+
z.write(path, Path.relative_to(Path(path), temp_dir))
229+
230+
# Remove cached files
231+
rmtree(temp_dir)

tests/python/test_aot.py

+22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import tempfile
5+
import zipfile
56

67
import numpy as np
78
import pytest
@@ -555,3 +556,24 @@ def run(arr: ti.types.ndarray(), val1: ti.f32, val2: ti.template()):
555556
res = json.load(json_file)
556557
args_count = res['aot_data']['kernels']['run']['args_count']
557558
assert args_count == 2, res # `arr` and `val1`
559+
560+
561+
@test_utils.test(arch=[ti.vulkan])
562+
def test_archive():
563+
density = ti.field(float, shape=(4, 4))
564+
565+
@ti.kernel
566+
def init():
567+
for i, j in density:
568+
density[i, j] = 1
569+
570+
with tempfile.TemporaryDirectory() as tmpdir:
571+
# note ti.aot.Module(ti.opengl) is no-op according to its docstring.
572+
m = ti.aot.Module(ti.lang.impl.current_cfg().arch)
573+
m.add_field('density', density)
574+
m.add_kernel(init)
575+
tcm_path = f"{tmpdir}/x.tcm"
576+
m.archive(tcm_path)
577+
with zipfile.ZipFile(tcm_path, 'r') as z:
578+
assert z.read("__version__") == bytes(
579+
'.'.join(str(x) for x in ti.__version__), 'utf-8')

0 commit comments

Comments
 (0)