Skip to content

Commit da43e52

Browse files
authoredApr 1, 2023
支持 BelleGroup/BELLE-LLAMA-7B-2M-gptq (#12)
* Update chatglm.py * Update llama.py * Update app.py * gptq * gptq * Create llama_gptq.py * Update app.py * Update llama_gptq.py
1 parent c407047 commit da43e52

14 files changed

+2005
-10
lines changed
 

‎app.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
# 加载模型
1111
# model_name = 'THUDM/chatglm-6b'
1212
# model_name = 'BelleGroup/BELLE-LLAMA-7B-2M'
13-
model_name = 'silver/chatglm-6b-int4-slim'
13+
# model_name = 'silver/chatglm-6b-int4-slim'
14+
model_name = 'BelleGroup/BELLE-LLAMA-7B-2M-gptq'
1415

1516
if 'chatglm' in model_name.lower():
1617
from predictors.chatglm import ChatGLM
1718
predictor = ChatGLM(model_name)
19+
elif 'gptq' in model_name.lower():
20+
from predictors.llama_gptq import LLaMaGPTQ
21+
predictor = LLaMaGPTQ(model_name)
1822
elif 'llama' in model_name.lower():
1923
from predictors.llama import LLaMa
2024
predictor = LLaMa(model_name)

‎gptq/README.md

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# GPTQ-for-Bloom & LLaMa
2+
8 bits quantization of [Bloom](https://arxiv.org/pdf/2211.05100.pdf) using [GPTQ](https://arxiv.org/abs/2210.17323)
3+
4+
GPTQ is SOTA one-shot weight quantization method
5+
6+
**This code is based on [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa)**
7+
8+
## [Huggingface models](https://huggingface.co/BelleGroup/BELLE-7B-gptq)
9+
10+
11+
| model name | file size | GPU memory usage |
12+
| -------------------------------------------------- | ------------------- | ------------------ |
13+
| base | 27G | ~28.2G |
14+
| bloom7b-2m-8bit-128g.pt | 9.7G | ~11.4G |
15+
| bloom7b-2m-4bit-128g.pt | 6.9G | ~8.4G |
16+
| bloom7b-0.2m-8bit-128g.pt | 9.7G | ~11.4G |
17+
| bloom7b-0.2m-4bit-128g.pt | 6.9G | ~8.4G |
18+
19+
20+
All experiments were run on a single NVIDIA A100.
21+
22+
## Installation
23+
If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first.
24+
```
25+
conda create --name gptq python=3.9 -y
26+
conda activate gptq
27+
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
28+
# Or, if you're having trouble with conda, use pip with python3.9:
29+
# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
30+
31+
pip install -r requirements.txt
32+
python setup_cuda.py install
33+
34+
# Benchmark performance for FC2 layer of LLaMa-7B
35+
CUDA_VISIBLE_DEVICES=0 python test_kernel.py
36+
```
37+
## Dependencies
38+
39+
* `torch`: tested on v2.0.0+cu117
40+
* `transformers`: tested on v4.28.0.dev0
41+
* `datasets`: tested on v2.10.1
42+
* `safetensors`: tested on v0.3.0
43+
* (to run 4-bit kernels: setup for compiling PyTorch CUDA extensions, see also https://pytorch.org/tutorials/advanced/cpp_extension.html, tested on CUDA 11.7)
44+
45+
46+
## Model inference with the saved model
47+
```
48+
# BELLE-7B-gptq: local saved model path from Huggingface
49+
git lfs install
50+
git clone https://huggingface.co/BelleGroup/BELLE-7B-gptq
51+
# model inference with the saved model
52+
CUDA_VISIBLE_DEVICES=0 python bloom_inference.py BELLE-7B-gptq --wbits 8 --groupsize 128 --load BELLE-7B-gptq/bloom7b-2m-8bit-128g.pt --text "hello"
53+
```
54+
55+
## Model quantization
56+
57+
```
58+
# BELLE-7B-gptq: local saved model path
59+
# Save compressed model
60+
CUDA_VISIBLE_DEVICES=0 python bloom.py BelleGroup/BELLE-7B-2M wikitext2 --wbits 8 --groupsize 128 --save BELLE-7B-gptq/bloom7b-2m-8bit-128g.pt
61+
62+
```
63+
CUDA Kernels support 2,3,4,8 bits.
64+
65+
Basically, 8-bit quantization and 128 groupsize are recommended.
66+
67+
# Acknowledgements
68+
This code is based on [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa)
69+
70+
Thanks to [Bloom](https://arxiv.org/pdf/2211.05100.pdf), a powerful LLM.

‎gptq/gptq.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import math
2+
import time
3+
4+
import torch
5+
import torch.nn as nn
6+
import transformers
7+
8+
from gptq.quant import *
9+
10+
11+
DEBUG = False
12+
13+
torch.backends.cuda.matmul.allow_tf32 = False
14+
torch.backends.cudnn.allow_tf32 = False
15+
16+
17+
class GPTQ:
18+
def __init__(self, layer):
19+
self.layer = layer
20+
self.dev = self.layer.weight.device
21+
W = layer.weight.data.clone()
22+
if isinstance(self.layer, nn.Conv2d):
23+
W = W.flatten(1)
24+
if isinstance(self.layer, transformers.Conv1D):
25+
W = W.t()
26+
self.rows = W.shape[0]
27+
self.columns = W.shape[1]
28+
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
29+
self.nsamples = 0
30+
31+
def add_batch(self, inp, out):
32+
if DEBUG:
33+
self.inp1 = inp
34+
self.out1 = out
35+
if len(inp.shape) == 2:
36+
inp = inp.unsqueeze(0)
37+
tmp = inp.shape[0]
38+
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
39+
if len(inp.shape) == 3:
40+
inp = inp.reshape((-1, inp.shape[-1]))
41+
inp = inp.t()
42+
if isinstance(self.layer, nn.Conv2d):
43+
unfold = nn.Unfold(
44+
self.layer.kernel_size,
45+
dilation=self.layer.dilation,
46+
padding=self.layer.padding,
47+
stride=self.layer.stride
48+
)
49+
inp = unfold(inp)
50+
inp = inp.permute([1, 0, 2])
51+
inp = inp.flatten(1)
52+
self.H *= self.nsamples / (self.nsamples + tmp)
53+
self.nsamples += tmp
54+
# inp = inp.float()
55+
inp = math.sqrt(2 / self.nsamples) * inp.float()
56+
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
57+
self.H += inp.matmul(inp.t())
58+
59+
def fasterquant(
60+
self, blocksize=128, percdamp=.01, groupsize=-1
61+
):
62+
W = self.layer.weight.data.clone()
63+
if isinstance(self.layer, nn.Conv2d):
64+
W = W.flatten(1)
65+
if isinstance(self.layer, transformers.Conv1D):
66+
W = W.t()
67+
W = W.float()
68+
69+
tick = time.time()
70+
71+
if not self.quantizer.ready():
72+
self.quantizer.find_params(W, weight=True)
73+
74+
H = self.H
75+
del self.H
76+
dead = torch.diag(H) == 0
77+
H[dead, dead] = 1
78+
W[:, dead] = 0
79+
80+
Losses = torch.zeros_like(W)
81+
Q = torch.zeros_like(W)
82+
83+
damp = percdamp * torch.mean(torch.diag(H))
84+
diag = torch.arange(self.columns, device=self.dev)
85+
H[diag, diag] += damp
86+
H = torch.linalg.cholesky(H)
87+
H = torch.cholesky_inverse(H)
88+
H = torch.linalg.cholesky(H, upper=True)
89+
Hinv = H
90+
91+
scale = []
92+
zero = []
93+
now_idx = 1
94+
95+
for i1 in range(0, self.columns, blocksize):
96+
i2 = min(i1 + blocksize, self.columns)
97+
count = i2 - i1
98+
99+
W1 = W[:, i1:i2].clone()
100+
Q1 = torch.zeros_like(W1)
101+
Err1 = torch.zeros_like(W1)
102+
Losses1 = torch.zeros_like(W1)
103+
Hinv1 = Hinv[i1:i2, i1:i2]
104+
105+
for i in range(count):
106+
w = W1[:, i]
107+
d = Hinv1[i, i]
108+
109+
if groupsize != -1:
110+
if (i1 + i) % groupsize == 0:
111+
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
112+
113+
if ((i1 + i) // groupsize) - now_idx == -1:
114+
scale.append(self.quantizer.scale)
115+
zero.append(self.quantizer.zero)
116+
now_idx += 1
117+
118+
q = quantize(
119+
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
120+
).flatten()
121+
Q1[:, i] = q
122+
Losses1[:, i] = (w - q) ** 2 / d ** 2
123+
124+
err1 = (w - q) / d
125+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
126+
Err1[:, i] = err1
127+
128+
Q[:, i1:i2] = Q1
129+
Losses[:, i1:i2] = Losses1 / 2
130+
131+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
132+
133+
if DEBUG:
134+
self.layer.weight.data[:, :i2] = Q[:, :i2]
135+
self.layer.weight.data[:, i2:] = W[:, i2:]
136+
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
137+
print(torch.sum(Losses))
138+
139+
torch.cuda.synchronize()
140+
print('time %.2f' % (time.time() - tick))
141+
print('error', torch.sum(Losses).item())
142+
143+
if isinstance(self.layer, transformers.Conv1D):
144+
Q = Q.t()
145+
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
146+
if DEBUG:
147+
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
148+
149+
if scale == []:
150+
scale.append(self.quantizer.scale)
151+
zero.append(self.quantizer.zero)
152+
scale = torch.cat(scale,dim=1)
153+
zero = torch.cat(zero,dim=1)
154+
return scale,zero
155+
156+
def free(self):
157+
if DEBUG:
158+
self.inp1 = None
159+
self.out1 = None
160+
self.H = None
161+
self.Losses = None
162+
self.Trace = None
163+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)
Please sign in to comment.