Skip to content

Commit d3b94c2

Browse files
committed
Fixed the comments
1 parent 3375f49 commit d3b94c2

File tree

3 files changed

+11
-16
lines changed

3 files changed

+11
-16
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
# Saving Mutable Torch TensorRT Module
6666
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6767

68-
# Currently, saving is only when "use_python" = False in settings
68+
# Currently, saving is only enabled when "use_python_runtime" = False in settings
6969
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
7070
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
7171

@@ -201,24 +201,18 @@ def forward(self, a, b, c={}):
201201
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
202202

203203
model = models.resnet18(pretrained=True).eval().to("cuda")
204-
enabled_precisions = {torch.float}
205-
debug = False
206-
min_block_size = 1
207-
use_python_runtime = True
208204

209205
times = []
210206
start = torch.cuda.Event(enable_timing=True)
211207
end = torch.cuda.Event(enable_timing=True)
212208

213-
214209
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
215-
# Mark the dim0 of inputs as dynamic
216210
model = torch_trt.MutableTorchTensorRTModule(
217211
model,
218-
use_python_runtime=use_python_runtime,
219-
enabled_precisions=enabled_precisions,
220-
debug=debug,
221-
min_block_size=min_block_size,
212+
use_python_runtime=True,
213+
enabled_precisions={torch.float},
214+
debug=True,
215+
min_block_size=1,
222216
immutable_weights=False,
223217
cache_built_engines=True,
224218
reuse_cached_engines=True,

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,21 @@ def forward(a, b, c=0, d=0):
228228
seq_len = torch.export.Dim("seq_len", min=1, max=10)
229229
args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
230230
kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
231+
set_expected_dynamic_shape_range(args_dynamic_shape, kwargs_dynamic_shape)
231232
# Later when you call the function
232233
forward(*(a, b), **{c:..., d:...})
233234
234-
235+
Reference: https://pytorch.org/docs/stable/export.html#expressing-dynamism
235236
Arguments:
236237
args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
237238
kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
238239
"""
239240
assert isinstance(
240241
args_dynamic_shape, tuple
241-
), "args dynamic shape has to be a tuple"
242+
), f"args dynamic shape has to be a tuple, but got {type(args_dynamic_shape)}"
242243
assert isinstance(
243244
kwargs_dynamic_shape, dict
244-
), "args dynamic shape has to be a dictionary"
245+
), f"args dynamic shape has to be a dictionary, but got {type(kwargs_dynamic_shape)}"
245246
self.kwarg_dynamic_shapes = kwargs_dynamic_shape
246247
self.arg_dynamic_shapes = args_dynamic_shape
247248

tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def test_check_input_shape_dynamic():
6666
dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}}
6767
assertions.assertFalse(
6868
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b),
69-
msg=f"test_check_output_equal is not correct.",
69+
msg=f"test_check_input_shape_dynamic is not correct.",
7070
)
7171
assertions.assertTrue(
7272
torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b, dynamic_shape),
73-
msg=f"test_check_output_equal is not correct.",
73+
msg=f"test_check_input_shape_dynamic is not correct.",
7474
)
7575

7676

0 commit comments

Comments
 (0)