Skip to content

Commit eee7dad

Browse files
vkuzofacebook-github-bot
authored andcommittedSep 25, 2020
Add torch.do_assert, which is symbolically traceable (pytorch#45188)
Summary: Pull Request resolved: pytorch#45188 This is a symbolically traceable alternative to Python's `assert`. It should be useful to allow people who want to use FX to also be able to assert things. A bunch of TODO(before) land are inline - would love thoughts on where is the best place for this code to live, and what this function should be called (since `assert` is reserved). Test Plan: ``` python test/test_fx.py TestFX.test_symbolic_trace_assert ``` Imported from OSS Reviewed By: jamesr66a Differential Revision: D23861567 fbshipit-source-id: d9d6b9556140faccc0290eba1fabea401d7850de
1 parent 7c5436d commit eee7dad

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed
 

‎docs/source/torch.rst

+1
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,4 @@ Utilities
536536
set_deterministic
537537
is_deterministic
538538
vmap
539+
Assert

‎test/test_fx.py

+16
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,22 @@ def test_construct_root_dict(self):
603603
ref_out : torch.Tensor = linear_mod(x) + add_param
604604
self.assertEqual(out, ref_out)
605605

606+
def test_symbolic_trace_assert(self):
607+
message = "assert_foobar"
608+
609+
class AssertsTensorShape(torch.nn.Module):
610+
def forward(self, x):
611+
torch.Assert(x.shape[1] > 4, message)
612+
return x
613+
614+
m = AssertsTensorShape()
615+
# verify traceability
616+
traced = symbolic_trace(m)
617+
# verify assertion on traced model works correctly at runtime
618+
traced(torch.rand(4, 5))
619+
with self.assertRaisesRegex(AssertionError, message):
620+
traced(torch.rand(4, 3))
621+
606622

607623
if __name__ == '__main__':
608624
run_tests()

‎test/test_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -795,5 +795,13 @@ def test_fuzzer(self):
795795
x, torch.Tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
796796

797797

798+
class TestAssert(TestCase):
799+
def test_assert_true(self):
800+
# verify assertions work as expected
801+
torch.Assert(True, "foo")
802+
with self.assertRaisesRegex(AssertionError, "bar"):
803+
torch.Assert(False, "bar")
804+
805+
798806
if __name__ == '__main__':
799807
run_tests()

‎torch/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,12 @@ def compiled_with_cxx11_abi():
612612
# class usage. We add these lines here to preserve backward compatbility.
613613
quantized_lstm = torch.ops.aten.quantized_lstm
614614
quantized_gru = torch.ops.aten.quantized_gru
615+
616+
from .overrides import has_torch_function, handle_torch_function
617+
618+
def Assert(condition, message):
619+
r"""A wrapper around Python's assert which is symbolically traceable.
620+
"""
621+
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
622+
return handle_torch_function(Assert, (condition,), condition, message)
623+
assert condition, message

0 commit comments

Comments
 (0)