-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathoptimizers.py
32 lines (25 loc) · 1.1 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
## https://github.com/microsoft/Swin-Transformer/blob/f92123a0035930d89cf53fcb8257199481c4428d/optimizer.py
from torch import nn
from timm.optim import AdamP, AdamW
import torch
from utils import read_yaml
def make_my_optimizer(opt_name: str, model_params, cfg: dict):
opt_name = opt_name.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(model_params, **cfg)
elif opt_name == 'adam':
# https://stackoverflow.com/questions/64621585/adamw-and-adam-with-weight-decay
# https://www.fast.ai/posts/2018-07-02-adam-weight-decay.html
optimizer = torch.optim.Adam(model_params, **cfg)
elif opt_name == 'adamw':
optimizer = AdamW(model_params, **cfg)
elif opt_name == 'adamp':
optimizer = AdamP(model_params, **cfg)
else:
raise NotImplementedError(f'Not implemented optimizer: {opt_name}')
return optimizer
if __name__ == '__main__':
conf = read_yaml('configs/cifar/optimizer/adamw.yaml')
model = nn.Linear(3, 4)
optimizer = make_my_optimizer('adamw', model.parameters(), conf['params'])
print(optimizer.state_dict())