Skip to content

Commit 3852215

Browse files
IvanKobzarevfacebook-github-bot
authored andcommittedJun 20, 2020
[vulkan] jit passes for vulkan conv2 prepack and fuse with clamp (pytorch#39282)
Summary: Pull Request resolved: pytorch#39282 Test Plan: Imported from OSS Differential Revision: D21962424 Pulled By: IvanKobzarev fbshipit-source-id: 2d20e827d2c3836b7e6b443293377c68dc1ffa5a
1 parent f69460d commit 3852215

16 files changed

+540
-64
lines changed
 

‎aten/src/ATen/native/vulkan/VulkanAten.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ at::Tensor vulkan_convolution(
179179
voutput,
180180
vinput,
181181
weight.data_ptr<float>(),
182-
bias.defined() ? c10::make_optional<float*>(bias.data_ptr<float>())
182+
bias.defined() ? c10::make_optional<const float*>(bias.data_ptr<float>())
183183
: c10::nullopt,
184184
params);
185185
return new_with_vtensor_vulkan(std::move(voutput), input.options());
@@ -242,7 +242,8 @@ at::Tensor vulkan_convolution_prepacked(
242242
voutput,
243243
vinput,
244244
vweight,
245-
hasBias ? c10::make_optional((*bias).data_ptr<float>()) : c10::nullopt,
245+
hasBias ? c10::make_optional<const float*>((*bias).data_ptr<float>())
246+
: c10::nullopt,
246247
params,
247248
output_min,
248249
output_max);

‎aten/src/ATen/native/vulkan/VulkanConvolution.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ ContextConv2D create(
6666
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
6767
const auto dilation_expanded =
6868
expand_param_if_needed(dilation, "dilation", 2);
69-
const Tensor weight_nchw = weight.contiguous();
69+
Tensor weight_nchw = weight.contiguous();
70+
auto ws = weight_nchw.sizes();
7071
return ContextConv2D{
71-
at::native::vulkan_convolution_prepack_weights(weight),
72+
groups == 1 ? at::native::vulkan_convolution_prepack_weights(weight_nchw)
73+
: weight_nchw.vulkan(),
7274
bias.has_value() ? c10::make_optional((*bias).vulkan()) : c10::nullopt,
73-
{weight_nchw.sizes()[0],
74-
weight_nchw.sizes()[1],
75-
weight_nchw.sizes()[2],
76-
weight_nchw.sizes()[3]},
75+
{{ws[0], ws[1], ws[2], ws[3]}},
7776
{padding_expanded[0], padding_expanded[1]},
7877
{stride_expanded[0], stride_expanded[1]},
7978
{dilation_expanded[0], dilation_expanded[1]},

‎aten/src/ATen/native/vulkan/VulkanOps.cpp

+71-14
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ VBuffer kernelNCHW_OCHW_repack_O4C4HWi4o4(
176176
}
177177

178178
VBuffer bufferFromOptionalHostData(
179-
c10::optional<float*> data,
179+
c10::optional<const float*> data,
180180
const uint32_t size) {
181181
const auto sizeAligned =
182182
ROUND_UP(size, context().limits().minStorageBufferOffsetAlignment);
@@ -202,17 +202,15 @@ uint32_t conv2d_biasBufferSize(uint32_t oc) {
202202
void conv2d_depthwise(
203203
VulkanTensor& output,
204204
const VulkanTensor& input,
205-
const float* weight,
206-
const c10::optional<float*> bias,
207-
const Conv2DParams params,
205+
const VulkanTensor& weight,
206+
const VBuffer& biasBuffer,
207+
const Conv2DParams& params,
208208
c10::optional<float> output_min,
209209
c10::optional<float> output_max) {
210210
TORCH_INTERNAL_ASSERT(params.G == params.C);
211211
auto osizes = output.sizes();
212212
TORCH_INTERNAL_ASSERT(osizes[2] == params.OH);
213213
TORCH_INTERNAL_ASSERT(osizes[3] == params.OW);
214-
auto biasBuffer =
215-
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC));
216214
struct ConstBlock {
217215
int32_t padding[2];
218216
int32_t kernelSize[2];
@@ -234,9 +232,6 @@ void conv2d_depthwise(
234232
output_max ? *output_max : std::numeric_limits<float>::infinity()};
235233
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
236234

237-
VulkanTensor kernel{{params.OC, params.KH, params.KW}};
238-
kernel.set_data_from_host(weight);
239-
240235
VkDescriptorSetLayout descriptorSetLayout{};
241236
VkDescriptorPool descriptorPool{};
242237
VkDescriptorSet descriptorSet{};
@@ -256,7 +251,7 @@ void conv2d_depthwise(
256251

257252
output.image()->bindStorageImage(descriptorSet, 0);
258253
input.image()->bindShaderRead(descriptorSet, 1);
259-
kernel.image()->bindShaderRead(descriptorSet, 2);
254+
weight.image()->bindShaderRead(descriptorSet, 2);
260255
biasBuffer.bind(descriptorSet, 3);
261256
constBuffer.bind(descriptorSet, 4);
262257

@@ -269,7 +264,7 @@ void conv2d_depthwise(
269264
auto commandBuffer = computeUnit.commandBuffer();
270265
output.image()->addImageMemoryBarrierToGeneral(commandBuffer);
271266
input.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
272-
kernel.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
267+
weight.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
273268
computeUnit.dispatchCommandBuffer(
274269
params.OW, params.OH, params.OC_4, workGroupSize);
275270
computeUnit.endCommandBuffer();
@@ -279,6 +274,44 @@ void conv2d_depthwise(
279274
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
280275
}
281276

277+
void conv2d_depthwise(
278+
VulkanTensor& output,
279+
const VulkanTensor& input,
280+
const VulkanTensor& weight,
281+
const c10::optional<const float*> bias,
282+
const Conv2DParams params,
283+
c10::optional<float> output_min,
284+
c10::optional<float> output_max) {
285+
conv2d_depthwise(
286+
output,
287+
input,
288+
weight,
289+
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
290+
params,
291+
output_min,
292+
output_max);
293+
}
294+
295+
void conv2d_depthwise(
296+
VulkanTensor& output,
297+
const VulkanTensor& input,
298+
const float* weight,
299+
const c10::optional<const float*> bias,
300+
const Conv2DParams params,
301+
c10::optional<float> output_min,
302+
c10::optional<float> output_max) {
303+
VulkanTensor weightTensor{{params.OC, params.KH, params.KW}};
304+
weightTensor.set_data_from_host(weight);
305+
conv2d_depthwise(
306+
output,
307+
input,
308+
weightTensor,
309+
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
310+
params,
311+
output_min,
312+
output_max);
313+
}
314+
282315
ImageSizes conv2d_prepack_weights_image_sizes(
283316
int64_t OC,
284317
int64_t C,
@@ -463,7 +496,7 @@ void conv2d(
463496
VulkanTensor& output,
464497
const VulkanTensor& input,
465498
const VImage& kernelImage,
466-
const c10::optional<float*> bias,
499+
const c10::optional<const float*> bias,
467500
const Conv2DParams& params,
468501
c10::optional<float> output_min,
469502
c10::optional<float> output_max) {
@@ -483,10 +516,22 @@ void conv2d(
483516
VulkanTensor& output,
484517
const VulkanTensor& input,
485518
const VulkanTensor& weight_prepacked,
486-
c10::optional<float*> bias,
519+
c10::optional<const float*> bias,
487520
const Conv2DParams params,
488521
c10::optional<float> output_min,
489522
c10::optional<float> output_max) {
523+
if (params.G > 1) {
524+
conv2d_depthwise(
525+
output,
526+
input,
527+
weight_prepacked,
528+
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
529+
params,
530+
output_min,
531+
output_max);
532+
return;
533+
}
534+
490535
conv2d(
491536
output,
492537
input,
@@ -505,6 +550,18 @@ void conv2d(
505550
const Conv2DParams params,
506551
c10::optional<float> output_min,
507552
c10::optional<float> output_max) {
553+
if (params.G > 1) {
554+
conv2d_depthwise(
555+
output,
556+
input,
557+
weight_prepacked,
558+
*(bias.buffer()),
559+
params,
560+
output_min,
561+
output_max);
562+
return;
563+
}
564+
508565
conv2d(
509566
output,
510567
input,
@@ -519,7 +576,7 @@ void conv2d(
519576
VulkanTensor& output,
520577
const VulkanTensor& input,
521578
const float* weight,
522-
const c10::optional<float*> bias,
579+
const c10::optional<const float*> bias,
523580
const Conv2DParams params,
524581
c10::optional<float> output_min,
525582
c10::optional<float> output_max) {

‎aten/src/ATen/native/vulkan/VulkanOps.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void conv2d(
3737
VulkanTensor& output,
3838
const VulkanTensor& input,
3939
const float* weight,
40-
const c10::optional<float*> bias,
40+
const c10::optional<const float*> bias,
4141
const Conv2DParams params,
4242
c10::optional<float> output_min = c10::nullopt,
4343
c10::optional<float> output_max = c10::nullopt);
@@ -46,7 +46,7 @@ void conv2d(
4646
VulkanTensor& output,
4747
const VulkanTensor& input,
4848
const VulkanTensor& weight_prepacked,
49-
const c10::optional<float*> bias,
49+
const c10::optional<const float*> bias,
5050
const Conv2DParams params,
5151
c10::optional<float> output_min = c10::nullopt,
5252
c10::optional<float> output_max = c10::nullopt);

‎aten/src/ATen/test/vulkan_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ TEST(VulkanTest, conv2dPrepack) {
496496
ASSERT_TRUE(no_prepack_check);
497497

498498
auto prepack = callOpByName(
499-
"vulkan::conv2d_clamp_prepack",
499+
"vulkan_prepack::conv2d_clamp_prepack",
500500
"",
501501
t_w,
502502
t_b,
@@ -507,7 +507,7 @@ TEST(VulkanTest, conv2dPrepack) {
507507
output_min,
508508
output_max);
509509
auto tv_out_prepack_ivalues =
510-
callOpByName("vulkan::conv2d_clamp_run", "", tv_in, prepack[0]);
510+
callOpByName("vulkan_prepack::conv2d_clamp_run", "", tv_in, prepack[0]);
511511
auto tv_out_prepack = tv_out_prepack_ivalues[0].toTensor();
512512
auto t_out_prepack = tv_out_prepack.cpu();
513513
const auto prepack_check = almostEqual(t_out_prepack, t_out_expected);

‎binaries/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,4 @@ endif()
103103
caffe2_binary_target("tutorial_blob.cc")
104104

105105
caffe2_binary_target("dump_operator_names.cc")
106+
caffe2_binary_target("optimize_for_mobile.cc")

‎binaries/optimize_for_mobile.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818

1919
#include "torch/csrc/jit/api/module.h"
20+
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
2021
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
2122
#include "torch/csrc/jit/serialization/import.h"
2223

@@ -29,6 +30,7 @@ C10_DEFINE_bool(
2930
save_for_mobile,
3031
false,
3132
"Save the model with bytecode format compatible with lite inteprter.");
33+
C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
3234

3335
int main(int argc, char** argv) {
3436
c10::SetUsageMessage(
@@ -52,7 +54,10 @@ int main(int argc, char** argv) {
5254
}
5355

5456
auto module = torch::jit::load(FLAGS_model);
55-
auto optimized_module = torch::jit::optimizeForMobile(module);
57+
58+
auto optimized_module = FLAGS_vulkan
59+
? torch::jit::vulkanOptimizeForMobile(module)
60+
: torch::jit::optimizeForMobile(module);
5661

5762
if (FLAGS_save_for_mobile) {
5863
optimized_module._save_for_mobile(output_model_name);

‎test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
'test_optim',
4949
'test_mobile_optimizer',
5050
'test_xnnpack_integration',
51+
'test_vulkan',
5152
'test_quantization',
5253
'test_sparse',
5354
'test_serialization',

0 commit comments

Comments
 (0)
Please sign in to comment.