Skip to content

Commit afed0af

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
add util function to compute rowwise adagrad updates (pytorch#2148)
Summary: Pull Request resolved: pytorch#2148 Added `compute_rowwise_adagrad_updates`, which is a util function to compute rowwise adagrad if we want to just pass in optim_state and grad, without paramater. It can handle the case when grad is sparse. Differential Revision: D58270549
1 parent cf62594 commit afed0af

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

torchrec/optim/rowwise_adagrad.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,6 @@ def __init__(
6969
if not 0.0 <= eps:
7070
raise ValueError("Invalid epsilon value: {}".format(eps))
7171

72-
if weight_decay > 0:
73-
logger.warning(
74-
"Note that the weight decay mode of this optimizer may produce "
75-
"different results compared to the one by FBGEMM TBE. This is "
76-
"due to FBGEMM TBE rowwise adagrad is sparse, and will only "
77-
"update the optimizer states if that row has nonzero gradients."
78-
)
79-
8072
defaults = dict(
8173
lr=lr,
8274
lr_decay=lr_decay,
@@ -213,6 +205,13 @@ def _single_tensor_adagrad(
213205
eps: float,
214206
maximize: bool,
215207
) -> None:
208+
if weight_decay != 0 and len(state_steps) > 0 and state_steps[0].item() < 1.0:
209+
logger.warning(
210+
"Note that the weight decay mode of this optimizer may produce "
211+
"different results compared to the one by FBGEMM TBE. This is "
212+
"due to FBGEMM TBE rowwise adagrad is sparse, and will only "
213+
"update the optimizer states if that row has nonzero gradients."
214+
)
216215

217216
for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
218217
if grad.is_sparse:

torchrec/optim/tests/test_rowwise_adagrad.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,24 @@
1616

1717
class RowWiseAdagradTest(unittest.TestCase):
1818
def test_optim(self) -> None:
19-
embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
19+
embedding_bag = torch.nn.EmbeddingBag(
20+
num_embeddings=4, embedding_dim=4, mode="sum"
21+
)
2022
opt = torchrec.optim.RowWiseAdagrad(embedding_bag.parameters())
2123
index, offsets = torch.tensor([0, 3]), torch.tensor([0, 1])
2224
embedding_bag_out = embedding_bag(index, offsets)
2325
opt.zero_grad()
2426
embedding_bag_out.sum().backward()
27+
opt.step()
2528

2629
def test_optim_equivalence(self) -> None:
2730
# If rows are initialized to be the same and uniform, then RowWiseAdagrad and canonical Adagrad are identical
28-
rowwise_embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
29-
embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
31+
rowwise_embedding_bag = torch.nn.EmbeddingBag(
32+
num_embeddings=4, embedding_dim=4, mode="sum"
33+
)
34+
embedding_bag = torch.nn.EmbeddingBag(
35+
num_embeddings=4, embedding_dim=4, mode="sum"
36+
)
3037
state_dict = {
3138
"weight": torch.Tensor(
3239
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]

0 commit comments

Comments
 (0)