Skip to content

Commit a4b2f3e

Browse files
mjacarfacebook-github-bot
authored andcommittedJul 2, 2019
Implement AdamW optimizer (pytorch#21250)
Summary: # What is this? This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training. There have already been several abortive attempts to push this into pytorch in some form or fashion: pytorch#17468, pytorch#10866, pytorch#3740, pytorch#4429. Hopefully this one goes through. # Why is this important? Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have. # How was this tested? There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation. Pull Request resolved: pytorch#21250 Differential Revision: D16060339 Pulled By: vincentqb fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
1 parent c9a8413 commit a4b2f3e

File tree

6 files changed

+140
-0
lines changed

6 files changed

+140
-0
lines changed
 

‎docs/source/optim.rst

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ Algorithms
111111
:members:
112112
.. autoclass:: Adam
113113
:members:
114+
.. autoclass:: AdamW
115+
:members:
114116
.. autoclass:: SparseAdam
115117
:members:
116118
.. autoclass:: Adamax

‎test/optim/test.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def drosenbrock(tensor):
1616
'adadelta': optim.adadelta,
1717
'adagrad': optim.adagrad,
1818
'adam': optim.adam,
19+
'adamw': optim.adamw,
1920
'adamax': optim.adamax,
2021
'asgd': optim.asgd,
2122
'cg': optim.cg,

‎test/optim/tests.json

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@
2525
{"learningRate": 1e-4, "weightDecay": 0.1}
2626
]
2727
},
28+
{
29+
"algorithm": "adamw",
30+
"config": [
31+
{},
32+
{"learningRate": 1e-4},
33+
{"learningRate": 1e-4, "beta1": 0.92},
34+
{"learningRate": 1e-4, "beta1": 0.92, "beta2": 0.96},
35+
{"learningRate": 1e-4, "beta1": 0.92, "beta2": 0.96, "epsilon": 1e-3},
36+
{"learningRate": 1e-4, "weightDecay": 0.1}
37+
]
38+
},
2839
{
2940
"algorithm": "adamax",
3041
"config": [

‎test/test_optim.py

+10
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ def test_adam(self):
326326
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
327327
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
328328

329+
def test_adamw(self):
330+
self._test_basic_cases(
331+
lambda weight, bias: optim.AdamW([weight, bias], lr=1e-3)
332+
)
333+
self._test_basic_cases(
334+
lambda weight, bias: optim.AdamW(
335+
self._build_params_dict(weight, bias, lr=1e-2),
336+
lr=1e-3)
337+
)
338+
329339
def test_sparse_adam(self):
330340
self._test_rosenbrock_sparse(
331341
lambda params: optim.SparseAdam(params, lr=4e-2),

‎torch/optim/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .adadelta import Adadelta # noqa: F401
99
from .adagrad import Adagrad # noqa: F401
1010
from .adam import Adam # noqa: F401
11+
from .adamw import AdamW # noqa: F401
1112
from .sparse_adam import SparseAdam # noqa: F401
1213
from .adamax import Adamax # noqa: F401
1314
from .asgd import ASGD # noqa: F401
@@ -21,6 +22,7 @@
2122
del adadelta
2223
del adagrad
2324
del adam
25+
del adamw
2426
del sparse_adam
2527
del adamax
2628
del asgd

‎torch/optim/adamw.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import math
2+
import torch
3+
from .optimizer import Optimizer
4+
5+
6+
class AdamW(Optimizer):
7+
r"""Implements AdamW algorithm.
8+
9+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
10+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
11+
12+
Arguments:
13+
params (iterable): iterable of parameters to optimize or dicts defining
14+
parameter groups
15+
lr (float, optional): learning rate (default: 1e-3)
16+
betas (Tuple[float, float], optional): coefficients used for computing
17+
running averages of gradient and its square (default: (0.9, 0.999))
18+
eps (float, optional): term added to the denominator to improve
19+
numerical stability (default: 1e-8)
20+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
21+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
22+
algorithm from the paper `On the Convergence of Adam and Beyond`_
23+
(default: False)
24+
25+
.. _Adam\: A Method for Stochastic Optimization:
26+
https://arxiv.org/abs/1412.6980
27+
.. _Decoupled Weight Decay Regularization:
28+
https://arxiv.org/abs/1711.05101
29+
.. _On the Convergence of Adam and Beyond:
30+
https://openreview.net/forum?id=ryQu7f-RZ
31+
"""
32+
33+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
34+
weight_decay=1e-2, amsgrad=False):
35+
if not 0.0 <= lr:
36+
raise ValueError("Invalid learning rate: {}".format(lr))
37+
if not 0.0 <= eps:
38+
raise ValueError("Invalid epsilon value: {}".format(eps))
39+
if not 0.0 <= betas[0] < 1.0:
40+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
41+
if not 0.0 <= betas[1] < 1.0:
42+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
43+
defaults = dict(lr=lr, betas=betas, eps=eps,
44+
weight_decay=weight_decay, amsgrad=amsgrad)
45+
super(AdamW, self).__init__(params, defaults)
46+
47+
def __setstate__(self, state):
48+
super(AdamW, self).__setstate__(state)
49+
for group in self.param_groups:
50+
group.setdefault('amsgrad', False)
51+
52+
def step(self, closure=None):
53+
"""Performs a single optimization step.
54+
55+
Arguments:
56+
closure (callable, optional): A closure that reevaluates the model
57+
and returns the loss.
58+
"""
59+
loss = None
60+
if closure is not None:
61+
loss = closure()
62+
63+
for group in self.param_groups:
64+
for p in group['params']:
65+
if p.grad is None:
66+
continue
67+
68+
# Perform stepweight decay
69+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
70+
71+
# Perform optimization step
72+
grad = p.grad.data
73+
if grad.is_sparse:
74+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
75+
amsgrad = group['amsgrad']
76+
77+
state = self.state[p]
78+
79+
# State initialization
80+
if len(state) == 0:
81+
state['step'] = 0
82+
# Exponential moving average of gradient values
83+
state['exp_avg'] = torch.zeros_like(p.data)
84+
# Exponential moving average of squared gradient values
85+
state['exp_avg_sq'] = torch.zeros_like(p.data)
86+
if amsgrad:
87+
# Maintains max of all exp. moving avg. of sq. grad. values
88+
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
89+
90+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
91+
if amsgrad:
92+
max_exp_avg_sq = state['max_exp_avg_sq']
93+
beta1, beta2 = group['betas']
94+
95+
state['step'] += 1
96+
97+
# Decay the first and second moment running average coefficient
98+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
99+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
100+
if amsgrad:
101+
# Maintains the maximum of all 2nd moment running avg. till now
102+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
103+
# Use the max. for normalizing running avg. of gradient
104+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
105+
else:
106+
denom = exp_avg_sq.sqrt().add_(group['eps'])
107+
108+
bias_correction1 = 1 - beta1 ** state['step']
109+
bias_correction2 = 1 - beta2 ** state['step']
110+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
111+
112+
p.data.addcdiv_(-step_size, exp_avg, denom)
113+
114+
return loss

0 commit comments

Comments
 (0)
Please sign in to comment.