Skip to content

Commit 98a4ac5

Browse files
committed
Moved image tracking registration to create_image
1 parent b1096e9 commit 98a4ac5

File tree

8 files changed

+29
-2
lines changed

8 files changed

+29
-2
lines changed

taichi/program/program.h

+3
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ class TI_DLL_EXPORT Program {
300300
uint64 *result_buffer) {
301301
return program_impl_->allocate_memory_ndarray(alloc_size, result_buffer);
302302
}
303+
DeviceAllocation allocate_texture(const ImageParams& params) {
304+
return program_impl_->allocate_texture(params);
305+
}
303306

304307
Ndarray *create_ndarray(
305308
const DataType type,

taichi/program/program_impl.h

+5
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ class ProgramImpl {
9898
uint64 *result_buffer) {
9999
return kDeviceNullAllocation;
100100
}
101+
102+
virtual DeviceAllocation allocate_texture(const ImageParams& params) {
103+
return kDeviceNullAllocation;
104+
}
105+
101106
virtual ~ProgramImpl() {
102107
}
103108

taichi/program/texture.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Texture::Texture(Program *prog,
3535
img_params.y = height;
3636
img_params.z = depth;
3737
img_params.initial_layout = ImageLayout::undefined;
38-
texture_alloc_ = device->create_image(img_params);
38+
texture_alloc_ = prog_->allocate_texture(img_params);
3939

4040
format_ = img_params.format;
4141

taichi/rhi/device.h

+4
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,10 @@ class TI_DLL_EXPORT GraphicsDevice : public Device {
642642

643643
virtual std::unique_ptr<Surface> create_surface(
644644
const SurfaceConfig &config) = 0;
645+
// You are not expected to call this directly. If you want to use this image
646+
// in a taichi kernel, you usually want to create the image via
647+
// `GfxRuntime::create_image`. `GfxRuntime` is available in `ProgramImpl`
648+
// of GPU backends.
645649
virtual DeviceAllocation create_image(const ImageParams &params) = 0;
646650
virtual void destroy_image(DeviceAllocation handle) = 0;
647651

taichi/runtime/gfx/runtime.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,17 @@ void GfxRuntime::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) {
558558
current_cmdlist_->buffer_copy(dst, src, size);
559559
submit_current_cmdlist_if_timeout();
560560
}
561+
562+
DeviceAllocation GfxRuntime::create_image(const ImageParams& params) {
563+
GraphicsDevice* gfx_device = dynamic_cast<GraphicsDevice*>(device_);
564+
TI_ERROR_IF(gfx_device == nullptr, "Image can only be created on a graphics device");
565+
DeviceAllocation image = gfx_device->create_image(params);
566+
last_image_layouts_[image.alloc_id] = params.initial_layout;
567+
return image;
568+
}
569+
561570
void GfxRuntime::transition_image(DeviceAllocation image, ImageLayout layout) {
562-
ImageLayout &last_layout = last_image_layouts_[image.alloc_id];
571+
ImageLayout &last_layout = last_image_layouts_.at(image.alloc_id);
563572
ensure_current_cmdlist();
564573
current_cmdlist_->image_transition(image, last_layout, layout);
565574
last_layout = layout;

taichi/runtime/gfx/runtime.h

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class TI_DLL_EXPORT GfxRuntime {
102102

103103
void buffer_copy(DevicePtr dst, DevicePtr src, size_t size);
104104

105+
DeviceAllocation create_image(const ImageParams& params);
105106
void transition_image(DeviceAllocation image, ImageLayout layout);
106107

107108
void signal_event(DeviceEvent *event);

taichi/runtime/program_impls/vulkan/vulkan_program.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ DeviceAllocation VulkanProgramImpl::allocate_memory_ndarray(
193193
/*export_sharing=*/false});
194194
}
195195

196+
DeviceAllocation VulkanProgramImpl::allocate_texture(const ImageParams& params) {
197+
return vulkan_runtime_->create_image(params);
198+
}
199+
196200
std::unique_ptr<aot::Kernel> VulkanProgramImpl::make_aot_kernel(
197201
Kernel &kernel) {
198202
spirv::lower(&kernel);

taichi/runtime/program_impls/vulkan/vulkan_program.h

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class VulkanProgramImpl : public ProgramImpl {
6565

6666
DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size,
6767
uint64 *result_buffer) override;
68+
DeviceAllocation allocate_texture(const ImageParams& params) override;
6869

6970
Device *get_compute_device() override {
7071
if (embedded_device_) {

0 commit comments

Comments
 (0)