@@ -12344,10 +12344,9 @@ def foo(cond):
12344
12344
''')
12345
12345
12346
12346
cu.foo(torch.tensor(0))
12347
- with self.assertRaisesRegex(torch.jit.Error, "Exception "):
12347
+ with self.assertRaisesRegex(torch.jit.Error, "3 "):
12348
12348
cu.foo(torch.tensor(1))
12349
12349
12350
- @torch.jit.script
12351
12350
def foo(cond):
12352
12351
a = 3
12353
12352
if bool(cond):
@@ -12356,24 +12355,19 @@ def foo(cond):
12356
12355
raise ArbitraryError
12357
12356
return a
12358
12357
12359
- foo(torch.tensor(0))
12360
- # we don't currently validate the name of the exception
12361
- with self.assertRaisesRegex(torch.jit.Error, "Exception"):
12362
- foo(torch.tensor(1))
12358
+ with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
12359
+ torch.jit.script(foo)
12363
12360
12364
- @torch.jit.script
12365
- def foo_except_used():
12361
+ def exception_as_value():
12366
12362
a = Exception()
12367
12363
print(a)
12368
- raise a
12369
12364
12370
- # a not DCEd
12371
- with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
12372
- foo_except_used()
12365
+ with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
12366
+ torch.jit.script(exception_as_value)
12373
12367
12374
12368
@torch.jit.script
12375
12369
def foo_no_decl_always_throws():
12376
- raise "Hi"
12370
+ raise RuntimeError( "Hi")
12377
12371
12378
12372
# function that has no declared type but always throws set to None
12379
12373
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
@@ -12387,11 +12381,12 @@ def foo_decl_always_throws():
12387
12381
output_type = next(foo_decl_always_throws.graph.outputs()).type()
12388
12382
self.assertTrue(str(output_type) == "Tensor")
12389
12383
12390
- # We don't validate the expr following raise
12391
- @torch.jit.script
12392
12384
def foo():
12393
12385
raise 3 + 4
12394
12386
12387
+ with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
12388
+ torch.jit.script(foo)
12389
+
12395
12390
# a escapes scope
12396
12391
@torch.jit.script
12397
12392
def foo():
@@ -12405,6 +12400,20 @@ def foo():
12405
12400
return a
12406
12401
self.assertEqual(foo(), 1)
12407
12402
12403
+ @torch.jit.script
12404
+ def tuple_fn():
12405
+ raise RuntimeError("hello", "goodbye")
12406
+
12407
+ with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
12408
+ tuple_fn()
12409
+
12410
+ @torch.jit.script
12411
+ def no_message():
12412
+ raise RuntimeError
12413
+
12414
+ with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
12415
+ no_message()
12416
+
12408
12417
def test_assertions(self):
12409
12418
cu = torch.jit.CompilationUnit('''
12410
12419
def foo(cond):
@@ -12413,7 +12422,7 @@ def foo(cond):
12413
12422
''')
12414
12423
12415
12424
cu.foo(torch.tensor(1))
12416
- with self.assertRaisesRegex(torch.jit.Error, "Exception "):
12425
+ with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi "):
12417
12426
cu.foo(torch.tensor(0))
12418
12427
12419
12428
@torch.jit.script
@@ -12422,7 +12431,7 @@ def foo(cond):
12422
12431
12423
12432
foo(torch.tensor(1))
12424
12433
# we don't currently validate the name of the exception
12425
- with self.assertRaisesRegex(torch.jit.Error, "Exception "):
12434
+ with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi "):
12426
12435
foo(torch.tensor(0))
12427
12436
12428
12437
def test_python_op_exception(self):
@@ -13211,7 +13220,7 @@ def no_guard_ifs_added(x):
13211
13220
def no_ifs_added(x):
13212
13221
# type: (int) -> int
13213
13222
if x < 0:
13214
- raise RunTimeError ("hi")
13223
+ raise RuntimeError ("hi")
13215
13224
return x
13216
13225
13217
13226
self.checkScript(no_ifs_added, (1,))
@@ -13226,7 +13235,7 @@ def test_if_might(x):
13226
13235
else:
13227
13236
a = 2
13228
13237
else:
13229
- raise RunTimeError ("hi")
13238
+ raise RuntimeError ("hi")
13230
13239
return a + 2
13231
13240
13232
13241
self.checkScript(test_if_might, (1,))
@@ -13238,7 +13247,7 @@ def test_loop_no_escape(x):
13238
13247
# type: (int)
13239
13248
if x >= 0:
13240
13249
for i in range(x):
13241
- raise RunTimeError ("hi")
13250
+ raise RuntimeError ("hi")
13242
13251
else:
13243
13252
return 5
13244
13253
return x + 3
@@ -13255,7 +13264,7 @@ def test_loop_exception_with_continue(x):
13255
13264
i = 0
13256
13265
for i in range(5):
13257
13266
if i == x:
13258
- raise RunTimeError ("hi")
13267
+ raise RuntimeError ("hi")
13259
13268
else:
13260
13269
continue
13261
13270
print(i)
@@ -13272,7 +13281,7 @@ def no_return_func(self):
13272
13281
# type: (Tensor) -> Tensor
13273
13282
output = torch.tanh(self)
13274
13283
def backward(grad_output):
13275
- raise "Hi"
13284
+ raise RuntimeError( "Hi")
13276
13285
''')
13277
13286
with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13278
13287
cu = torch.jit.CompilationUnit(code)
@@ -13283,7 +13292,7 @@ def test_exit_pair_reset(x):
13283
13292
if x > 0:
13284
13293
a = 0
13285
13294
def backward(grad_output):
13286
- raise "Hi"
13295
+ raise RuntimeError( "Hi")
13287
13296
a = a + 1
13288
13297
else:
13289
13298
return x
0 commit comments