Closed
Description
Is it possible to refactor the Flux positional embeddings so that we can fully make use of CUDAGRAPHs?
skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from :
File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 469, in forward
image_rotary_emb = self.pos_embed(ids)
File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sayak/diffusers/src/diffusers/models/embeddings.py", line 630, in forward
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
Code
import torch
torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True
import diffusers
from platform import python_version
from diffusers import DiffusionPipeline
print(diffusers.__version__)
print(torch.__version__)
print(python_version())
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
for _ in range(5):
image = pipe(
"Happy bear",
num_inference_steps=5,
guidance_scale=3.5,
max_sequence_length=512,
generator=torch.manual_seed(42),
height=1024,
width=1024,
).images[0]
If we can fully make sure CUDAGRAPHs torch.compile()
would be faster.
Metadata
Metadata
Assignees
Labels
No labels