Skip to content

Commit bdcf320

Browse files
gmagogsfmfacebook-github-bot
authored andcommittedAug 1, 2020
Support custom exception message (pytorch#41907)
Summary: Raise and assert used to have a hard-coded error message "Exception". User provided error message was ignored. This PR adds support to represent user's error message in TorchScript. This breaks backward compatibility because now we actually need to script the user's error message, which can potentially contain unscriptable expressions. Such programs can break when scripting, but saved models can still continue to work. Increased an op count in test_mobile_optimizer.py because now we need aten::format to form the actual exception message. This is built upon an WIP PR: pytorch#34112 by driazati Pull Request resolved: pytorch#41907 Reviewed By: ngimel Differential Revision: D22778301 Pulled By: gmagogsfm fbshipit-source-id: 2b94f0db4ae9fe70c4cd03f4048e519ea96323ad
1 parent 5769b06 commit bdcf320

21 files changed

+232
-58
lines changed
 

‎test/expect/TestScript.test_python_frontend_py3.expect

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
(decl (list) (option))
44
(list
55
(raise
6-
(option
7-
(apply
8-
(variable (ident Exception))
9-
(list (string_literal hello))
10-
(list))))))
6+
(apply
7+
(variable (ident Exception))
8+
(list (string_literal hello))
9+
(list)))))

‎test/test_jit.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -12344,10 +12344,9 @@ def foo(cond):
1234412344
''')
1234512345

1234612346
cu.foo(torch.tensor(0))
12347-
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
12347+
with self.assertRaisesRegex(torch.jit.Error, "3"):
1234812348
cu.foo(torch.tensor(1))
1234912349

12350-
@torch.jit.script
1235112350
def foo(cond):
1235212351
a = 3
1235312352
if bool(cond):
@@ -12356,24 +12355,19 @@ def foo(cond):
1235612355
raise ArbitraryError
1235712356
return a
1235812357

12359-
foo(torch.tensor(0))
12360-
# we don't currently validate the name of the exception
12361-
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
12362-
foo(torch.tensor(1))
12358+
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
12359+
torch.jit.script(foo)
1236312360

12364-
@torch.jit.script
12365-
def foo_except_used():
12361+
def exception_as_value():
1236612362
a = Exception()
1236712363
print(a)
12368-
raise a
1236912364

12370-
# a not DCEd
12371-
with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
12372-
foo_except_used()
12365+
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
12366+
torch.jit.script(exception_as_value)
1237312367

1237412368
@torch.jit.script
1237512369
def foo_no_decl_always_throws():
12376-
raise "Hi"
12370+
raise RuntimeError("Hi")
1237712371

1237812372
# function that has no declared type but always throws set to None
1237912373
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
@@ -12387,11 +12381,12 @@ def foo_decl_always_throws():
1238712381
output_type = next(foo_decl_always_throws.graph.outputs()).type()
1238812382
self.assertTrue(str(output_type) == "Tensor")
1238912383

12390-
# We don't validate the expr following raise
12391-
@torch.jit.script
1239212384
def foo():
1239312385
raise 3 + 4
1239412386

12387+
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
12388+
torch.jit.script(foo)
12389+
1239512390
# a escapes scope
1239612391
@torch.jit.script
1239712392
def foo():
@@ -12405,6 +12400,20 @@ def foo():
1240512400
return a
1240612401
self.assertEqual(foo(), 1)
1240712402

12403+
@torch.jit.script
12404+
def tuple_fn():
12405+
raise RuntimeError("hello", "goodbye")
12406+
12407+
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
12408+
tuple_fn()
12409+
12410+
@torch.jit.script
12411+
def no_message():
12412+
raise RuntimeError
12413+
12414+
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
12415+
no_message()
12416+
1240812417
def test_assertions(self):
1240912418
cu = torch.jit.CompilationUnit('''
1241012419
def foo(cond):
@@ -12413,7 +12422,7 @@ def foo(cond):
1241312422
''')
1241412423

1241512424
cu.foo(torch.tensor(1))
12416-
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
12425+
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
1241712426
cu.foo(torch.tensor(0))
1241812427

1241912428
@torch.jit.script
@@ -12422,7 +12431,7 @@ def foo(cond):
1242212431

1242312432
foo(torch.tensor(1))
1242412433
# we don't currently validate the name of the exception
12425-
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
12434+
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
1242612435
foo(torch.tensor(0))
1242712436

1242812437
def test_python_op_exception(self):
@@ -13211,7 +13220,7 @@ def no_guard_ifs_added(x):
1321113220
def no_ifs_added(x):
1321213221
# type: (int) -> int
1321313222
if x < 0:
13214-
raise RunTimeError("hi")
13223+
raise RuntimeError("hi")
1321513224
return x
1321613225

1321713226
self.checkScript(no_ifs_added, (1,))
@@ -13226,7 +13235,7 @@ def test_if_might(x):
1322613235
else:
1322713236
a = 2
1322813237
else:
13229-
raise RunTimeError("hi")
13238+
raise RuntimeError("hi")
1323013239
return a + 2
1323113240

1323213241
self.checkScript(test_if_might, (1,))
@@ -13238,7 +13247,7 @@ def test_loop_no_escape(x):
1323813247
# type: (int)
1323913248
if x >= 0:
1324013249
for i in range(x):
13241-
raise RunTimeError("hi")
13250+
raise RuntimeError("hi")
1324213251
else:
1324313252
return 5
1324413253
return x + 3
@@ -13255,7 +13264,7 @@ def test_loop_exception_with_continue(x):
1325513264
i = 0
1325613265
for i in range(5):
1325713266
if i == x:
13258-
raise RunTimeError("hi")
13267+
raise RuntimeError("hi")
1325913268
else:
1326013269
continue
1326113270
print(i)
@@ -13272,7 +13281,7 @@ def no_return_func(self):
1327213281
# type: (Tensor) -> Tensor
1327313282
output = torch.tanh(self)
1327413283
def backward(grad_output):
13275-
raise "Hi"
13284+
raise RuntimeError("Hi")
1327613285
''')
1327713286
with self.assertRaisesRegex(RuntimeError, "does not return along all"):
1327813287
cu = torch.jit.CompilationUnit(code)
@@ -13283,7 +13292,7 @@ def test_exit_pair_reset(x):
1328313292
if x > 0:
1328413293
a = 0
1328513294
def backward(grad_output):
13286-
raise "Hi"
13295+
raise RuntimeError("Hi")
1328713296
a = a + 1
1328813297
else:
1328913298
return x

‎test/test_mobile_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def forward(self, x):
112112
bn_test_module = BNTestModule()
113113
bn_scripted_module = torch.jit.script(bn_test_module)
114114
bn_scripted_module.eval()
115-
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 13)
115+
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
116116
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
117117
.run(str(get_forward(bn_scripted_module._c).graph))
118118

‎torch/_jit_internal.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def __getattr__(self, key):
126126
return f_locals[key]
127127
elif key in f_globals:
128128
return f_globals[key]
129+
elif key in dir(builtins):
130+
return getattr(builtins, key)
129131

130132
return createResolutionCallbackFromEnv(env())
131133

@@ -229,7 +231,13 @@ def createResolutionCallbackForClassMethods(cls):
229231
for fn in fns:
230232
captures.update(get_closure(fn))
231233

232-
return lambda key: captures.get(key, None)
234+
def lookup_in_class(key):
235+
if key in captures:
236+
return captures[key]
237+
else:
238+
return getattr(builtins, key, None)
239+
240+
return lookup_in_class
233241

234242

235243
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
@@ -820,3 +828,8 @@ def __enter__(self):
820828

821829
def __exit__(self, *args):
822830
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
831+
832+
def _is_exception(obj):
833+
if not inspect.isclass(obj):
834+
return False
835+
return issubclass(obj, Exception)

‎torch/_linalg_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def is_sparse(A):
1212
"""Check if tensor A is a sparse tensor"""
1313
if isinstance(A, torch.Tensor):
1414
return A.layout == torch.sparse_coo
15-
raise TypeError("expected Tensor but got %s" % (type(A).__name__))
1615

16+
error_str = "expected Tensor"
17+
if not torch.jit.is_scripting():
18+
error_str += " but got {}".format(type(A))
19+
raise TypeError(error_str)
1720

1821
def get_floating_dtype(A):
1922
"""Return the floating point dtype of tensor A.

‎torch/_lobpcg.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,8 @@ def update_converged_count(self):
377377
# strict ordering of eigenpairs
378378
break
379379
count += 1
380-
assert count >= prev_count, (
381-
'the number of converged eigenpairs '
382-
'(was %s, got %s) cannot decrease' % (prev_count, count))
380+
assert count >= prev_count, 'the number of converged eigenpairs ' \
381+
'(was {}, got {}) cannot decrease'.format(prev_count, count)
383382
self.ivars['converged_count'] = count
384383
self.tvars['rerr'] = rerr
385384
return count
@@ -723,10 +722,14 @@ def _get_ortho(self, U, V):
723722
if rerr < tau_ortho:
724723
break
725724
if m < U.shape[-1] + V.shape[-1]:
725+
# TorchScript needs the class var to be assigned to a local to
726+
# do optional type refinement
727+
B = self.B
728+
assert B is not None
726729
raise ValueError(
727730
'Overdetermined shape of U:'
728731
' #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold'
729-
.format(self.B.shape[-1], U.shape[-1], V.shape[-1]))
732+
.format(B.shape[-1], U.shape[-1], V.shape[-1]))
730733
self.ivars['ortho_i'] = i
731734
self.ivars['ortho_j'] = j
732735
return U

‎torch/csrc/jit/frontend/ir_emitter.cpp

+44-6
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ struct Environment {
494494
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
495495
{"sorted",
496496
std::make_shared<BuiltinFunction>(aten::sorted, at::nullopt)},
497+
// Only AssertionError is bound so that we can use it from emitAssert,
498+
// all other exceptions should be resolved at the Python level
499+
{"AssertionError",
500+
std::make_shared<ExceptionValue>("AssertionError")},
497501
};
498502
auto it = globals.find(ident);
499503
if (it != globals.end()) {
@@ -1024,7 +1028,7 @@ struct to_ir {
10241028
emitSugaredExpr(expr, 0);
10251029
} break;
10261030
case TK_RAISE:
1027-
emitRaise(Raise(stmt).range());
1031+
emitRaise(Raise(stmt));
10281032
break;
10291033
case TK_ASSERT:
10301034
emitAssert(Assert(stmt));
@@ -1838,19 +1842,53 @@ struct to_ir {
18381842
// raise a
18391843
//
18401844
// We ignore the expression following raise
1841-
void emitRaise(const SourceRange& loc) {
1842-
const std::string exception = "Exception";
1843-
auto string_input = insertConstant(*graph, exception, loc);
1844-
graph->insert(prim::RaiseException, {string_input}, {}, loc);
1845+
void emitRaise(const Raise& raise) {
1846+
auto sv = emitSugaredExpr(raise.expr(), 1);
1847+
Value* error_message = nullptr;
1848+
1849+
if (auto exception_instance =
1850+
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
1851+
// The typical case, an instance of the exception class was thrown:
1852+
// raise RuntimeError("error")
1853+
error_message = exception_instance->getValue();
1854+
} else if (
1855+
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
1856+
// A bare exception was thrown so add an empty message. e.g.
1857+
// raise RuntimeError
1858+
error_message = insertConstant(*graph, "", raise.range());
1859+
} else {
1860+
// The raise was not followed by an exception (i.e. it was something like
1861+
// `raise "error"` instead of `raise RuntimeError("error")`)
1862+
throw ErrorReport(raise.range())
1863+
<< "exceptions must derive from BaseException";
1864+
}
1865+
1866+
if (!error_message->type()->isSubtypeOf(StringType::get())) {
1867+
error_message = graph->insert(aten::str, {error_message});
1868+
}
1869+
1870+
graph->insert(prim::RaiseException, {error_message}, {}, raise.range());
18451871
exit_blocks.insert(environment_stack->block());
18461872
}
18471873

18481874
// emit assserions as an if branch so that assertions will reuse the
18491875
void emitAssert(const Assert& stmt) {
18501876
CondValue cond_value = emitCondExpr(stmt.test());
18511877
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
1878+
// Create an `AssertionError("the_message")` call
1879+
auto message = (stmt.msg().present())
1880+
? stmt.msg().get()
1881+
: StringLiteral::create(stmt.range(), "");
1882+
auto callee = Var::create(
1883+
stmt.range(), Ident::create(stmt.range(), "AssertionError"));
1884+
auto apply = Apply::create(
1885+
stmt.range(),
1886+
callee,
1887+
List<Expr>::create(stmt.range(), {message}),
1888+
List<Attribute>::create(stmt.range(), {}));
1889+
18521890
List<Stmt> false_branch =
1853-
List<Stmt>::create(stmt.range(), {Raise::create(stmt.range())});
1891+
List<Stmt>::create(stmt.range(), {Raise::create(stmt.range(), apply)});
18541892
emitIfElseBlocks(stmt.range(), cond_value, true_branch, false_branch);
18551893
}
18561894

‎torch/csrc/jit/frontend/sugared_value.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,8 @@ std::shared_ptr<SugaredValue> ClassValue::attr(
597597
Function& m,
598598
const std::string& field) {
599599
if (field != "__new__") {
600-
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class";
600+
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class "
601+
<< type_->annotation_str();
601602
}
602603
return SpecialFormValue::create(prim::CreateObject);
603604
}

‎torch/csrc/jit/frontend/sugared_value.h

+47
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <memory>
44
#include <string>
55

6+
#include <ATen/core/interned_strings.h>
67
#include <torch/csrc/jit/api/module.h>
78
#include <torch/csrc/jit/frontend/error_report.h>
89
#include <torch/csrc/jit/frontend/schema_matching.h>
@@ -704,5 +705,51 @@ struct SimpleSelf : public Self {
704705
private:
705706
ClassTypePtr classType_;
706707
};
708+
709+
// This is not a SimpleValue so it can not pass through the code paths that
710+
// expect a SimpleValue as a sugared value.
711+
struct TORCH_API ExceptionMessageValue : public SugaredValue {
712+
explicit ExceptionMessageValue(Value* value) : value_(value) {}
713+
714+
std::string kind() const override {
715+
return "exception message";
716+
}
717+
718+
Value* getValue() {
719+
return value_;
720+
}
721+
722+
Value* value_;
723+
};
724+
725+
struct TORCH_API ExceptionValue : public SugaredValue {
726+
explicit ExceptionValue(const std::string& message) : message_(message) {}
727+
728+
std::string kind() const override {
729+
return "exception";
730+
}
731+
732+
std::shared_ptr<SugaredValue> call(
733+
const SourceRange& loc,
734+
Function& m,
735+
at::ArrayRef<NamedValue> inputs,
736+
at::ArrayRef<NamedValue> /*attributes*/,
737+
size_t /*n_binders*/) override {
738+
auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc);
739+
for (auto& input : inputs) {
740+
auto input_str = input.value(*m.graph());
741+
if (!input_str->type()->isSubtypeOf(StringType::get())) {
742+
input_str =
743+
emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {});
744+
}
745+
exception_message = emitBuiltinCall(
746+
loc, *m.graph(), aten::add, {exception_message, input_str}, {});
747+
}
748+
return std::make_shared<ExceptionMessageValue>(exception_message);
749+
}
750+
751+
std::string message_;
752+
};
753+
707754
} // namespace jit
708755
} // namespace torch

‎torch/csrc/jit/frontend/tree_views.h

+3-7
Original file line numberDiff line numberDiff line change
@@ -645,16 +645,12 @@ struct Raise : public Stmt {
645645
explicit Raise(const TreeRef& tree) : Stmt(tree) {
646646
tree_->match(TK_RAISE);
647647
}
648-
Maybe<Expr> expr() const {
649-
return Maybe<Expr>(subtree(0));
648+
Expr expr() const {
649+
return Expr(subtree(0));
650650
}
651-
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
651+
static Raise create(const SourceRange& range, const Expr& expr) {
652652
return Raise(Compound::create(TK_RAISE, range, {expr}));
653653
}
654-
static Raise create(const SourceRange& range) {
655-
return Raise(
656-
Compound::create(TK_RAISE, range, {Maybe<Expr>::create(range)}));
657-
}
658654
};
659655

660656
struct Assert : public Stmt {

‎torch/csrc/jit/passes/constant_propagation.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
5555
isinstance(stack, n->tys(attr::types));
5656
} break;
5757
default: {
58+
const auto& the_operator = n->getOperator();
59+
if (the_operator.schema().is_vararg()) {
60+
// vararg schemas require the number of inputs at the top of the stack
61+
// but this is broken in other places in constant prop, so disable it
62+
// for now
63+
return c10::nullopt;
64+
}
65+
5866
auto op = n->getOperation();
5967
try {
6068
op(&stack);

‎torch/csrc/jit/passes/graph_fuser.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ struct GraphFuser {
164164
: block_(block),
165165
aliasDb_(aliasDb),
166166
callback_(std::move(callback)),
167-
kind_(kind) {}
167+
kind_(kind),
168+
strict_fuser_check_(false) {}
168169

169170
void setInputArgLimit(size_t limit) {
170171
subgraph_arg_limit_ = limit;

‎torch/csrc/jit/python/python_sugared_value.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,36 @@ std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
696696
return value->call(loc, caller, inputs, attributes, n_binders);
697697
}
698698

699+
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
700+
const SourceRange& loc,
701+
Function& caller,
702+
at::ArrayRef<NamedValue> inputs,
703+
at::ArrayRef<NamedValue> attributes,
704+
size_t /*n_binders*/) {
705+
Value* error_message = nullptr;
706+
if (inputs.size() == 0) {
707+
error_message = insertConstant(*caller.graph(), "", loc);
708+
} else if (inputs.size() == 1) {
709+
error_message = inputs.at(0).value(*caller.graph());
710+
} else {
711+
std::vector<Value*> message_values;
712+
message_values.reserve(inputs.size() + attributes.size());
713+
714+
for (auto inp : inputs) {
715+
message_values.push_back(inp.value(*caller.graph()));
716+
}
717+
for (auto kwarg_inp : attributes) {
718+
message_values.push_back(kwarg_inp.value(*caller.graph()));
719+
}
720+
error_message =
721+
caller.graph()
722+
->insertNode(caller.graph()->createTuple(message_values))
723+
->output();
724+
}
725+
726+
return std::make_shared<ExceptionMessageValue>(error_message);
727+
}
728+
699729
bool isNamedTupleClass(const py::object& obj) {
700730
auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
701731
return PyObject_IsSubclass(obj.ptr(), tuple_type) &&
@@ -859,6 +889,11 @@ std::shared_ptr<SugaredValue> toSugaredValue(
859889
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
860890
}
861891

892+
if (py::cast<bool>(py::module::import("torch._jit_internal")
893+
.attr("_is_exception")(obj))) {
894+
return std::make_shared<PythonExceptionValue>(obj);
895+
}
896+
862897
if (py::isinstance<py::function>(obj)) {
863898
if (typeString(obj) == "builtin_function_or_method") {
864899
throw ErrorReport(loc) << "Python builtin " << py::str(obj)

‎torch/csrc/jit/python/python_sugared_value.h

+17
Original file line numberDiff line numberDiff line change
@@ -318,5 +318,22 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
318318
py::object py_type_;
319319
};
320320

321+
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
322+
explicit PythonExceptionValue(const py::object& exception_class)
323+
: ExceptionValue(
324+
py::str(py::getattr(exception_class, "__name__", py::str("")))) {}
325+
326+
std::string kind() const override {
327+
return "Python exception";
328+
}
329+
330+
std::shared_ptr<SugaredValue> call(
331+
const SourceRange& loc,
332+
Function& caller,
333+
at::ArrayRef<NamedValue> inputs,
334+
at::ArrayRef<NamedValue> attributes,
335+
size_t n_binders) override;
336+
};
337+
321338
} // namespace jit
322339
} // namespace torch

‎torch/csrc/jit/python/python_tree_views.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ void initTreeViewBindings(PyObject* module) {
207207
range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
208208
}));
209209
py::class_<Raise, Stmt>(m, "Raise")
210-
.def(py::init([](const SourceRange& range, Expr* expr) {
211-
return Raise::create(range, wrap_maybe(range, expr));
210+
.def(py::init([](const SourceRange& range, const Expr& expr) {
211+
return Raise::create(range, expr);
212212
}));
213213
py::class_<Assert, Stmt>(m, "Assert")
214214
.def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {

‎torch/jit/annotations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def try_ann_to_type(ann, loc):
320320
if inspect.isclass(ann):
321321
if hasattr(ann, "__torch_script_class__"):
322322
return ClassType(_qualified_name(ann))
323-
ignored_builtin_classes = (torch.nn.Module, tuple, list)
323+
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
324324
if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
325325
torch.jit._script._recursive_compile_class(ann, loc)
326326
return ClassType(_qualified_name(ann))

‎torch/jit/quantized.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def get_expected_hidden_size(self, input, batch_sizes):
332332
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
333333
# type: (Tensor, Tuple[int, int, int], str) -> None
334334
if hx.size() != expected_hidden_size:
335-
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
335+
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
336336

337337
@torch.jit.script_method
338338
def check_forward_args(self, input, hidden, batch_sizes):

‎torch/nn/functional.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1949,10 +1949,14 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
19491949

19501950
if input.dim() == 2:
19511951
if offsets is not None:
1952+
type_str = "<unknown>"
1953+
# TODO: Remove this once script supports type() calls
1954+
if not torch.jit.is_scripting():
1955+
type_str = str(type(offsets))
19521956
raise ValueError("if input is 2D, then offsets has to be None"
19531957
", as input is treated is a mini-batch of"
19541958
" fixed length sequences. However, found "
1955-
"offsets of type {}".format(type(offsets)))
1959+
"offsets of type {}".format(type_str))
19561960
offsets = torch.arange(0, input.numel(), input.size(1),
19571961
dtype=torch.long, device=input.device)
19581962

‎torch/nn/modules/rnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor])
192192
def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
193193
msg: str = 'Expected hidden size {}, got {}') -> None:
194194
if hx.size() != expected_hidden_size:
195-
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
195+
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
196196

197197
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
198198
self.check_input(input, batch_sizes)

‎torch/nn/quantized/dynamic/modules/rnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size
177177
# type: (Tensor, Tuple[int, int, int], str) -> None
178178
if hx.size() != expected_hidden_size:
179179
raise RuntimeError(msg.format(
180-
expected_hidden_size, tuple(hx.size())))
180+
expected_hidden_size, list(hx.size())))
181181

182182
def check_forward_args(self, input, hidden, batch_sizes):
183183
# type: (Tensor, Tensor, Optional[Tensor]) -> None

‎torch/testing/_internal/distributed/rpc/jit/rpc_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def rpc_async_call_remote_raising_torchscript_in_torchscript(
754754
ret = fut.wait()
755755
return ret
756756

757-
with self.assertRaisesRegex(RuntimeError, "Exception"):
757+
with self.assertRaisesRegex(RuntimeError, "Expected error"):
758758
ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
759759
dst_worker_name
760760
)

0 commit comments

Comments
 (0)
Please sign in to comment.