Skip to content

Commit 3f5ea23

Browse files
vkuzofacebook-github-bot
authored andcommittedAug 28, 2020
Adding a version serialization type to ConvPackedParam (pytorch#43086)
Summary: Pull Request resolved: pytorch#43086 This PR changes the format of `ConvPackedParam` in a nearly backwards-compatible way: * a new format is introduced which has more flexibility and a lower on-disk size * custom pickle functions are added to `ConvPackedParams` which know how to load the old format * the custom pickle functions are **not** BC because the output type of `__getstate__` has changed. We expect this to be acceptable as no user flows are actually broken (loading a v1 model with v2 code works), which is why we whitelist the failure. Test plan (TODO finalize): ``` // adhoc testing of saving v1 and loading in v2: https://gist.github.com/vkuzo/f3616c5de1b3109cb2a1f504feed69be // test that loading models with v1 conv params format works and leads to the same numerics python test/test_quantization.py TestSerialization.test_conv2d_graph python test/test_quantization.py TestSerialization.test_conv2d_nobias_graph // test that saving and loading models with v2 conv params format works and leads to same numerics python test/test_quantization.py TestSerialization.test_conv2d_graph_v2 python test/test_quantization.py TestSerialization.test_conv2d_nobias_graph_v2 // TODO before land: // test numerics for a real model // test legacy ONNX path ``` Note: this is a newer copy of pytorch#40003 Test Plan: Imported from OSS Reviewed By: dreiss Differential Revision: D23347832 Pulled By: vkuzo fbshipit-source-id: 06bbe4666421ebad25dc54004c3b49a481d3cc92
1 parent af4ecb3 commit 3f5ea23

14 files changed

+474
-128
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/List.h>
5+
6+
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
7+
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
8+
9+
#include <tuple>
10+
11+
/* Convolution prepacked parameters serialization.
12+
*
13+
* Version 1
14+
*
15+
* - Fields:
16+
* 1. weight
17+
* 2. bias
18+
* 3. stride x kSpatialDim
19+
* 4. padding x kSpatialDim
20+
* 5. dilation x kSpatialDim
21+
* 6. groups
22+
*
23+
* Version 2
24+
*
25+
* - Fields:
26+
* 0. version (string)
27+
* 1. list of non-optional tensors
28+
* 0: packed parameters (int16_t)
29+
* - kSpatialDim
30+
* - stride x kSpatialDim
31+
* - padding x kSpatialDim
32+
* - dilation x kSpatialDim
33+
* - output_padding x kSpatialDim (unused)
34+
* - groups
35+
* - transpose (0 or 1, unused)
36+
* 1: weight
37+
* 2. list of optional tensors
38+
* 0: bias
39+
*
40+
* Note: version is a string and conv params are packed into a Tensor
41+
* to make ONNX happy (ints and containers of ints are not supported).
42+
*/
43+
44+
// version 1
45+
using ConvParamsSerializationTypeLegacy = std::tuple<
46+
// weight
47+
at::Tensor,
48+
// bias
49+
c10::optional<at::Tensor>,
50+
// stride x kSpatialDim
51+
torch::List<at::Tensor>,
52+
// padding x kSpatialDim
53+
torch::List<at::Tensor>,
54+
// dilation x kSpatialDim
55+
torch::List<at::Tensor>,
56+
// groups
57+
at::Tensor>;
58+
59+
// version 2
60+
using ConvParamsSerializationType = std::tuple<
61+
// version, for versions 2 and up
62+
std::string,
63+
// non-optional tensors
64+
std::vector<at::Tensor>,
65+
// optional tensors
66+
std::vector<c10::optional<at::Tensor>>>;
67+
68+
// Parses any historical conv packed params format into
69+
// the current format.
70+
template <uint32_t kSpatialDim>
71+
ConvParamsSerializationType parse_conv_serialized_state(c10::IValue v) {
72+
73+
// determine the version based on IValue contents
74+
int version = -1;
75+
if (v.isTuple()) {
76+
auto elements = v.toTuple()->elements();
77+
if (elements.size() > 0) {
78+
auto firstElement = elements[0];
79+
if (firstElement.isTensor()) {
80+
version = 1;
81+
} else if (firstElement.isString()) {
82+
std::string version_str = firstElement.toStringRef();
83+
// note: not parsing the string to automatically handle bad
84+
// inputs
85+
if (version_str == "2") {
86+
version = 2;
87+
}
88+
}
89+
}
90+
}
91+
TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
92+
93+
if (version == 1) {
94+
// version 1 - convert to version 2 manually
95+
96+
auto elements = v.toTuple()->elements();
97+
98+
at::Tensor weight = elements[0].toTensor();
99+
c10::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
100+
torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
101+
torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
102+
torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
103+
at::Tensor groups = elements[5].toTensor();
104+
105+
std::string version = "2";
106+
std::vector<at::Tensor> non_optional;
107+
std::vector<c10::optional<at::Tensor>> optional;
108+
109+
std::vector<int16_t> params_vec;
110+
params_vec.push_back(kSpatialDim);
111+
for (int i = 0; i < stride_x_kSpatialDim.size(); i++) {
112+
auto stride = stride_x_kSpatialDim.get(i);
113+
params_vec.push_back(stride[0].item<int16_t>());
114+
}
115+
for (int i = 0; i < padding_x_kSpatialDim.size(); i++) {
116+
auto padding = padding_x_kSpatialDim.get(i);
117+
params_vec.push_back(padding[0].item<int16_t>());
118+
}
119+
for (int i = 0; i < dilation_x_kSpatialDim.size(); i++) {
120+
auto dilation = dilation_x_kSpatialDim.get(i);
121+
params_vec.push_back(dilation[0].item<int16_t>());
122+
}
123+
// output_padding does not exist in v1, so we fill in a default value
124+
for (int i = 0; i < kSpatialDim; i++) {
125+
params_vec.push_back(0);
126+
}
127+
params_vec.push_back(groups[0].item<int16_t>());
128+
// transpose does not exist in v1, so we fill in a default value
129+
params_vec.push_back(0);
130+
int64_t vec_size = params_vec.size();
131+
at::Tensor params_tensor = at::from_blob(params_vec.data(),
132+
{vec_size}, at::TensorOptions().dtype(at::kShort))
133+
// clone to retain ownership of the data
134+
.clone();
135+
136+
non_optional.emplace_back(std::move(params_tensor));
137+
non_optional.emplace_back(std::move(weight));
138+
optional.emplace_back(std::move(bias));
139+
140+
return std::tie(version, non_optional, optional);
141+
} else if (version == 2) {
142+
// version 2
143+
return v.to<ConvParamsSerializationType>();
144+
} else {
145+
TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
146+
version);
147+
}
148+
}
149+
150+
template <uint32_t kSpatialDim>
151+
ConvParamsSerializationType serialize_conv(
152+
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
153+
154+
std::string version = "2";
155+
std::vector<at::Tensor> non_optional;
156+
std::vector<c10::optional<at::Tensor>> optional;
157+
158+
// create a packed int8_t tensor for conv params
159+
std::vector<int16_t> params_vec;
160+
params_vec.push_back(kSpatialDim);
161+
auto stride = params->stride().vec();
162+
params_vec.insert(params_vec.end(), stride.begin(), stride.end());
163+
auto padding = params->padding().vec();
164+
params_vec.insert(params_vec.end(), padding.begin(), padding.end());
165+
auto dilation = params->dilation().vec();
166+
params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
167+
// output_padding is not implemented yet, so we fill in a default value
168+
for (int i = 0; i < kSpatialDim; i++) {
169+
params_vec.push_back(0);
170+
}
171+
params_vec.push_back(params->groups());
172+
// transpose is not implemented yet, so we fill in a default value
173+
params_vec.push_back(0);
174+
int64_t vec_size = params_vec.size();
175+
at::Tensor params_tensor = at::from_blob(
176+
params_vec.data(), {vec_size},
177+
at::TensorOptions().dtype(at::kShort))
178+
// clone to retain ownership of the data
179+
.clone();
180+
181+
at::Tensor weight;
182+
c10::optional<at::Tensor> bias;
183+
std::tie(weight, bias) = params->unpack();
184+
185+
non_optional.emplace_back(std::move(params_tensor));
186+
non_optional.emplace_back(std::move(weight));
187+
optional.emplace_back(std::move(bias));
188+
189+
return std::tie(version, non_optional, optional);
190+
}
191+
192+
template <uint32_t kSpatialDim>
193+
ConvParamsSerializationTypeLegacy serialize_conv_legacy(
194+
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
195+
at::Tensor weight;
196+
c10::optional<at::Tensor> bias;
197+
std::tie(weight, bias) = params->unpack();
198+
torch::List<at::Tensor> stride;
199+
torch::List<at::Tensor> padding;
200+
torch::List<at::Tensor> dilation;
201+
at::Tensor groups;
202+
for (int64_t s : params->stride()) {
203+
stride.emplace_back(at::tensor(s));
204+
}
205+
for (int64_t p : params->padding()) {
206+
padding.emplace_back(at::tensor(p));
207+
}
208+
for (int64_t d : params->dilation()) {
209+
dilation.emplace_back(at::tensor(d));
210+
}
211+
groups = at::tensor(params->groups());
212+
return std::make_tuple(
213+
std::move(weight),
214+
std::move(bias),
215+
stride,
216+
padding,
217+
dilation,
218+
groups);
219+
}
220+
221+
template <uint32_t kSpatialDim>
222+
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
223+
ConvParamsSerializationType state) {
224+
225+
std::string version;
226+
std::vector<at::Tensor> non_optional;
227+
std::vector<c10::optional<at::Tensor>> optional;
228+
229+
std::tie(version, non_optional, optional) = state;
230+
TORCH_INTERNAL_ASSERT(version == "2", "Unexpected serialized qconv version: ",
231+
version);
232+
233+
at::Tensor conv_params_packed = non_optional[0];
234+
at::Tensor weight = non_optional[1];
235+
c10::optional<at::Tensor> bias = optional[0];
236+
237+
torch::List<int64_t> stride, padding, dilation;
238+
// skip kSpatialDim
239+
int idx = 1;
240+
for (int i = 0; i < kSpatialDim; ++i) {
241+
stride.emplace_back(conv_params_packed[idx].item<int64_t>());
242+
idx++;
243+
}
244+
for (int i = 0; i < kSpatialDim; ++i) {
245+
padding.emplace_back(conv_params_packed[idx].item<int64_t>());
246+
idx++;
247+
}
248+
for (int i = 0; i < kSpatialDim; ++i) {
249+
dilation.emplace_back(conv_params_packed[idx].item<int64_t>());
250+
idx++;
251+
}
252+
// output_padding is not implemented yet, so we skip the entries
253+
for (int i = 0; i < kSpatialDim; ++i) {
254+
// do nothing
255+
idx++;
256+
}
257+
int64_t groups = conv_params_packed[idx].item<int64_t>();
258+
idx++;
259+
// transpose is not implemented yet, so we skip the entry
260+
idx++;
261+
TORCH_INTERNAL_ASSERT(idx == conv_params_packed.numel(),
262+
"Unexpected length of conv_params_packed, expected ",
263+
idx,
264+
" got ",
265+
conv_params_packed.numel());
266+
267+
auto& ctx = at::globalContext();
268+
269+
#ifdef USE_FBGEMM
270+
if (ctx.qEngine() == at::QEngine::FBGEMM) {
271+
return PackedConvWeight<kSpatialDim>::prepack(
272+
weight,
273+
bias,
274+
stride,
275+
padding,
276+
dilation,
277+
groups
278+
);
279+
}
280+
#endif // USE_FBGEMM
281+
#ifdef USE_PYTORCH_QNNPACK
282+
if (ctx.qEngine() == at::QEngine::QNNPACK) {
283+
TORCH_CHECK(
284+
kSpatialDim == 2,
285+
"prepack/__setstate__: QNNPACK only supports Conv2d "
286+
"now.");
287+
return PackedConvWeightsQnnp<kSpatialDim>::prepack(
288+
weight,
289+
bias,
290+
stride,
291+
padding,
292+
dilation,
293+
groups
294+
);
295+
}
296+
#endif // USE_PYTORCH_QNNPACK
297+
TORCH_CHECK(
298+
false,
299+
"Didn't find engine for when deserializing ConvPackedParams: ",
300+
toString(ctx.qEngine()));
301+
}

0 commit comments

Comments
 (0)
Please sign in to comment.