Skip to content

Commit 937c317

Browse files
committed
Update on "[autodiff] Support basic operations for forward mode autodiff"
Support cpu and gpu backends. The cc backend has an issue on FieldBuilder ref to #5143. The opengl backend currently does not support materializing multiple snode trees (see OpenglProgramImpl::compile_snode_tree_types), thus FieldBuilder is not supported. Related #5055 [ghstack-poisoned]
2 parents ce01e36 + ba19461 commit 937c317

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2604
-550
lines changed

.pre-commit-config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ci:
22
autoupdate_schedule: quarterly
33
autoupdate_commit_msg: '[misc] Update pre-commit hooks'
44

5-
exclude: ^((tests/python/test_exception|.*/examples/.*)\.py$|external/)
5+
exclude: ^((tests/python/test_exception)\.py$|external/)
66
repos:
77
- repo: https://github.com/google/yapf
88
rev: v0.32.0
@@ -33,3 +33,4 @@ repos:
3333
- id: pylint
3434
args: ['-rn', '-sn']
3535
files: ^python/taichi/
36+
exclude: ^python/taichi/examples/.*.py

c_api/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
!taichi.json

c_api/include/taichi/taichi.h

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include "taichi/taichi_platform.h"
4+
5+
#include "taichi/taichi_core.h"
6+
7+
#if TI_WITH_VULKAN
8+
#define VK_NO_PROTOTYPES 1
9+
#include "taichi/taichi_vulkan.h"
10+
#endif // TI_WITH_VULKAN

c_api/include/taichi/taichi_core.h

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#pragma once
2+
#include <taichi/taichi_platform.h>
3+
4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif // __cplusplus
7+
8+
// alias.bool
9+
typedef uint32_t TiBool;
10+
11+
// definition.false
12+
#define FALSE 0
13+
14+
// definition.true
15+
#define TRUE 1
16+
17+
// alias.flags
18+
typedef uint32_t TiFlags;
19+
20+
// definition.null_handle
21+
#define NULL_HANDLE 0
22+
23+
// handle.runtime
24+
typedef struct TiRuntime_t *TiRuntime;
25+
26+
// handle.aot_module
27+
typedef struct TiAotModule_t *TiAotModule;
28+
29+
// handle.memory
30+
typedef struct TiMemory_t *TiMemory;
31+
32+
// handle.kernel
33+
typedef struct TiKernel_t *TiKernel;
34+
35+
// handle.compute_graph
36+
typedef struct TiComputeGraph_t *TiComputeGraph;
37+
38+
// enumeration.arch
39+
typedef enum TiArch {
40+
TI_ARCH_X64 = 0,
41+
TI_ARCH_ARM64 = 1,
42+
TI_ARCH_JS = 2,
43+
TI_ARCH_CC = 3,
44+
TI_ARCH_WASM = 4,
45+
TI_ARCH_CUDA = 5,
46+
TI_ARCH_METAL = 6,
47+
TI_ARCH_OPENGL = 7,
48+
TI_ARCH_DX11 = 8,
49+
TI_ARCH_OPENCL = 9,
50+
TI_ARCH_AMDGPU = 10,
51+
TI_ARCH_VULKAN = 11,
52+
TI_ARCH_MAX_ENUM = 0xffffffff,
53+
} TiArch;
54+
55+
// enumeration.argument_type
56+
typedef enum TiArgumentType {
57+
TI_ARGUMENT_TYPE_I32 = 0,
58+
TI_ARGUMENT_TYPE_F32 = 1,
59+
TI_ARGUMENT_TYPE_NDARRAY = 2,
60+
TI_ARGUMENT_TYPE_MAX_ENUM = 0xffffffff,
61+
} TiArgumentType;
62+
63+
// bit_field.memory_usage
64+
typedef enum TiMemoryUsageFlagBits {
65+
TI_MEMORY_USAGE_STORAGE_BIT = 0,
66+
TI_MEMORY_USAGE_UNIFORM_BIT = 1,
67+
TI_MEMORY_USAGE_VERTEX_BIT = 2,
68+
TI_MEMORY_USAGE_INDEX_BIT = 3,
69+
} TiMemoryUsageFlagBits;
70+
typedef TiFlags TiMemoryUsageFlags;
71+
72+
// structure.memory_allocate_info
73+
typedef struct TiMemoryAllocateInfo {
74+
uint64_t size;
75+
bool host_write;
76+
bool host_read;
77+
bool export_sharing;
78+
TiMemoryUsageFlagBits usage;
79+
} TiMemoryAllocateInfo;
80+
81+
// structure.nd_shape
82+
typedef struct TiNdShape {
83+
uint32_t dim_count;
84+
uint32_t dims[16];
85+
} TiNdShape;
86+
87+
// structure.nd_array
88+
typedef struct TiNdArray {
89+
TiMemory memory;
90+
TiNdShape shape;
91+
TiNdShape elem_shape;
92+
} TiNdArray;
93+
94+
// union.argument_value
95+
typedef union TiArgumentValue {
96+
int32_t i32;
97+
float f32;
98+
TiNdArray ndarray;
99+
} TiArgumentValue;
100+
101+
// structure.argument
102+
typedef struct TiArgument {
103+
TiArgumentType type;
104+
TiArgumentValue value;
105+
} TiArgument;
106+
107+
// structure.named_argument
108+
typedef struct TiNamedArgument {
109+
const char *name;
110+
TiArgument argument;
111+
} TiNamedArgument;
112+
113+
// function.create_runtime
114+
TI_DLL_EXPORT TiRuntime TI_API_CALL ti_create_runtime(TiArch arch);
115+
116+
// function.destroy_runtime
117+
TI_DLL_EXPORT void TI_API_CALL ti_destroy_runtime(TiRuntime runtime);
118+
119+
// function.allocate_memory
120+
TI_DLL_EXPORT TiMemory TI_API_CALL
121+
ti_allocate_memory(TiRuntime runtime,
122+
const TiMemoryAllocateInfo *allocate_info);
123+
124+
// function.free_memory
125+
TI_DLL_EXPORT void TI_API_CALL ti_free_memory(TiRuntime runtime,
126+
TiMemory memory);
127+
128+
// function.map_memory
129+
TI_DLL_EXPORT void *TI_API_CALL ti_map_memory(TiRuntime runtime,
130+
TiMemory memory);
131+
132+
// function.unmap_memory
133+
TI_DLL_EXPORT void TI_API_CALL ti_unmap_memory(TiRuntime runtime,
134+
TiMemory memory);
135+
136+
// function.launch_kernel
137+
TI_DLL_EXPORT void TI_API_CALL ti_launch_kernel(TiRuntime runtime,
138+
TiKernel kernel,
139+
uint32_t arg_count,
140+
const TiArgument *args);
141+
142+
// function.launch_compute_graph
143+
TI_DLL_EXPORT void TI_API_CALL
144+
ti_launch_compute_graph(TiRuntime runtime,
145+
TiComputeGraph compute_graph,
146+
uint32_t arg_count,
147+
const TiNamedArgument *args);
148+
149+
// function.submit
150+
TI_DLL_EXPORT void TI_API_CALL ti_submit(TiRuntime runtime);
151+
152+
// function.wait
153+
TI_DLL_EXPORT void TI_API_CALL ti_wait(TiRuntime runtime);
154+
155+
// function.load_aot_module
156+
TI_DLL_EXPORT TiAotModule TI_API_CALL
157+
ti_load_aot_module(TiRuntime runtime, const char *module_path);
158+
159+
// function.destroy_aot_module
160+
TI_DLL_EXPORT void TI_API_CALL ti_destroy_aot_module(TiAotModule aot_module);
161+
162+
// function.get_aot_module_kernel
163+
TI_DLL_EXPORT TiKernel TI_API_CALL
164+
ti_get_aot_module_kernel(TiAotModule aot_module, const char *name);
165+
166+
// function.get_aot_module_compute_graph
167+
TI_DLL_EXPORT TiComputeGraph TI_API_CALL
168+
ti_get_aot_module_compute_graph(TiAotModule aot_module, const char *name);
169+
170+
#ifdef __cplusplus
171+
} // extern "C"
172+
#endif // __cplusplus
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma once
2+
3+
// TO KEEP THE INCLUDE DEPENDENCY CLEAN, PLEASE DO NOT INCLUDE ANY OTHER
4+
// TAICHI HEADERS INTO THIS ONE.
5+
//
6+
// TODO(#2196): Once we can slim down "taichi/common/core.h", consider moving
7+
// the contents back to core.h and delete this file.
8+
#ifndef _CRT_SECURE_NO_WARNINGS
9+
#define _CRT_SECURE_NO_WARNINGS
10+
#endif
11+
12+
// https://gcc.gnu.org/wiki/Visibility
13+
#if defined _WIN32 || defined _WIN64 || defined __CYGWIN__
14+
#ifdef __GNUC__
15+
#define TI_DLL_EXPORT __attribute__((dllexport))
16+
#define TI_API_CALL
17+
#else
18+
#define TI_DLL_EXPORT __declspec(dllexport)
19+
#define TI_API_CALL __stdcall
20+
#endif // __GNUC__
21+
#else
22+
#define TI_DLL_EXPORT __attribute__((visibility("default")))
23+
#define TI_API_CALL
24+
#endif // defined _WIN32 || defined _WIN64 || defined __CYGWIN__
25+
26+
// Windows
27+
#if defined(_WIN64)
28+
#define TI_PLATFORM_WINDOWS
29+
#endif
30+
31+
#if defined(_WIN32) && !defined(_WIN64)
32+
static_assert(false, "32-bit Windows systems are not supported")
33+
#endif
34+
35+
// Linux
36+
#if defined(__linux__)
37+
#if defined(ANDROID)
38+
#define TI_PLATFORM_ANDROID
39+
#else
40+
#define TI_PLATFORM_LINUX
41+
#endif
42+
#endif
43+
44+
// OSX
45+
#if defined(__APPLE__)
46+
#define TI_PLATFORM_OSX
47+
#endif
48+
49+
#if (defined(TI_PLATFORM_LINUX) || defined(TI_PLATFORM_OSX) || \
50+
defined(__unix__))
51+
#define TI_PLATFORM_UNIX
52+
#endif
53+
54+
#include <stddef.h>
55+
#include <stdint.h>

c_api/include/taichi/taichi_vulkan.h

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma once
2+
#include <taichi/taichi_core.h>
3+
#include <vulkan/vulkan.h>
4+
5+
#ifdef __cplusplus
6+
extern "C" {
7+
#endif // __cplusplus
8+
9+
// structure.vulkan_runtime_interop_info
10+
typedef struct TiVulkanRuntimeInteropInfo {
11+
uint32_t api_version;
12+
VkInstance instance;
13+
VkPhysicalDevice physical_device;
14+
VkDevice device;
15+
VkQueue compute_queue;
16+
uint32_t compute_queue_family_index;
17+
VkQueue graphics_queue;
18+
uint32_t graphics_queue_family_index;
19+
} TiVulkanRuntimeInteropInfo;
20+
21+
// structure.vulkan_memory_interop_info
22+
typedef struct TiVulkanMemoryInteropInfo {
23+
VkBuffer buffer;
24+
size_t size;
25+
VkBufferUsageFlags usage;
26+
} TiVulkanMemoryInteropInfo;
27+
28+
// function.create_vulkan_runtime
29+
TI_DLL_EXPORT TiRuntime TI_API_CALL
30+
ti_create_vulkan_runtime_ext(uint32_t api_version,
31+
uint32_t instance_extension_count,
32+
const char **instance_extensions,
33+
uint32_t device_extension_count,
34+
const char **device_extensions);
35+
36+
// function.import_vulkan_runtime
37+
TI_DLL_EXPORT TiRuntime TI_API_CALL
38+
ti_import_vulkan_runtime(const TiVulkanRuntimeInteropInfo *interop_info);
39+
40+
// function.export_vulkan_runtime
41+
TI_DLL_EXPORT void TI_API_CALL
42+
ti_export_vulkan_runtime(TiRuntime runtime,
43+
TiVulkanRuntimeInteropInfo *interop_info);
44+
45+
// function.import_vulkan_memory
46+
TI_DLL_EXPORT TiMemory TI_API_CALL
47+
ti_import_vulkan_memory(TiRuntime runtime,
48+
const TiVulkanMemoryInteropInfo *interop_info);
49+
50+
// function.export_vulkan_memory
51+
TI_DLL_EXPORT void TI_API_CALL
52+
ti_export_vulkan_memory(TiRuntime runtime,
53+
TiMemory memory,
54+
TiVulkanMemoryInteropInfo *interop_info);
55+
56+
#ifdef __cplusplus
57+
} // extern "C"
58+
#endif // __cplusplus

0 commit comments

Comments
 (0)