|
| 1 | +import math |
| 2 | +import time |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import transformers |
| 7 | + |
| 8 | +from gptq.quant import * |
| 9 | + |
| 10 | + |
| 11 | +DEBUG = False |
| 12 | + |
| 13 | +torch.backends.cuda.matmul.allow_tf32 = False |
| 14 | +torch.backends.cudnn.allow_tf32 = False |
| 15 | + |
| 16 | + |
| 17 | +class GPTQ: |
| 18 | + def __init__(self, layer): |
| 19 | + self.layer = layer |
| 20 | + self.dev = self.layer.weight.device |
| 21 | + W = layer.weight.data.clone() |
| 22 | + if isinstance(self.layer, nn.Conv2d): |
| 23 | + W = W.flatten(1) |
| 24 | + if isinstance(self.layer, transformers.Conv1D): |
| 25 | + W = W.t() |
| 26 | + self.rows = W.shape[0] |
| 27 | + self.columns = W.shape[1] |
| 28 | + self.H = torch.zeros((self.columns, self.columns), device=self.dev) |
| 29 | + self.nsamples = 0 |
| 30 | + |
| 31 | + def add_batch(self, inp, out): |
| 32 | + if DEBUG: |
| 33 | + self.inp1 = inp |
| 34 | + self.out1 = out |
| 35 | + if len(inp.shape) == 2: |
| 36 | + inp = inp.unsqueeze(0) |
| 37 | + tmp = inp.shape[0] |
| 38 | + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): |
| 39 | + if len(inp.shape) == 3: |
| 40 | + inp = inp.reshape((-1, inp.shape[-1])) |
| 41 | + inp = inp.t() |
| 42 | + if isinstance(self.layer, nn.Conv2d): |
| 43 | + unfold = nn.Unfold( |
| 44 | + self.layer.kernel_size, |
| 45 | + dilation=self.layer.dilation, |
| 46 | + padding=self.layer.padding, |
| 47 | + stride=self.layer.stride |
| 48 | + ) |
| 49 | + inp = unfold(inp) |
| 50 | + inp = inp.permute([1, 0, 2]) |
| 51 | + inp = inp.flatten(1) |
| 52 | + self.H *= self.nsamples / (self.nsamples + tmp) |
| 53 | + self.nsamples += tmp |
| 54 | + # inp = inp.float() |
| 55 | + inp = math.sqrt(2 / self.nsamples) * inp.float() |
| 56 | + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) |
| 57 | + self.H += inp.matmul(inp.t()) |
| 58 | + |
| 59 | + def fasterquant( |
| 60 | + self, blocksize=128, percdamp=.01, groupsize=-1 |
| 61 | + ): |
| 62 | + W = self.layer.weight.data.clone() |
| 63 | + if isinstance(self.layer, nn.Conv2d): |
| 64 | + W = W.flatten(1) |
| 65 | + if isinstance(self.layer, transformers.Conv1D): |
| 66 | + W = W.t() |
| 67 | + W = W.float() |
| 68 | + |
| 69 | + tick = time.time() |
| 70 | + |
| 71 | + if not self.quantizer.ready(): |
| 72 | + self.quantizer.find_params(W, weight=True) |
| 73 | + |
| 74 | + H = self.H |
| 75 | + del self.H |
| 76 | + dead = torch.diag(H) == 0 |
| 77 | + H[dead, dead] = 1 |
| 78 | + W[:, dead] = 0 |
| 79 | + |
| 80 | + Losses = torch.zeros_like(W) |
| 81 | + Q = torch.zeros_like(W) |
| 82 | + |
| 83 | + damp = percdamp * torch.mean(torch.diag(H)) |
| 84 | + diag = torch.arange(self.columns, device=self.dev) |
| 85 | + H[diag, diag] += damp |
| 86 | + H = torch.linalg.cholesky(H) |
| 87 | + H = torch.cholesky_inverse(H) |
| 88 | + H = torch.linalg.cholesky(H, upper=True) |
| 89 | + Hinv = H |
| 90 | + |
| 91 | + scale = [] |
| 92 | + zero = [] |
| 93 | + now_idx = 1 |
| 94 | + |
| 95 | + for i1 in range(0, self.columns, blocksize): |
| 96 | + i2 = min(i1 + blocksize, self.columns) |
| 97 | + count = i2 - i1 |
| 98 | + |
| 99 | + W1 = W[:, i1:i2].clone() |
| 100 | + Q1 = torch.zeros_like(W1) |
| 101 | + Err1 = torch.zeros_like(W1) |
| 102 | + Losses1 = torch.zeros_like(W1) |
| 103 | + Hinv1 = Hinv[i1:i2, i1:i2] |
| 104 | + |
| 105 | + for i in range(count): |
| 106 | + w = W1[:, i] |
| 107 | + d = Hinv1[i, i] |
| 108 | + |
| 109 | + if groupsize != -1: |
| 110 | + if (i1 + i) % groupsize == 0: |
| 111 | + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) |
| 112 | + |
| 113 | + if ((i1 + i) // groupsize) - now_idx == -1: |
| 114 | + scale.append(self.quantizer.scale) |
| 115 | + zero.append(self.quantizer.zero) |
| 116 | + now_idx += 1 |
| 117 | + |
| 118 | + q = quantize( |
| 119 | + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq |
| 120 | + ).flatten() |
| 121 | + Q1[:, i] = q |
| 122 | + Losses1[:, i] = (w - q) ** 2 / d ** 2 |
| 123 | + |
| 124 | + err1 = (w - q) / d |
| 125 | + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) |
| 126 | + Err1[:, i] = err1 |
| 127 | + |
| 128 | + Q[:, i1:i2] = Q1 |
| 129 | + Losses[:, i1:i2] = Losses1 / 2 |
| 130 | + |
| 131 | + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) |
| 132 | + |
| 133 | + if DEBUG: |
| 134 | + self.layer.weight.data[:, :i2] = Q[:, :i2] |
| 135 | + self.layer.weight.data[:, i2:] = W[:, i2:] |
| 136 | + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) |
| 137 | + print(torch.sum(Losses)) |
| 138 | + |
| 139 | + torch.cuda.synchronize() |
| 140 | + print('time %.2f' % (time.time() - tick)) |
| 141 | + print('error', torch.sum(Losses).item()) |
| 142 | + |
| 143 | + if isinstance(self.layer, transformers.Conv1D): |
| 144 | + Q = Q.t() |
| 145 | + self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) |
| 146 | + if DEBUG: |
| 147 | + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) |
| 148 | + |
| 149 | + if scale == []: |
| 150 | + scale.append(self.quantizer.scale) |
| 151 | + zero.append(self.quantizer.zero) |
| 152 | + scale = torch.cat(scale,dim=1) |
| 153 | + zero = torch.cat(zero,dim=1) |
| 154 | + return scale,zero |
| 155 | + |
| 156 | + def free(self): |
| 157 | + if DEBUG: |
| 158 | + self.inp1 = None |
| 159 | + self.out1 = None |
| 160 | + self.H = None |
| 161 | + self.Losses = None |
| 162 | + self.Trace = None |
| 163 | + torch.cuda.empty_cache() |
0 commit comments