Skip to content

Commit a8e0b48

Browse files
committed
Fixed the issue in comments
1 parent ec2d674 commit a8e0b48

File tree

5 files changed

+246
-85
lines changed

5 files changed

+246
-85
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
1515
2. Save a Mutable Torch TensorRT Module
1616
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1718
"""
1819

1920
import numpy as np
@@ -63,16 +64,14 @@
6364
# Saving Mutable Torch TensorRT Module
6465
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6566

66-
# Currently, saving is only enabled for C++ runtime, not python runtime.
67+
# Currently, saving is only when "use_python" = False in settings
6768
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
6869
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
6970

7071
# %%
7172
# Stable Diffusion with Huggingface
7273
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7374

74-
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75-
7675
from diffusers import DiffusionPipeline
7776

7877
with torch.no_grad():
@@ -111,3 +110,45 @@
111110
# Refit triggered
112111
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
113112
image.save("./with_LoRA_mutable.jpg")
113+
114+
115+
# %%
116+
# Use Mutable Torch TensorRT module with dynamic shape
117+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
118+
class Model(torch.nn.Module):
119+
def __init__(self):
120+
super().__init__()
121+
122+
def forward(self, a, b, c={}):
123+
x = torch.matmul(a, b)
124+
x = torch.matmul(c["a"], c["b"].T)
125+
print(c["b"][0])
126+
x = 2 * c["b"]
127+
return x
128+
129+
130+
device = "cuda:0"
131+
model = Model().eval().to(device)
132+
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
133+
kwargs = {
134+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
135+
}
136+
dim_0 = torch.export.Dim("dim", min=1, max=50)
137+
dim_1 = torch.export.Dim("dim", min=1, max=50)
138+
dim_2 = torch.export.Dim("dim2", min=1, max=50)
139+
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
140+
kwarg_dynamic_shapes = {
141+
"c": {"a": {}, "b": {0: dim_2}},
142+
}
143+
# Export the model first with custom dynamic shape constraints
144+
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
145+
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
146+
# Compile
147+
model(*inputs, **kwargs)
148+
# Change input shape
149+
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
150+
kwargs_2 = {
151+
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
152+
}
153+
# Run without recompiling
154+
model(*inputs_2, **kwargs_2)

py/torch_tensorrt/dynamo/_refit.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,12 @@ def refit_module_weights(
395395
try:
396396
weight_name_map = compiled_submodule.weight_name_map
397397
except AttributeError:
398-
logger.warning(
399-
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
400-
)
398+
if not isinstance(
399+
compiled_submodule, torch.fx.graph_module.GraphModule
400+
):
401+
logger.warning(
402+
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
403+
)
401404
if not weight_name_map:
402405
use_weight_map_cache = False
403406
logger.warning(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:
375375

376376
@staticmethod
377377
def find_weight(
378-
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
378+
weight_name: str,
379+
np_map: dict[str, Any],
380+
state_dict: dict[str, Any],
381+
device: torch.device,
379382
) -> str:
380383
"""
381384
We need to build map from engine weight name to state_dict weight name.
@@ -385,19 +388,21 @@ def find_weight(
385388
np_map: the map from weight name to np values in INetworkDefinition
386389
state_dict: state of the graph module
387390
"""
388-
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
391+
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
389392
for sd_w_name, sd_weight in state_dict.items():
390-
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
393+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
391394
del state_dict[sd_w_name]
392395
return sd_w_name
393396
return ""
394397

395398
@staticmethod
396399
def check_weight_equal(
397-
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
400+
sd_weight: torch.tensor,
401+
network_weight: Union[torch.Tensor, np.ndarray],
402+
device: torch.device,
398403
) -> Any:
399404
if not isinstance(network_weight, torch.Tensor):
400-
network_weight = torch.from_numpy(network_weight).cuda()
405+
network_weight = torch.from_numpy(network_weight).to(device)
401406
try:
402407
return sd_weight.shape == network_weight.shape and torch.all(
403408
torch.abs(sd_weight - network_weight) < 0.01
@@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
530535
# There is no direct connection in batch_norm layer. So skip it
531536
pass
532537
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
533-
sd[sd_weight_name], np_map[engine_weight_name]
538+
sd[sd_weight_name], np_map[engine_weight_name], torch_device
534539
):
535540
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
536-
engine_weight_name, np_map, sd
541+
engine_weight_name, np_map, sd, torch_device
537542
)
538543
if (
539544
weight_name_map[engine_weight_name] != ""

0 commit comments

Comments
 (0)