diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5e9863ab0d0a..debeedf6626e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -550,8 +550,11 @@ def get_1d_rotary_pos_embed( pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - freqs = freqs.to(pos.device) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index fd0881a14880..8a1e2ffce671 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -38,6 +38,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from torch.profiler import record_function @maybe_allow_in_graph @@ -439,109 +440,114 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - hidden_states = self.x_embedder(hidden_states) + with record_function(" x_embedder"): + hidden_states = self.x_embedder(hidden_states) - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" + with record_function(" time_text_embed"): + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) ) - img_ids = img_ids[0] - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + with record_function(" pos_embeds (rotary)"): + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" ) + img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + with record_function(" blocks"): + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) + with record_function(" single blocks"): + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index bb214885da1c..f77a49b8d896 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -36,6 +36,8 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import FluxPipelineOutput +from torch.profiler import record_function + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -716,21 +718,24 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with record_function(f"transformer iter_{i}"): + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + with record_function(f"scheduler.step (iter_{i})"): + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -757,10 +762,11 @@ def __call__( image = latents else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) + with record_function(f"decode latent"): + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 937cae2e47f5..48742e843829 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .scheduling_utils import SchedulerMixin - +from torch.profiler import record_function logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -284,20 +284,22 @@ def step( " one of the `scheduler.timesteps` as a timestep." ), ) - - if self.step_index is None: - self._init_step_index(timestep) + with record_function(" _init_step_index"): + if self.step_index is None: + self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) + + with record_function(" get sigma and sigma_next"): + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - - prev_sample = sample + (sigma_next - sigma) * model_output + with record_function(" get prev_sample"): + prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) + prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1