Skip to content

Commit 05258c9

Browse files
committed
add option for decoupled weight decay
1 parent 50edc8a commit 05258c9

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,10 @@ opt = Lion(
9898
year = {2019}
9999
}
100100
```
101+
102+
```bibtex
103+
@misc{Schaipp2024,
104+
author = {Fabian Schaipp},
105+
url = {https://fabian-sp.github.io/posts/2024/02/decoupling/}
106+
}
107+
```

lion_pytorch/foreach.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@ def __init__(
1717
params,
1818
lr: float = 1e-4,
1919
betas: Tuple[float, float] = (0.9, 0.99),
20-
weight_decay: float = 0.0
20+
weight_decay: float = 0.0,
21+
decoupled_weight_decay: bool = False
2122
):
2223
assert lr > 0.
2324
assert all([0. <= beta <= 1. for beta in betas])
2425
assert all([hasattr(torch, attr) for attr in ('_foreach_mul_', '_foreach_add_', '_foreach_sign_', '_foreach_lerp_')]), 'this version of torch does not have the prerequisite foreach functions'
2526

27+
self._init_lr = lr
28+
self.decoupled_wd = decoupled_weight_decay
29+
2630
defaults = dict(
2731
lr = lr,
2832
betas = betas,
@@ -44,7 +48,14 @@ def step(
4448

4549
for group in self.param_groups:
4650

47-
lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
51+
lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr
52+
53+
# maybe decoupled weight decay
54+
55+
if decoupled_wd:
56+
wd /= init_lr
57+
58+
# accumulate List[Tensor] for foreach inplace updates
4859

4960
params = []
5061
grads = []

lion_pytorch/lion_pytorch.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ def exists(val):
1414
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
1515
# stepweight decay
1616

17-
p.data.mul_(1 - lr * wd)
17+
p.data.mul_(1. - lr * wd)
1818

1919
# weight update
2020

21-
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
21+
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
2222
p.add_(update, alpha = -lr)
2323

2424
# decay the momentum running average coefficient
2525

26-
exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
26+
exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
2727

2828
# class
2929

@@ -34,11 +34,15 @@ def __init__(
3434
lr: float = 1e-4,
3535
betas: Tuple[float, float] = (0.9, 0.99),
3636
weight_decay: float = 0.0,
37-
use_triton: bool = False
37+
use_triton: bool = False,
38+
decoupled_weight_decay: bool = False,
3839
):
3940
assert lr > 0.
4041
assert all([0. <= beta <= 1. for beta in betas])
4142

43+
self._init_lr = lr
44+
self.decoupled_wd = decoupled_weight_decay
45+
4246
defaults = dict(
4347
lr = lr,
4448
betas = betas,
@@ -67,7 +71,12 @@ def step(
6771
for group in self.param_groups:
6872
for p in filter(lambda p: exists(p.grad), group['params']):
6973

70-
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
74+
grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr
75+
76+
# maybe decoupled weight decay
77+
78+
if decoupled_wd:
79+
wd /= init_lr
7180

7281
# init state - exponential moving average of gradient values
7382

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.0',
6+
version = '0.2.1',
77
license='MIT',
88
description = 'Lion Optimizer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)