Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lion 8 bit #188

Merged
merged 21 commits into from
Apr 11, 2023
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7247cb4
initial commit, slowly work from interface into the kernel
lucidrains Mar 9, 2023
d43ea97
make sure interface is correct
lucidrains Mar 9, 2023
cb4c3c8
do a bunch of typical bookkeeping before getting to main lion logic
lucidrains Mar 9, 2023
8de29fc
forget about tests for now, will test live on local enwik8 training
lucidrains Mar 9, 2023
64bb1ae
add a sign function, for lion
lucidrains Mar 9, 2023
c83888a
use epsilon as beta2 for lion, complete most of the logic in kernel.c…
lucidrains Mar 9, 2023
ead570a
remove something rmsprop specific
lucidrains Mar 9, 2023
af03430
fix weight decay for lion to be decoupled, using a switch
lucidrains Mar 9, 2023
c558272
missed adagrad
lucidrains Mar 9, 2023
8618bed
swap the order in which momentum and parameters are updated in ops.cu
lucidrains Mar 10, 2023
c99b44f
do the epsilon beta2 switcharoo within the cuda code, and not within …
lucidrains Mar 10, 2023
19b9ef3
whoops
lucidrains Mar 10, 2023
abbe65a
beta2 is actually accessible in kOptimizerStatic8bit1StateBlockwise
lucidrains Mar 10, 2023
6c377b3
always pass beta2 into all the 1state functions
lucidrains Mar 10, 2023
369a51c
switch all eps to beta2
lucidrains Mar 10, 2023
9b656f4
follow advice of Tim to fix update of momentum vs parameters in block…
lucidrains Mar 22, 2023
a43cd20
add some code in test_optim.py, although it seems to be failing
lucidrains Mar 22, 2023
aa9b939
add some comments, and fix use of g_val
lucidrains Mar 22, 2023
916000c
fix consistent tabs / spaces
lucidrains Mar 22, 2023
978ba2d
another tab/spaces fix
lucidrains Mar 22, 2023
2a6828e
fix comment
lucidrains Mar 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
## Features
- 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms
12 changes: 12 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,10 @@ def prod(iterable):
lib.crmsprop32bit_g32,
lib.crmsprop32bit_g16,
)
str2optimizer32bit["lion"] = (
lib.clion32bit_g32,
lib.clion32bit_g16,
)
str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16,
@@ -58,6 +62,10 @@ def prod(iterable):
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_g32,
lib.clion_static_8bit_g16,
)
str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16,
@@ -80,6 +88,10 @@ def prod(iterable):
lib.crmsprop_8bit_blockwise_fp32,
lib.crmsprop_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_fp32,
lib.clion_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32,
lib.cadagrad_8bit_blockwise_fp16,
1 change: 1 addition & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -12,4 +12,5 @@
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit
from .sgd import SGD, SGD8bit, SGD32bit
87 changes: 87 additions & 0 deletions bitsandbytes/optim/lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State


class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)


class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)


class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
87 changes: 75 additions & 12 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
@@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return __int_as_float(old);
}

// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA

template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}

template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
@@ -743,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{

@@ -790,6 +798,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
@@ -821,7 +832,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{

@@ -890,6 +901,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,

p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
@@ -1158,7 +1173,7 @@ __global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@@ -1219,6 +1234,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@@ -1244,7 +1262,7 @@ template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@@ -1307,8 +1325,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;

if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}

s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];

switch(OPTIMIZER)
@@ -1321,6 +1350,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,

p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
@@ -1649,10 +1682,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}

s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];

@@ -1664,6 +1707,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@@ -1701,6 +1749,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
@@ -2692,24 +2743,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \

MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)

@@ -2731,6 +2786,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
@@ -2742,11 +2798,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
MAKE_PreconditionStatic8bit1State(LION, half)
MAKE_PreconditionStatic8bit1State(LION, float)

#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
@@ -2758,6 +2817,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)
MAKE_optimizerStatic8bit1State(LION, half)
MAKE_optimizerStatic8bit1State(LION, float)

#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
@@ -2849,5 +2910,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
8 changes: 4 additions & 4 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
@@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);

template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);

template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
38 changes: 33 additions & 5 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
@@ -120,17 +120,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case MOMENTUM:
case RMSPROP:
case ADAGRAD:

if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());

if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
}
}

@@ -164,12 +175,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());

CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default:
break;
}
@@ -198,6 +219,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
case LION:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
@@ -707,6 +729,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)

@@ -726,6 +750,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)

#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
@@ -738,6 +764,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);

1 change: 1 addition & 0 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
@@ -70,6 +70,7 @@ typedef enum Optimizer_t
RMSPROP = 2,
LARS = 3,
ADAGRAD = 4,
LION = 5,
} Optimizer_t;

typedef enum Transform_t
12 changes: 12 additions & 0 deletions csrc/pythonInterface.c
Original file line number Diff line number Diff line change
@@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32(adam, ADAM, half, 16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32)
MAKE_FUNC32(lion, LION, half, 16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)

@@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)

#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
@@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_BLOCKWISE8(lion, LION, half, 16)
MAKE_BLOCKWISE8(lion, LION, float, 32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)

@@ -161,6 +167,8 @@ extern "C"
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32)
MAKE_CFUNC32(lion, half, 16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)

@@ -183,6 +191,8 @@ extern "C"
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)

#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
@@ -196,6 +206,8 @@ extern "C"
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_CBLOCKWISE8(lion, LION, half, 16)
MAKE_CBLOCKWISE8(lion, LION, float, 32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
lion-pytorch
pytest
23 changes: 22 additions & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
from os.path import join

import pytest
from lion_pytorch import Lion

import torch

import bitsandbytes as bnb
@@ -31,13 +33,15 @@ def rm_path(path):
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@@ -54,6 +58,10 @@ def rm_path(path):
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@@ -71,6 +79,10 @@ def rm_path(path):
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@@ -82,6 +94,7 @@ def rm_path(path):

str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
@@ -90,6 +103,9 @@ def rm_path(path):
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lion8bit"] = [
("exp_avg", "state1", "qmap1", "max1")
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
@@ -98,6 +114,9 @@ def rm_path(path):
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
@@ -113,7 +132,7 @@ def rm_path(path):
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
@@ -241,9 +260,11 @@ def test_global_config(dim1, dim2, gtype):
gtype = [torch.float32, torch.float16]
optimizer_names = [
"adam8bit",
"lion8bit",
"momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lion8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",