|
| 1 | +import math |
| 2 | +import torch |
| 3 | +from .optimizer import Optimizer |
| 4 | + |
| 5 | + |
| 6 | +class AdamW(Optimizer): |
| 7 | + r"""Implements AdamW algorithm. |
| 8 | +
|
| 9 | + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. |
| 10 | + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. |
| 11 | +
|
| 12 | + Arguments: |
| 13 | + params (iterable): iterable of parameters to optimize or dicts defining |
| 14 | + parameter groups |
| 15 | + lr (float, optional): learning rate (default: 1e-3) |
| 16 | + betas (Tuple[float, float], optional): coefficients used for computing |
| 17 | + running averages of gradient and its square (default: (0.9, 0.999)) |
| 18 | + eps (float, optional): term added to the denominator to improve |
| 19 | + numerical stability (default: 1e-8) |
| 20 | + weight_decay (float, optional): weight decay coefficient (default: 1e-2) |
| 21 | + amsgrad (boolean, optional): whether to use the AMSGrad variant of this |
| 22 | + algorithm from the paper `On the Convergence of Adam and Beyond`_ |
| 23 | + (default: False) |
| 24 | +
|
| 25 | + .. _Adam\: A Method for Stochastic Optimization: |
| 26 | + https://arxiv.org/abs/1412.6980 |
| 27 | + .. _Decoupled Weight Decay Regularization: |
| 28 | + https://arxiv.org/abs/1711.05101 |
| 29 | + .. _On the Convergence of Adam and Beyond: |
| 30 | + https://openreview.net/forum?id=ryQu7f-RZ |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, |
| 34 | + weight_decay=1e-2, amsgrad=False): |
| 35 | + if not 0.0 <= lr: |
| 36 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 37 | + if not 0.0 <= eps: |
| 38 | + raise ValueError("Invalid epsilon value: {}".format(eps)) |
| 39 | + if not 0.0 <= betas[0] < 1.0: |
| 40 | + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
| 41 | + if not 0.0 <= betas[1] < 1.0: |
| 42 | + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
| 43 | + defaults = dict(lr=lr, betas=betas, eps=eps, |
| 44 | + weight_decay=weight_decay, amsgrad=amsgrad) |
| 45 | + super(AdamW, self).__init__(params, defaults) |
| 46 | + |
| 47 | + def __setstate__(self, state): |
| 48 | + super(AdamW, self).__setstate__(state) |
| 49 | + for group in self.param_groups: |
| 50 | + group.setdefault('amsgrad', False) |
| 51 | + |
| 52 | + def step(self, closure=None): |
| 53 | + """Performs a single optimization step. |
| 54 | +
|
| 55 | + Arguments: |
| 56 | + closure (callable, optional): A closure that reevaluates the model |
| 57 | + and returns the loss. |
| 58 | + """ |
| 59 | + loss = None |
| 60 | + if closure is not None: |
| 61 | + loss = closure() |
| 62 | + |
| 63 | + for group in self.param_groups: |
| 64 | + for p in group['params']: |
| 65 | + if p.grad is None: |
| 66 | + continue |
| 67 | + |
| 68 | + # Perform stepweight decay |
| 69 | + p.data.mul_(1 - group['lr'] * group['weight_decay']) |
| 70 | + |
| 71 | + # Perform optimization step |
| 72 | + grad = p.grad.data |
| 73 | + if grad.is_sparse: |
| 74 | + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') |
| 75 | + amsgrad = group['amsgrad'] |
| 76 | + |
| 77 | + state = self.state[p] |
| 78 | + |
| 79 | + # State initialization |
| 80 | + if len(state) == 0: |
| 81 | + state['step'] = 0 |
| 82 | + # Exponential moving average of gradient values |
| 83 | + state['exp_avg'] = torch.zeros_like(p.data) |
| 84 | + # Exponential moving average of squared gradient values |
| 85 | + state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 86 | + if amsgrad: |
| 87 | + # Maintains max of all exp. moving avg. of sq. grad. values |
| 88 | + state['max_exp_avg_sq'] = torch.zeros_like(p.data) |
| 89 | + |
| 90 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 91 | + if amsgrad: |
| 92 | + max_exp_avg_sq = state['max_exp_avg_sq'] |
| 93 | + beta1, beta2 = group['betas'] |
| 94 | + |
| 95 | + state['step'] += 1 |
| 96 | + |
| 97 | + # Decay the first and second moment running average coefficient |
| 98 | + exp_avg.mul_(beta1).add_(1 - beta1, grad) |
| 99 | + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
| 100 | + if amsgrad: |
| 101 | + # Maintains the maximum of all 2nd moment running avg. till now |
| 102 | + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
| 103 | + # Use the max. for normalizing running avg. of gradient |
| 104 | + denom = max_exp_avg_sq.sqrt().add_(group['eps']) |
| 105 | + else: |
| 106 | + denom = exp_avg_sq.sqrt().add_(group['eps']) |
| 107 | + |
| 108 | + bias_correction1 = 1 - beta1 ** state['step'] |
| 109 | + bias_correction2 = 1 - beta2 ** state['step'] |
| 110 | + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 |
| 111 | + |
| 112 | + p.data.addcdiv_(-step_size, exp_avg, denom) |
| 113 | + |
| 114 | + return loss |
0 commit comments