Skip to content

Commit a59d92d

Browse files
committed
Fixed the issue in comments
1 parent ec2d674 commit a59d92d

File tree

4 files changed

+234
-78
lines changed

4 files changed

+234
-78
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
1515
2. Save a Mutable Torch TensorRT Module
1616
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1718
"""
1819

1920
import numpy as np
@@ -63,16 +64,14 @@
6364
# Saving Mutable Torch TensorRT Module
6465
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6566

66-
# Currently, saving is only enabled for C++ runtime, not python runtime.
67+
# Currently, saving is only when "use_python" = False in settings
6768
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
6869
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
6970

7071
# %%
7172
# Stable Diffusion with Huggingface
7273
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7374

74-
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75-
7675
from diffusers import DiffusionPipeline
7776

7877
with torch.no_grad():
@@ -111,3 +110,45 @@
111110
# Refit triggered
112111
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
113112
image.save("./with_LoRA_mutable.jpg")
113+
114+
115+
# %%
116+
# Use Mutable Torch TensorRT module with dynamic shape
117+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
118+
class Model(torch.nn.Module):
119+
def __init__(self):
120+
super().__init__()
121+
122+
def forward(self, a, b, c={}):
123+
x = torch.matmul(a, b)
124+
x = torch.matmul(c["a"], c["b"].T)
125+
print(c["b"][0])
126+
x = 2 * c["b"]
127+
return x
128+
129+
130+
device = "cuda:0"
131+
model = Model().eval().to(device)
132+
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
133+
kwargs = {
134+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
135+
}
136+
dim_0 = torch.export.Dim("dim", min=1, max=50)
137+
dim_1 = torch.export.Dim("dim", min=1, max=50)
138+
dim_2 = torch.export.Dim("dim2", min=1, max=50)
139+
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
140+
kwarg_dynamic_shapes = {
141+
"c": {"a": {}, "b": {0: dim_2}},
142+
}
143+
# Export the model first with custom dynamic shape constraints
144+
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
145+
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
146+
# Compile
147+
model(*inputs, **kwargs)
148+
# Change input shape
149+
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
150+
kwargs_2 = {
151+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
152+
}
153+
# Run without recompiling
154+
model(*inputs_2, **kwargs_2)

py/torch_tensorrt/dynamo/_refit.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,12 @@ def refit_module_weights(
395395
try:
396396
weight_name_map = compiled_submodule.weight_name_map
397397
except AttributeError:
398-
logger.warning(
399-
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
400-
)
398+
if not isinstance(
399+
compiled_submodule, torch.fx.graph_module.GraphModule
400+
):
401+
logger.warning(
402+
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
403+
)
401404
if not weight_name_map:
402405
use_weight_map_cache = False
403406
logger.warning(

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+115-37
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,11 @@ def __init__(
196196
}
197197
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198198
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
199-
self.total_dynamic_shape: Optional[dict[Any, Any]] = None
200199

201200
self.settings = CompilationSettings(**compilation_options)
202201
self.run_info: Optional[tuple[Any, ...]] = None
203202
self.state_dict_metadata: dict[str, torch.Size] = {}
204-
self.store_state_dict_metadata()
203+
self._store_state_dict_metadata()
205204

206205
cls = self.__class__
207206
self.__class__ = type(
@@ -211,11 +210,31 @@ def __init__(
211210
)
212211
self.init_finished = True
213212

214-
def set_dynamic_shape_hint(
213+
def set_expected_dynamic_shape_range(
215214
self,
216215
args_dynamic_shape: tuple[dict[Any, Any]],
217216
kwargs_dynamic_shape: dict[str, Any],
218217
) -> None:
218+
"""
219+
Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function
220+
and should not omit any entries. If the dynamic shape is not required for the input, an empty dictionary should be given
221+
as the shape hint for that input.
222+
223+
Example:
224+
def forward(a, b, c=0, d=0):
225+
pass
226+
227+
seq_len = torch.export.Dim("seq_len", min=1, max=10)
228+
args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
229+
kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
230+
# Later when you call the function
231+
forward(*(a, b), **{c:..., d:...})
232+
233+
234+
Arguments:
235+
args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
236+
kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
237+
"""
219238
assert isinstance(
220239
args_dynamic_shape, tuple
221240
), "args dynamic shape has to be a tuple"
@@ -224,19 +243,31 @@ def set_dynamic_shape_hint(
224243
), "args dynamic shape has to be a dictionary"
225244
self.kwarg_dynamic_shapes = kwargs_dynamic_shape
226245
self.arg_dynamic_shapes = args_dynamic_shape
227-
self.total_dynamic_shape = self.kwarg_dynamic_shapes.copy()
228-
signature = list(
229-
inspect.signature(self.original_model.forward).parameters.keys()
230-
)
231-
for i, arg in enumerate(self.arg_dynamic_shapes):
232-
self.total_dynamic_shape[signature[i]] = arg
233-
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
234246

235247
# Clear cached inputs
236248
self.arg_inputs = tuple()
237249
self.kwarg_inputs = {}
238250

239-
def store_state_dict_metadata(self) -> None:
251+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
252+
253+
def _get_total_dynamic_shapes(self) -> dict[str, Any] | None:
254+
if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes:
255+
return None
256+
total_dynamic_shape = {}
257+
if self.arg_dynamic_shapes:
258+
signature = list(
259+
inspect.signature(self.original_model.forward).parameters.keys()
260+
)
261+
for i, arg in enumerate(self.arg_dynamic_shapes):
262+
total_dynamic_shape[signature[i]] = arg
263+
264+
if self.kwarg_dynamic_shapes:
265+
for kwargs, kwargs_dynamic_shape in self.kwarg_dynamic_shapes.items():
266+
total_dynamic_shape[kwargs] = kwargs_dynamic_shape
267+
268+
return total_dynamic_shape
269+
270+
def _store_state_dict_metadata(self) -> None:
240271
for k, v in self.original_model.state_dict().items():
241272
self.state_dict_metadata[k] = v.shape
242273

@@ -328,7 +359,7 @@ def compile(self) -> None:
328359
self.original_model,
329360
self.arg_inputs,
330361
kwargs=self.kwarg_inputs,
331-
dynamic_shapes=self.total_dynamic_shape,
362+
dynamic_shapes=self._get_total_dynamic_shapes(),
332363
)
333364
self.gm = dynamo_compile(
334365
self.exp_program,
@@ -340,40 +371,75 @@ def compile(self) -> None:
340371
torch.cuda.empty_cache()
341372

342373
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
374+
375+
if not self.arg_inputs:
376+
logger.info("First time compilation initiated. This may take some time.")
377+
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
378+
self._store_inputs(args, kwargs)
379+
if self.arg_dynamic_shapes or self.kwarg_dynamic_shapes:
380+
if not self._validates_dynamic_hints():
381+
logger.warning(
382+
"Invalid dynamic shape hint. Compiling module for the provided input shapes (static)"
383+
)
384+
self.arg_dynamic_shapes = None
385+
self.kwarg_dynamic_shapes = None
386+
return
387+
388+
# If input does not equal or does not fall into dynamic shape range, recompile the engine
343389
try:
344-
if (
345-
not self.arg_inputs
346-
or not MutableTorchTensorRTModule.check_inputs_equal(
347-
self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes
348-
)
349-
or not MutableTorchTensorRTModule.check_inputs_equal(
350-
self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes
351-
)
390+
if not MutableTorchTensorRTModule._check_inputs_shape(
391+
self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes
392+
) or not MutableTorchTensorRTModule._check_inputs_shape(
393+
self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes
352394
):
353395
logger.info("Input change detected.")
354396
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
355-
self.store_inputs(args, kwargs)
397+
self._store_inputs(args, kwargs)
356398
except DynamicShapeOutOfRangeException as e:
357399
logger.info("Input change detected.")
358400
logger.warning(e)
359-
logger.warning("Recompiling the engine with static shape")
401+
logger.warning(
402+
"Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)"
403+
)
360404
self.arg_dynamic_shapes = None
361405
self.kwarg_dynamic_shapes = None
362-
self.total_dynamic_shape = None
363406
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
364-
self.store_inputs(args, kwargs)
407+
self._store_inputs(args, kwargs)
365408

366-
def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
409+
def _validates_dynamic_hints(self) -> bool:
410+
if self.arg_dynamic_shapes is None:
411+
if self.arg_inputs:
412+
logger.warning("arg_dynamic_shape is not provided!")
413+
else:
414+
if len(self.arg_dynamic_shapes) != len(self.arg_inputs):
415+
logger.warning(
416+
f"Warning: The length of arg_inputs is {len(self.arg_inputs)} but the length of arg_dynamic_shape is {len(self.arg_dynamic_shapes)}!"
417+
)
418+
return False
419+
420+
if self.kwarg_dynamic_shapes is None:
421+
if self.kwarg_inputs:
422+
logger.warning("kwarg_dynamic_shape is not provided!")
423+
else:
424+
if self.kwarg_dynamic_shapes.keys() != self.kwarg_inputs.keys():
425+
logger.warning(
426+
f"kwarg_inputs has {list(self.kwarg_inputs.keys())} but kwarg_dynamic_shape has {list(self.kwarg_dynamic_shapes.keys())}!"
427+
)
428+
return False
429+
430+
return True
431+
432+
def _store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
367433
self.arg_inputs = arg_inputs
368434
self.kwarg_inputs = kwarg_inputs
369435

370436
@staticmethod
371-
def process_kwarg_inputs(inputs: Any) -> Any:
437+
def _process_kwarg_inputs(inputs: Any) -> Any:
372438
# Process kwarg inputs to be acceptable for Torch-TensorRT
373439
if isinstance(inputs, dict):
374440
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
375441
return {
376-
k: MutableTorchTensorRTModule.process_kwarg_inputs(v)
442+
k: MutableTorchTensorRTModule._process_kwarg_inputs(v)
377443
for k, v in inputs.items()
378444
if (v is not None and not isinstance(v, bool))
379445
}
@@ -384,7 +450,10 @@ def process_kwarg_inputs(inputs: Any) -> Any:
384450
elif isinstance(inputs, (list, tuple)):
385451
if None not in inputs:
386452
return type(inputs)(
387-
[MutableTorchTensorRTModule.process_kwarg_inputs(v) for v in inputs]
453+
[
454+
MutableTorchTensorRTModule._process_kwarg_inputs(v)
455+
for v in inputs
456+
]
388457
)
389458

390459
raise ValueError(
@@ -394,7 +463,7 @@ def process_kwarg_inputs(inputs: Any) -> Any:
394463

395464
def forward(self, *args: Any, **kwargs: Any) -> Any:
396465
# Step 1: Check whether the input shape has changed
397-
kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs)
466+
kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs)
398467
self._validate_inputs(*args, **kwargs)
399468

400469
# Step 2: If the flag is unknown, it could be a recompile or refit.
@@ -406,7 +475,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
406475
if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE:
407476
logger.info("(Re)Compiling the engine...")
408477
self.compile()
409-
self.store_state_dict_metadata()
478+
self._store_state_dict_metadata()
410479
self.refit_state.set_state(RefitFlag.LIVE)
411480

412481
elif self.refit_state.get_state() == RefitFlag.NEEDS_REFIT:
@@ -417,7 +486,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
417486
logger.error(e)
418487
logger.error("Model refit failed. Recompiling the graph module.")
419488
self.compile()
420-
self.store_state_dict_metadata()
489+
self._store_state_dict_metadata()
421490
self.refit_state.set_state(RefitFlag.LIVE)
422491

423492
result = self.gm(*args, **kwargs)
@@ -427,7 +496,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
427496

428497
def to(self, device: str) -> None:
429498
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
430-
self.orignial_model.to(device)
499+
self.original_model.to(device)
431500

432501
def __deepcopy__(self, memo: Any) -> Any:
433502
cls = self.__class__
@@ -479,7 +548,7 @@ def __setattr__(self, name: str, value: Any) -> None:
479548
object.__setattr__(self, name, value)
480549

481550
@staticmethod
482-
def check_inputs_equal(
551+
def _check_inputs_shape(
483552
input1: Any,
484553
input2: Any,
485554
dynamic_shapes: Any = None,
@@ -495,10 +564,13 @@ def check_inputs_equal(
495564
return False
496565
elif isinstance(a, torch.Tensor) and a.shape != b.shape:
497566
if dynamic_shapes is None:
567+
logger.warning(
568+
"Dynamic shape is not properly set but the input shape is changed!"
569+
)
498570
return False
499571
else:
500572
tensor_dynamic_shape = dynamic_shapes[i]
501-
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
573+
if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes(
502574
a, b, tensor_dynamic_shape
503575
):
504576
return False
@@ -513,28 +585,34 @@ def check_inputs_equal(
513585
return False
514586
elif isinstance(va, torch.Tensor) and va.shape != vb.shape:
515587
if dynamic_shapes is None:
588+
logger.warning(
589+
"Dynamic shape is not properly set but the input shape is changed!"
590+
)
516591
return False
517592
else:
518593
tensor_dynamic_shape = dynamic_shapes[ka]
519-
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
594+
if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes(
520595
va, vb, tensor_dynamic_shape
521596
):
522597
return False
523598
elif isinstance(
524599
va, (list, tuple, dict)
525-
) and not MutableTorchTensorRTModule.check_inputs_equal(
600+
) and not MutableTorchTensorRTModule._check_inputs_shape(
526601
va, vb, dynamic_shapes[ka] if dynamic_shapes else None
527602
):
528603
return False
529604
return True
530605

531606
@staticmethod
532-
def check_tensor_shapes_with_dynamic_shapes(
607+
def _check_tensor_shapes_with_dynamic_shapes(
533608
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
534609
) -> bool:
535610
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
536611
if axis_0 != axis_1:
537612
if i not in dynamic_shape:
613+
logger.warning(
614+
"Dynamic shape does not include the axis on which input changes!"
615+
)
538616
return False
539617
dyn = dynamic_shape[i]
540618
if axis_1 > dyn.max or axis_1 < dyn.min:

0 commit comments

Comments
 (0)