Skip to content

Commit 0ca42f3

Browse files
committed
Fixed the issue in comments
1 parent ec2d674 commit 0ca42f3

File tree

3 files changed

+192
-113
lines changed

3 files changed

+192
-113
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

+102-67
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
1515
2. Save a Mutable Torch TensorRT Module
1616
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1718
"""
1819

1920
import numpy as np
@@ -25,89 +26,123 @@
2526
torch.manual_seed(5)
2627
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
2728

28-
# %%
29-
# Initialize the Mutable Torch TensorRT Module with settings.
30-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31-
settings = {
32-
"use_python": False,
33-
"enabled_precisions": {torch.float32},
34-
"immutable_weights": False,
35-
}
29+
# # %%
30+
# # Initialize the Mutable Torch TensorRT Module with settings.
31+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
# settings = {
33+
# "use_python": False,
34+
# "enabled_precisions": {torch.float32},
35+
# "immutable_weights": False,
36+
# }
3637

37-
model = models.resnet18(pretrained=True).eval().to("cuda")
38-
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
39-
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
40-
mutable_module(*inputs)
38+
# model = models.resnet18(pretrained=True).eval().to("cuda")
39+
# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
40+
# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
41+
# mutable_module(*inputs)
4142

42-
# %%
43-
# Make modifications to the mutable module.
44-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
43+
# # %%
44+
# # Make modifications to the mutable module.
45+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4546

46-
# %%
47-
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
48-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
49-
mutable_module.load_state_dict(model2.state_dict())
47+
# # %%
48+
# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
49+
# model2 = models.resnet18(pretrained=False).eval().to("cuda")
50+
# mutable_module.load_state_dict(model2.state_dict())
5051

5152

52-
# Check the output
53-
# The refit happens while you call the mutable module again.
54-
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
55-
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
56-
assert torch.allclose(
57-
expected_output, refitted_output, 1e-2, 1e-2
58-
), "Refit Result is not correct. Refit failed"
53+
# # Check the output
54+
# # The refit happens while you call the mutable module again.
55+
# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
56+
# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
57+
# assert torch.allclose(
58+
# expected_output, refitted_output, 1e-2, 1e-2
59+
# ), "Refit Result is not correct. Refit failed"
5960

60-
print("Refit successfully!")
61+
# print("Refit successfully!")
6162

62-
# %%
63-
# Saving Mutable Torch TensorRT Module
64-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63+
# # %%
64+
# # Saving Mutable Torch TensorRT Module
65+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6566

66-
# Currently, saving is only enabled for C++ runtime, not python runtime.
67-
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
68-
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
67+
# # Currently, saving is only when "use_python" = False in settings
68+
# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
69+
# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
6970

70-
# %%
71-
# Stable Diffusion with Huggingface
72-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
71+
# # %%
72+
# # Stable Diffusion with Huggingface
73+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7374

74-
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75+
# # The LoRA checkpoint is from https://civitai.com/models/12597/moxin
7576

76-
from diffusers import DiffusionPipeline
77+
# from diffusers import DiffusionPipeline
7778

78-
with torch.no_grad():
79-
settings = {
80-
"use_python_runtime": True,
81-
"enabled_precisions": {torch.float16},
82-
"debug": True,
83-
"immutable_weights": False,
84-
}
79+
# with torch.no_grad():
80+
# settings = {
81+
# "use_python_runtime": True,
82+
# "enabled_precisions": {torch.float16},
83+
# "debug": True,
84+
# "immutable_weights": False,
85+
# }
8586

86-
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
87-
device = "cuda:0"
87+
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
88+
# device = "cuda:0"
8889

89-
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
90-
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
90+
# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
91+
# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
9192

92-
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
93-
pipe.to(device)
93+
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
94+
# pipe.to(device)
9495

95-
# The only extra line you need
96-
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
96+
# # The only extra line you need
97+
# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
9798

98-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
99-
image.save("./without_LoRA_mutable.jpg")
99+
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
100+
# image.save("./without_LoRA_mutable.jpg")
100101

101-
# Standard Huggingface LoRA loading procedure
102-
pipe.load_lora_weights(
103-
"stablediffusionapi/load_lora_embeddings",
104-
weight_name="all-disney-princess-xl-lo.safetensors",
105-
adapter_name="lora1",
106-
)
107-
pipe.set_adapters(["lora1"], adapter_weights=[1])
108-
pipe.fuse_lora()
109-
pipe.unload_lora_weights()
102+
# # Standard Huggingface LoRA loading procedure
103+
# pipe.load_lora_weights(
104+
# "stablediffusionapi/load_lora_embeddings",
105+
# weight_name="all-disney-princess-xl-lo.safetensors",
106+
# adapter_name="lora1",
107+
# )
108+
# pipe.set_adapters(["lora1"], adapter_weights=[1])
109+
# pipe.fuse_lora()
110+
# pipe.unload_lora_weights()
110111

111-
# Refit triggered
112-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
113-
image.save("./with_LoRA_mutable.jpg")
112+
# # Refit triggered
113+
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
114+
# image.save("./with_LoRA_mutable.jpg")
115+
116+
117+
# %%
118+
# Use Mutable Torch TensorRT module with dynamic shape
119+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
class Model(torch.nn.Module):
121+
def __init__(self):
122+
super().__init__()
123+
124+
def forward(self, a, b, c={}):
125+
x = torch.matmul(a, b)
126+
x = torch.matmul(c["a"], c["b"].T)
127+
x = 2 * c["b"][1]
128+
return x
129+
130+
131+
model = Model().eval().cuda()
132+
inputs = (torch.rand(10, 3), torch.rand(3, 30))
133+
kwargs = {
134+
"c": {"a": torch.rand(10, 30), "b": torch.rand(10, 30)},
135+
}
136+
137+
dim = torch.export.Dim("dim", min=1, max=50)
138+
dim2 = torch.export.Dim("dim2", min=1, max=50)
139+
args_dynamic_shapes = ({1: dim}, {0: dim})
140+
kwarg_dynamic_shapes = {
141+
"c": {"a": {}, "b": {0: dim2}},
142+
}
143+
# Export the model first with custom dynamic shape constraints
144+
# exp_program = torch.export.export(model, tuple(inputs), kwargs=k
145+
trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True)
146+
trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
147+
# Run inference
148+
trt_gm(*inputs, **kwargs)

0 commit comments

Comments
 (0)