Skip to content

Commit eace053

Browse files
ezyangfacebook-github-bot
authored andcommittedJun 11, 2020
Move all torch.nn.modules type annotations inline (pytorch#38211)
Summary: Pull Request resolved: pytorch#38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
1 parent e22dd56 commit eace053

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1048
-2141
lines changed
 

‎aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp

+20-5
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ void compute_fused_params(
4242
template <bool ReluFused>
4343
Tensor q_batch_norm2d_impl(
4444
Tensor qx,
45-
Tensor weight,
46-
Tensor bias,
45+
c10::optional<Tensor> mb_weight,
46+
c10::optional<Tensor> mb_bias,
4747
Tensor mean,
4848
Tensor var,
4949
double eps,
5050
double output_scale,
5151
int64_t output_zero_point) {
5252

53+
TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
54+
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
55+
const auto& weight = *mb_weight;
56+
const auto& bias = *mb_bias;
57+
5358
if (qx.numel() == 0) {
5459
auto out = qx.clone();
5560
return out;
@@ -131,14 +136,20 @@ Tensor q_batch_norm2d_impl(
131136
template <bool ReluFused>
132137
Tensor q_batch_norm3d_impl(
133138
Tensor qx,
134-
Tensor weight,
135-
Tensor bias,
139+
c10::optional<Tensor> mb_weight,
140+
c10::optional<Tensor> mb_bias,
136141
Tensor mean,
137142
Tensor var,
138143
double eps,
139144
double output_scale,
140145
int64_t output_zero_point) {
141146

147+
TORCH_CHECK(mb_weight.has_value(), "Weight must be provided")
148+
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided")
149+
150+
const auto& weight = *mb_weight;
151+
const auto& bias = *mb_bias;
152+
142153
if (qx.numel() == 0) {
143154
auto out = qx.clone();
144155
return out;
@@ -231,8 +242,12 @@ Tensor quantized_batch_norm(
231242
double output_scale,
232243
int64_t output_zero_point) {
233244
Tensor qy;
245+
// TODO: this should arguably support 3d as well
234246
qy = q_batch_norm2d_impl<false>(
235-
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
247+
qx,
248+
weight.defined() ? c10::make_optional(weight) : c10::nullopt,
249+
bias.defined() ? c10::make_optional(bias) : c10::nullopt,
250+
mean, var, eps, output_scale, output_zero_point);
236251
return qy;
237252
}
238253

‎aten/src/ATen/native/quantized/cpu/qnormalization.cpp

+19-9
Original file line numberDiff line numberDiff line change
@@ -123,33 +123,43 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
123123
m.impl("layer_norm", [](
124124
Tensor input,
125125
std::vector<int64_t> normalized_shape, // because IntArrayRef doesn't work
126-
Tensor weight /* optional */,
127-
Tensor bias /* optional */,
126+
c10::optional<Tensor> weight,
127+
c10::optional<Tensor> bias,
128128
double eps,
129129
double output_scale,
130130
int64_t output_zero_point) {
131-
return quantized_layer_norm_impl(input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
131+
return quantized_layer_norm_impl(
132+
input, normalized_shape,
133+
weight.has_value() ? *weight : Tensor(),
134+
bias.has_value() ? *bias : Tensor(),
135+
eps, output_scale, output_zero_point);
132136
});
133137
m.impl("group_norm", [](
134138
Tensor qx,
135139
int64_t num_groups,
136-
Tensor weight,
137-
Tensor bias,
140+
c10::optional<Tensor> weight,
141+
c10::optional<Tensor> bias,
138142
double eps,
139143
double output_scale,
140144
int64_t output_zero_point) {
141145
return quantized_group_norm_impl(
142-
qx, num_groups, weight, bias, eps, output_scale, output_zero_point);
146+
qx, num_groups,
147+
weight.has_value() ? *weight : Tensor(),
148+
bias.has_value() ? *bias : Tensor(),
149+
eps, output_scale, output_zero_point);
143150
});
144151
m.impl("instance_norm", [](
145152
Tensor qx,
146-
Tensor weight,
147-
Tensor bias,
153+
c10::optional<Tensor> weight,
154+
c10::optional<Tensor> bias,
148155
double eps,
149156
double output_scale,
150157
int64_t output_zero_point) {
151158
return quantized_instance_norm_impl(
152-
qx, weight, bias, eps, output_scale, output_zero_point);
159+
qx,
160+
weight.has_value() ? *weight : Tensor(),
161+
bias.has_value() ? *bias : Tensor(),
162+
eps, output_scale, output_zero_point);
153163
});
154164
}
155165

0 commit comments

Comments
 (0)