Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutable module improvement #3394

Merged
merged 16 commits into from
Feb 28, 2025
2 changes: 1 addition & 1 deletion examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ Model Zoo
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
166 changes: 150 additions & 16 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.

In this tutorial, we are going to walk through
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
"""

# %%
import numpy as np
import torch
import torch_tensorrt as torch_trt
Expand Down Expand Up @@ -63,16 +65,14 @@
# Saving Mutable Torch TensorRT Module
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Currently, saving is only enabled for C++ runtime, not python runtime.
# Currently, saving is only when "use_python" = False in settings
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

# %%
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline

with torch.no_grad():
Expand All @@ -83,33 +83,167 @@
"immutable_weights": False,
}

model_id = "runwayml/stable-diffusion-v1-5"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda:0"

prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"

pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
BATCH = torch.export.Dim("BATCH", min=1 * 2, max=12 * 2)
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
HEIGHT = 4 * _HEIGHT
WIDTH = 4 * _WIDTH
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
kwargs_dynamic_shapes = {
"encoder_hidden_states": {0: BATCH},
"added_cond_kwargs": {
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
)
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=768,
num_images_per_prompt=2,
).images[0]
image.save("./without_LoRA_mutable.jpg")

# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
weight_name="all-disney-princess-xl-lo.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()

# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=1024,
num_images_per_prompt=1,
).images[0]
image.save("./with_LoRA_mutable.jpg")


# %%
# Use Mutable Torch TensorRT module with dynamic shape
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
# Note that you should exclude keyword arguments with value None as those will be filtered out.


class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c={}):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"].T)
print(c["b"][0])
x = 2 * c["b"]
return x


device = "cuda:0"
model = Model().eval().to(device)
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
kwargs = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
}
dim_0 = torch.export.Dim("dim", min=1, max=50)
dim_1 = torch.export.Dim("dim", min=1, max=50)
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
"c": {
"a": {},
"b": {0: dim_2},
}, # a's shape does not change so we give it an empty dict
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Compile
model(*inputs, **kwargs)
# Change input shape
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
kwargs_2 = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
}
# Run without recompiling
model(*inputs_2, **kwargs_2)

# %%
# Use Mutable Torch TensorRT module with persistent cache
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
import os

from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = True

times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)


example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
model = torch_trt.MutableTorchTensorRTModule(
model,
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_size=1 << 30, # 1GB
)


def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)


remove_timing_cache()

for i in range(4):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]

start.record()
model(*inputs) # Recompile
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------dynamo_compile----------------")
print("Without engine caching, used:", times[0], "ms")
print("With engine caching used:", times[1], "ms")
print("With engine caching used:", times[2], "ms")
print("With engine caching used:", times[3], "ms")
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,12 @@ def refit_module_weights(
try:
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not isinstance(
compiled_submodule, torch.fx.graph_module.GraphModule
):
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
Expand Down
19 changes: 12 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:

@staticmethod
def find_weight(
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
weight_name: str,
np_map: dict[str, Any],
state_dict: dict[str, Any],
device: torch.device,
) -> str:
"""
We need to build map from engine weight name to state_dict weight name.
Expand All @@ -385,19 +388,21 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).cuda()
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
Expand Down Expand Up @@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
# There is no direct connection in batch_norm layer. So skip it
pass
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
sd[sd_weight_name], np_map[engine_weight_name]
sd[sd_weight_name], np_map[engine_weight_name], torch_device
):
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
engine_weight_name, np_map, sd
engine_weight_name, np_map, sd, torch_device
)
if (
weight_name_map[engine_weight_name] != ""
Expand Down
Loading
Loading