Skip to content

Commit f183050

Browse files
committed
Added dynamic shape support to MutableTorchTensorRTModule
1 parent 1e356e4 commit f183050

File tree

2 files changed

+215
-20
lines changed

2 files changed

+215
-20
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+95-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
from copy import deepcopy
34
from enum import Enum, auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
4142
return self._state
4243

4344

45+
class DynamicShapeOutOfRangeException(Exception):
46+
pass
47+
48+
4449
class MutableTorchTensorRTModule(object):
4550
"""
4651
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
6570
Union[torch.dtype, dtype]
6671
] = _defaults.ENABLED_PRECISIONS,
6772
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
68-
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
73+
immutable_weights: bool = False,
6974
debug: bool = _defaults.DEBUG,
7075
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
7176
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -189,6 +194,9 @@ def __init__(
189194
"hardware_compatible": hardware_compatible,
190195
"timing_cache_path": timing_cache_path,
191196
}
197+
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198+
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
199+
self.total_dynamic_shape: Optional[dict[Any, Any]] = None
192200

193201
self.settings = CompilationSettings(**compilation_options)
194202
self.run_info: Optional[tuple[Any, ...]] = None
@@ -203,6 +211,27 @@ def __init__(
203211
)
204212
self.init_finished = True
205213

214+
def set_dynamic_shape_hint(
215+
self,
216+
args_dynamic_shape: tuple[dict[Any, Any]],
217+
kwargs_dynamic_shape: dict[str, Any],
218+
) -> None:
219+
assert isinstance(
220+
args_dynamic_shape, tuple
221+
), "args dynamic shape has to be a tuple"
222+
assert isinstance(
223+
kwargs_dynamic_shape, dict
224+
), "args dynamic shape has to be a dictionary"
225+
self.kwarg_dynamic_shapes = kwargs_dynamic_shape
226+
self.arg_dynamic_shapes = args_dynamic_shape
227+
self.total_dynamic_shape = self.kwarg_dynamic_shapes.copy()
228+
signature = list(
229+
inspect.signature(self.original_model.forward).parameters.keys()
230+
)
231+
for i, arg in enumerate(self.arg_dynamic_shapes):
232+
self.total_dynamic_shape[signature[i]] = arg
233+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
234+
206235
def store_state_dict_metadata(self) -> None:
207236
for k, v in self.original_model.state_dict().items():
208237
self.state_dict_metadata[k] = v.shape
@@ -295,6 +324,7 @@ def compile(self) -> None:
295324
self.original_model,
296325
self.arg_inputs,
297326
kwargs=self.kwarg_inputs,
327+
dynamic_shapes=self.total_dynamic_shape,
298328
)
299329
self.gm = dynamo_compile(
300330
self.exp_program,
@@ -306,14 +336,26 @@ def compile(self) -> None:
306336
torch.cuda.empty_cache()
307337

308338
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
309-
if (
310-
not self.arg_inputs
311-
or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args)
312-
or not MutableTorchTensorRTModule.check_inputs_equal(
313-
self.kwarg_inputs, kwargs
314-
)
315-
):
339+
try:
340+
if (
341+
not self.arg_inputs
342+
or not MutableTorchTensorRTModule.check_inputs_equal(
343+
self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes
344+
)
345+
or not MutableTorchTensorRTModule.check_inputs_equal(
346+
self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes
347+
)
348+
):
349+
logger.info("Input change detected.")
350+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
351+
self.store_inputs(args, kwargs)
352+
except DynamicShapeOutOfRangeException as e:
316353
logger.info("Input change detected.")
354+
logger.warning(e)
355+
logger.warning("Recompiling the engine with static shape")
356+
self.arg_dynamic_shapes = None
357+
self.kwarg_dynamic_shapes = None
358+
self.total_dynamic_shape = None
317359
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
318360
self.store_inputs(args, kwargs)
319361

@@ -436,33 +478,66 @@ def __setattr__(self, name: str, value: Any) -> None:
436478
def check_inputs_equal(
437479
input1: Any,
438480
input2: Any,
481+
dynamic_shapes: Any = None,
439482
) -> bool:
440-
# TODO: Add support for dynamic shape
483+
441484
if isinstance(input1, (tuple, list)):
442485
if len(input1) != len(input2):
443486
return False
444-
for a, b in zip(input1, input2):
487+
for (i, a), b in zip(enumerate(input1), input2):
445488
if type(a) != type(b):
446489
return False
447-
if isinstance(a, torch.Tensor) and a.shape != b.shape:
448-
return False
449-
elif isinstance(a, bool) and a != b:
490+
if isinstance(a, bool) and a != b:
450491
return False
492+
elif isinstance(a, torch.Tensor) and a.shape != b.shape:
493+
if dynamic_shapes is None:
494+
return False
495+
else:
496+
tensor_dynamic_shape = dynamic_shapes[i]
497+
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
498+
a, b, tensor_dynamic_shape
499+
):
500+
return False
451501

452502
elif isinstance(input1, dict):
453503
if input1.keys() != input2.keys():
454504
return False
455-
for a, b in zip(input1.values(), input2.values()):
456-
if type(a) != type(b):
505+
for (ka, va), vb in zip(input1.items(), input2.values()):
506+
if type(va) != type(vb):
457507
return False
458-
if isinstance(a, torch.Tensor) and a.shape != b.shape:
459-
return False
460-
elif isinstance(a, bool) and a != b:
508+
if isinstance(va, bool) and va != vb:
461509
return False
510+
elif isinstance(va, torch.Tensor) and va.shape != vb.shape:
511+
if dynamic_shapes is None:
512+
return False
513+
else:
514+
tensor_dynamic_shape = dynamic_shapes[ka]
515+
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
516+
va, vb, tensor_dynamic_shape
517+
):
518+
return False
462519
elif isinstance(
463-
a, (list, tuple, dict)
464-
) and not MutableTorchTensorRTModule.check_inputs_equal(a, b):
520+
va, (list, tuple, dict)
521+
) and not MutableTorchTensorRTModule.check_inputs_equal(
522+
va, vb, dynamic_shapes[ka] if dynamic_shapes else None
523+
):
524+
return False
525+
return True
526+
527+
@staticmethod
528+
def check_tensor_shapes_with_dynamic_shapes(
529+
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
530+
) -> bool:
531+
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
532+
if axis_0 != axis_1:
533+
if i not in dynamic_shape:
465534
return False
535+
dyn = dynamic_shape[i]
536+
if axis_1 > dyn.max or axis_1 < dyn.min:
537+
raise DynamicShapeOutOfRangeException(
538+
f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!"
539+
)
540+
466541
return True
467542

468543
@staticmethod

tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

+120
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,126 @@ def test_check_output_equal():
3636
)
3737

3838

39+
@pytest.mark.unit
40+
def test_check_input_shape_dynamic():
41+
torch.manual_seed(0)
42+
a = {
43+
"a": torch.rand(10, 3),
44+
"b": [torch.rand(10, 30), torch.rand(5, 5)],
45+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
46+
}
47+
torch.manual_seed(0)
48+
b = {
49+
"a": torch.rand(10, 30),
50+
"b": [torch.rand(10, 30), torch.rand(5, 5)],
51+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
52+
}
53+
54+
dim = torch.export.Dim("dim", min=1, max=50)
55+
dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}}
56+
assertions.assertFalse(
57+
torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b),
58+
msg=f"test_check_output_equal is not correct.",
59+
)
60+
assertions.assertTrue(
61+
torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b, dynamic_shape),
62+
msg=f"test_check_output_equal is not correct.",
63+
)
64+
65+
66+
@pytest.mark.unit
67+
def test_model_complex_dynamic_shape():
68+
class Model(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
72+
def forward(self, a, b, c=None):
73+
x = torch.matmul(a, b)
74+
x = torch.matmul(c["a"], c["b"][0].T)
75+
x = 2 * c["b"][1]
76+
return x
77+
78+
model = Model().eval().cuda()
79+
inputs = [torch.rand(10, 3)]
80+
kwargs = {
81+
"b": torch.rand(3, 30),
82+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 3)]},
83+
}
84+
85+
dim = torch.export.Dim("dim", min=1, max=50)
86+
dim2 = torch.export.Dim("dim2", min=1, max=50)
87+
args_dynamic_shapes = ({1: dim},)
88+
kwarg_dynamic_shapes = {
89+
"b": {0: dim},
90+
"c": {"a": {}, "b": [{}, {1: dim2}]},
91+
}
92+
# Export the model first with custom dynamic shape constraints
93+
# exp_program = torch.export.export(model, tuple(inputs), kwargs=k
94+
trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True)
95+
trt_gm.set_dynamic_shape_hint(args_dynamic_shapes, kwarg_dynamic_shapes)
96+
# Run inference
97+
trt_gm(*inputs, **kwargs)
98+
99+
inputs_2 = [torch.rand(10, 9)]
100+
kwargs_2 = {
101+
"b": torch.rand(9, 30),
102+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
103+
}
104+
105+
kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_2)
106+
trt_gm._validate_inputs(*inputs_2, **kwargs_2)
107+
assertions.assertTrue(
108+
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
109+
msg=f"Dynamic shape support is not correct.",
110+
)
111+
trt_gm(*inputs_2, **kwargs_2)
112+
113+
# Change does not align with Dynamic Shape Hint
114+
inputs_3 = [torch.rand(7, 9)]
115+
kwargs_3 = {
116+
"b": torch.rand(9, 30),
117+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
118+
}
119+
120+
kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_3)
121+
trt_gm._validate_inputs(*inputs_3, **kwargs_3)
122+
assertions.assertTrue(
123+
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
124+
msg=f"Dynamic shape support is not correct.",
125+
)
126+
trt_gm(*inputs_3, **kwargs_3)
127+
128+
# # Stored input is changed (inputs first dimension is 7)
129+
inputs_4 = [torch.rand(7, 20)]
130+
kwargs_4 = {
131+
"b": torch.rand(20, 30),
132+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
133+
}
134+
135+
kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_4)
136+
trt_gm._validate_inputs(*inputs_4, **kwargs_4)
137+
assertions.assertTrue(
138+
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
139+
msg=f"Dynamic shape support is not correct.",
140+
)
141+
trt_gm(*inputs_4, **kwargs_4)
142+
143+
# # Change outside of the dynamic range limit
144+
inputs_5 = [torch.rand(7, 900)]
145+
kwargs_5 = {
146+
"b": torch.rand(900, 30),
147+
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
148+
}
149+
150+
kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_5)
151+
trt_gm._validate_inputs(*inputs_5, **kwargs_5)
152+
assertions.assertTrue(
153+
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
154+
msg=f"Dynamic shape support is not correct.",
155+
)
156+
trt_gm(*inputs_5, **kwargs_5)
157+
158+
39159
@unittest.skipIf(
40160
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
41161
"TorchScript Frontend is not available",

0 commit comments

Comments
 (0)