Skip to content

small update on rotary embedding #9354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 3, 2024
Merged

small update on rotary embedding #9354

merged 6 commits into from
Sep 3, 2024

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 3, 2024

create the freqs tensor on device to avoid potential sync

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
this is the only code change from #9321

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): We could add a comment on emphasizing the impact of creating torch.arange() on device so as to prevent syncs. This will be helpful references for us going forward. What do you think?

We could do similar for the changes introduced in https://github.com/huggingface/diffusers/pull/9307/files as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this did not cause the sync, though

@yiyixuxu yiyixuxu requested review from DN6 and sayakpaul September 3, 2024 09:31
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Just a single minor comment.

@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): We could add a comment on emphasizing the impact of creating torch.arange() on device so as to prevent syncs. This will be helpful references for us going forward. What do you think?

We could do similar for the changes introduced in https://github.com/huggingface/diffusers/pull/9307/files as well.

@yiyixuxu yiyixuxu merged commit dcf320f into main Sep 3, 2024
18 checks passed
@yiyixuxu yiyixuxu deleted the small-update-rotary branch September 3, 2024 17:18
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* update

* fix

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants