From 1ff6c42ab584d3acef4b043dcd641e04b084ec6b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 28 Aug 2024 23:48:19 +0200 Subject: [PATCH] change get_1d_rotary to accept pos as torch tensors --- src/diffusers/models/embeddings.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d1366654c448..088ecf78efff 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -561,11 +561,14 @@ def get_1d_rotary_pos_embed( assert dim % 2 == 0 if isinstance(pos, int): - pos = np.arange(pos) + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] + freqs = freqs.to(pos.device) + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] @@ -638,7 +641,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.squeeze().float().cpu().numpy() + pos = ids.squeeze().float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes):