Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] Fix int casts #4814

Merged
merged 2 commits into from
Apr 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taichi/backends/device.h
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@
namespace taichi {
namespace lang {

constexpr size_t kBufferSizeEntireSize = size_t(-1);

// For backend dependent code (e.g. codegen)
// Or the backend runtime itself
// Capabilities are per-device
9 changes: 5 additions & 4 deletions taichi/backends/vulkan/runtime.cpp
Original file line number Diff line number Diff line change
@@ -357,7 +357,7 @@ void CompiledTaichiKernel::generate_command_list(
if (bind.buffer.type == BufferType::ListGen) {
// FIXME: properlly support multiple list
cmdlist->buffer_fill(input_buffers_.at(bind.buffer)->get_ptr(0),
kListGenBufferSize,
kBufferSizeEntireSize,
/*data=*/0);
cmdlist->buffer_barrier(*input_buffers_.at(bind.buffer));
}
@@ -579,9 +579,9 @@ void VkRuntime::init_nonroot_buffers() {
Stream *stream = device_->get_compute_stream();
auto cmdlist = stream->new_command_list();

cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), kGtmpBufferSize,
cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), kBufferSizeEntireSize,
/*data=*/0);
cmdlist->buffer_fill(listgen_buffer_->get_ptr(0), kListGenBufferSize,
cmdlist->buffer_fill(listgen_buffer_->get_ptr(0), kBufferSizeEntireSize,
/*data=*/0);
stream->submit_synced(cmdlist.get());
}
@@ -597,7 +597,8 @@ void VkRuntime::add_root_buffer(size_t root_buffer_size) {
/*export_sharing=*/false, AllocUsage::Storage});
Stream *stream = device_->get_compute_stream();
auto cmdlist = stream->new_command_list();
cmdlist->buffer_fill(new_buffer->get_ptr(0), root_buffer_size, /*data=*/0);
cmdlist->buffer_fill(new_buffer->get_ptr(0), kBufferSizeEntireSize,
/*data=*/0);
stream->submit_synced(cmdlist.get());
root_buffers_.push_back(std::move(new_buffer));
// cache the root buffer size
3 changes: 2 additions & 1 deletion taichi/backends/vulkan/vulkan_device.cpp
Original file line number Diff line number Diff line change
@@ -855,7 +855,8 @@ void VulkanCommandList::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) {

void VulkanCommandList::buffer_fill(DevicePtr ptr, size_t size, uint32_t data) {
auto buffer = ti_device_->get_vkbuffer(ptr);
vkCmdFillBuffer(buffer_->buffer, buffer->buffer, ptr.offset, size, data);
vkCmdFillBuffer(buffer_->buffer, buffer->buffer, ptr.offset,
(size == kBufferSizeEntireSize) ? VK_WHOLE_SIZE : size, data);
buffer_->refs.push_back(buffer);
}

2 changes: 1 addition & 1 deletion taichi/backends/vulkan/vulkan_device_creator.cpp
Original file line number Diff line number Diff line change
@@ -428,7 +428,7 @@ void VulkanDeviceCreator::create_logical_device() {
ti_device_->set_cap(DeviceCapability::spirv_version, 0x10000);

if (physical_device_properties.apiVersion >= VK_API_VERSION_1_3) {
ti_device_->set_cap(DeviceCapability::spirv_version, 0x10600);
ti_device_->set_cap(DeviceCapability::spirv_version, 0x10500);
} else if (physical_device_properties.apiVersion >= VK_API_VERSION_1_2) {
ti_device_->set_cap(DeviceCapability::spirv_version, 0x10500);
} else if (physical_device_properties.apiVersion >= VK_API_VERSION_1_1) {
65 changes: 40 additions & 25 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
@@ -777,40 +777,55 @@ Value IRBuilder::cast(const SType &dst_type, Value value) {
to.to_string());
return Value();
}
} else if (is_integral(from) && is_signed(from) && is_integral(to) &&
is_signed(to)) { // Int -> Int
return make_value(spv::OpSConvert, dst_type, value);
} else if (is_integral(from) && is_unsigned(from) && is_integral(to) &&
is_unsigned(to)) { // UInt -> UInt
return make_value(spv::OpUConvert, dst_type, value);
} else if (is_integral(from) && is_unsigned(from) && is_integral(to) &&
is_signed(to)) { // UInt -> Int
if (data_type_bits(from) != data_type_bits(to)) {
auto to_signed = [](DataType dt) -> DataType {
TI_ASSERT(is_unsigned(dt));
if (dt->is_primitive(PrimitiveTypeID::u8))
} else if (is_integral(from) && is_integral(to)) {
auto ret = value;

if (data_type_bits(from) == data_type_bits(to)) {
// Same width conversion
ret = make_value(spv::OpBitcast, dst_type, ret);
} else {
// Different width
// Step 1. Sign extend / truncate value to width of `to`
// Step 2. Bitcast to signess of `to`
auto get_signed_type = [](DataType dt) -> DataType {
// Create a output signed type with the same width as `dt`
if (data_type_bits(dt) == 8)
return PrimitiveType::i8;
else if (dt->is_primitive(PrimitiveTypeID::u16))
else if (data_type_bits(dt) == 16)
return PrimitiveType::i16;
else if (dt->is_primitive(PrimitiveTypeID::u32))
else if (data_type_bits(dt) == 32)
return PrimitiveType::i32;
else if (dt->is_primitive(PrimitiveTypeID::u64))
else if (data_type_bits(dt) == 64)
return PrimitiveType::i64;
else
return PrimitiveType::unknown;
};
auto get_unsigned_type = [](DataType dt) -> DataType {
// Create a output unsigned type with the same width as `dt`
if (data_type_bits(dt) == 8)
return PrimitiveType::u8;
else if (data_type_bits(dt) == 16)
return PrimitiveType::u16;
else if (data_type_bits(dt) == 32)
return PrimitiveType::u32;
else if (data_type_bits(dt) == 64)
return PrimitiveType::u64;
else
return PrimitiveType::unknown;
};

value = make_value(spv::OpUConvert, get_primitive_type(to_signed(from)),
value);
}
return make_value(spv::OpBitcast, dst_type, value);
} else if (is_integral(from) && is_signed(from) && is_integral(to) &&
is_unsigned(to)) { // Int -> UInt
if (data_type_bits(from) != data_type_bits(to)) {
value = make_value(spv::OpSConvert, get_primitive_type(to_unsigned(from)),
value);
if (is_signed(from)) {
ret = make_value(spv::OpSConvert,
get_primitive_type(get_signed_type(to)), ret);
} else {
ret = make_value(spv::OpUConvert,
get_primitive_type(get_unsigned_type(to)), ret);
}

ret = make_value(spv::OpBitcast, dst_type, ret);
}
return make_value(spv::OpBitcast, dst_type, value);

return ret;
} else if (is_real(from) && is_integral(to) &&
is_signed(to)) { // Float -> Int
return make_value(spv::OpConvertFToS, dst_type, value);