-
Notifications
You must be signed in to change notification settings - Fork 6.1k
small update on rotary embedding #9354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed( | |||
pos = torch.from_numpy(pos) # type: ignore # [S] | |||
|
|||
theta = theta * ntk_factor | |||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | |||
freqs = freqs.to(pos.device) | |||
freqs = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul
this is the only code change from #9321
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): We could add a comment on emphasizing the impact of creating torch.arange()
on device so as to prevent syncs. This will be helpful references for us going forward. What do you think?
We could do similar for the changes introduced in https://github.com/huggingface/diffusers/pull/9307/files as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this did not cause the sync, though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Just a single minor comment.
@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed( | |||
pos = torch.from_numpy(pos) # type: ignore # [S] | |||
|
|||
theta = theta * ntk_factor | |||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | |||
freqs = freqs.to(pos.device) | |||
freqs = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): We could add a comment on emphasizing the impact of creating torch.arange()
on device so as to prevent syncs. This will be helpful references for us going forward. What do you think?
We could do similar for the changes introduced in https://github.com/huggingface/diffusers/pull/9307/files as well.
* update * fix --------- Co-authored-by: Sayak Paul <[email protected]>
create the
freqs
tensor on device to avoid potential sync