|
14 | 14 | 1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
|
15 | 15 | 2. Save a Mutable Torch TensorRT Module
|
16 | 16 | 3. Integration with Huggingface pipeline in LoRA use case
|
| 17 | +4. Usage of dynamic shape with Mutable Torch TensorRT Module |
17 | 18 | """
|
18 | 19 |
|
19 | 20 | import numpy as np
|
|
25 | 26 | torch.manual_seed(5)
|
26 | 27 | inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
|
27 | 28 |
|
28 |
| -# %% |
29 |
| -# Initialize the Mutable Torch TensorRT Module with settings. |
30 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
31 |
| -settings = { |
32 |
| - "use_python": False, |
33 |
| - "enabled_precisions": {torch.float32}, |
34 |
| - "immutable_weights": False, |
35 |
| -} |
| 29 | +# # %% |
| 30 | +# # Initialize the Mutable Torch TensorRT Module with settings. |
| 31 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 32 | +# settings = { |
| 33 | +# "use_python": False, |
| 34 | +# "enabled_precisions": {torch.float32}, |
| 35 | +# "immutable_weights": False, |
| 36 | +# } |
36 | 37 |
|
37 |
| -model = models.resnet18(pretrained=True).eval().to("cuda") |
38 |
| -mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
39 |
| -# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
40 |
| -mutable_module(*inputs) |
| 38 | +# model = models.resnet18(pretrained=True).eval().to("cuda") |
| 39 | +# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
| 40 | +# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
| 41 | +# mutable_module(*inputs) |
41 | 42 |
|
42 |
| -# %% |
43 |
| -# Make modifications to the mutable module. |
44 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 43 | +# # %% |
| 44 | +# # Make modifications to the mutable module. |
| 45 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
45 | 46 |
|
46 |
| -# %% |
47 |
| -# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
48 |
| -model2 = models.resnet18(pretrained=False).eval().to("cuda") |
49 |
| -mutable_module.load_state_dict(model2.state_dict()) |
| 47 | +# # %% |
| 48 | +# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
| 49 | +# model2 = models.resnet18(pretrained=False).eval().to("cuda") |
| 50 | +# mutable_module.load_state_dict(model2.state_dict()) |
50 | 51 |
|
51 | 52 |
|
52 |
| -# Check the output |
53 |
| -# The refit happens while you call the mutable module again. |
54 |
| -expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
55 |
| -for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
56 |
| - assert torch.allclose( |
57 |
| - expected_output, refitted_output, 1e-2, 1e-2 |
58 |
| - ), "Refit Result is not correct. Refit failed" |
| 53 | +# # Check the output |
| 54 | +# # The refit happens while you call the mutable module again. |
| 55 | +# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
| 56 | +# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
| 57 | +# assert torch.allclose( |
| 58 | +# expected_output, refitted_output, 1e-2, 1e-2 |
| 59 | +# ), "Refit Result is not correct. Refit failed" |
59 | 60 |
|
60 |
| -print("Refit successfully!") |
| 61 | +# print("Refit successfully!") |
61 | 62 |
|
62 |
| -# %% |
63 |
| -# Saving Mutable Torch TensorRT Module |
64 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 63 | +# # %% |
| 64 | +# # Saving Mutable Torch TensorRT Module |
| 65 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
65 | 66 |
|
66 |
| -# Currently, saving is only enabled for C++ runtime, not python runtime. |
67 |
| -torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
68 |
| -reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
| 67 | +# # Currently, saving is only when "use_python" = False in settings |
| 68 | +# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
| 69 | +# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
69 | 70 |
|
70 |
| -# %% |
71 |
| -# Stable Diffusion with Huggingface |
72 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 71 | +# # %% |
| 72 | +# # Stable Diffusion with Huggingface |
| 73 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
73 | 74 |
|
74 |
| -# The LoRA checkpoint is from https://civitai.com/models/12597/moxin |
| 75 | +# # The LoRA checkpoint is from https://civitai.com/models/12597/moxin |
75 | 76 |
|
76 |
| -from diffusers import DiffusionPipeline |
| 77 | +# from diffusers import DiffusionPipeline |
77 | 78 |
|
78 |
| -with torch.no_grad(): |
79 |
| - settings = { |
80 |
| - "use_python_runtime": True, |
81 |
| - "enabled_precisions": {torch.float16}, |
82 |
| - "debug": True, |
83 |
| - "immutable_weights": False, |
84 |
| - } |
| 79 | +# with torch.no_grad(): |
| 80 | +# settings = { |
| 81 | +# "use_python_runtime": True, |
| 82 | +# "enabled_precisions": {torch.float16}, |
| 83 | +# "debug": True, |
| 84 | +# "immutable_weights": False, |
| 85 | +# } |
85 | 86 |
|
86 |
| - model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
87 |
| - device = "cuda:0" |
| 87 | +# model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
| 88 | +# device = "cuda:0" |
88 | 89 |
|
89 |
| - prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
90 |
| - negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
| 90 | +# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
| 91 | +# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
91 | 92 |
|
92 |
| - pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
93 |
| - pipe.to(device) |
| 93 | +# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| 94 | +# pipe.to(device) |
94 | 95 |
|
95 |
| - # The only extra line you need |
96 |
| - pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
| 96 | +# # The only extra line you need |
| 97 | +# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
97 | 98 |
|
98 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
99 |
| - image.save("./without_LoRA_mutable.jpg") |
| 99 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 100 | +# image.save("./without_LoRA_mutable.jpg") |
100 | 101 |
|
101 |
| - # Standard Huggingface LoRA loading procedure |
102 |
| - pipe.load_lora_weights( |
103 |
| - "stablediffusionapi/load_lora_embeddings", |
104 |
| - weight_name="all-disney-princess-xl-lo.safetensors", |
105 |
| - adapter_name="lora1", |
106 |
| - ) |
107 |
| - pipe.set_adapters(["lora1"], adapter_weights=[1]) |
108 |
| - pipe.fuse_lora() |
109 |
| - pipe.unload_lora_weights() |
| 102 | +# # Standard Huggingface LoRA loading procedure |
| 103 | +# pipe.load_lora_weights( |
| 104 | +# "stablediffusionapi/load_lora_embeddings", |
| 105 | +# weight_name="all-disney-princess-xl-lo.safetensors", |
| 106 | +# adapter_name="lora1", |
| 107 | +# ) |
| 108 | +# pipe.set_adapters(["lora1"], adapter_weights=[1]) |
| 109 | +# pipe.fuse_lora() |
| 110 | +# pipe.unload_lora_weights() |
110 | 111 |
|
111 |
| - # Refit triggered |
112 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
113 |
| - image.save("./with_LoRA_mutable.jpg") |
| 112 | +# # Refit triggered |
| 113 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
| 114 | +# image.save("./with_LoRA_mutable.jpg") |
| 115 | + |
| 116 | + |
| 117 | +# %% |
| 118 | +# Use Mutable Torch TensorRT module with dynamic shape |
| 119 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 120 | +class Model(torch.nn.Module): |
| 121 | + def __init__(self): |
| 122 | + super().__init__() |
| 123 | + |
| 124 | + def forward(self, a, b, c={}): |
| 125 | + x = torch.matmul(a, b) |
| 126 | + x = torch.matmul(c["a"], c["b"].T) |
| 127 | + x = 2 * c["b"][1] |
| 128 | + return x |
| 129 | + |
| 130 | + |
| 131 | +model = Model().eval().cuda() |
| 132 | +inputs = (torch.rand(10, 3), torch.rand(3, 30)) |
| 133 | +kwargs = { |
| 134 | + "c": {"a": torch.rand(10, 30), "b": torch.rand(10, 30)}, |
| 135 | +} |
| 136 | + |
| 137 | +dim = torch.export.Dim("dim", min=1, max=50) |
| 138 | +dim2 = torch.export.Dim("dim2", min=1, max=50) |
| 139 | +args_dynamic_shapes = ({1: dim}, {0: dim}) |
| 140 | +kwarg_dynamic_shapes = { |
| 141 | + "c": {"a": {}, "b": {0: dim2}}, |
| 142 | +} |
| 143 | +# Export the model first with custom dynamic shape constraints |
| 144 | +# exp_program = torch.export.export(model, tuple(inputs), kwargs=k |
| 145 | +trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True) |
| 146 | +trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) |
| 147 | +# Run inference |
| 148 | +trt_gm(*inputs, **kwargs) |
0 commit comments