Skip to content

Commit 33dfb04

Browse files
shaonianyrcharrli
and
charrli
authored
add starcoder2 support (#406)
Co-authored-by: charrli <[email protected]>
1 parent eb85f67 commit 33dfb04

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

awq/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from .mixtral import MixtralAWQForCausalLM
1616
from .qwen2 import Qwen2AWQForCausalLM
1717
from .gemma import GemmaAWQForCausalLM
18+
from .starcoder2 import Starcoder2AWQForCausalLM

awq/models/auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"llava": LlavaAWQForCausalLM,
2525
"qwen2": Qwen2AWQForCausalLM,
2626
"gemma": GemmaAWQForCausalLM,
27+
"starcoder2": Starcoder2AWQForCausalLM,
2728
}
2829

2930

awq/models/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
"llava": "AutoModelForVision2Seq",
6969
"qwen2": "AutoModelForCausalLM",
7070
"gemma": "AutoModelForCausalLM",
71+
"starcoder2": "AutoModelForCausalLM",
7172
}
7273

7374

awq/models/starcoder2.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import tqdm
2+
from typing import List, Tuple
3+
from .base import BaseAWQForCausalLM
4+
from awq.utils.fused_utils import fuse_qkv
5+
from awq.modules.fused.block import LlamaLikeBlock
6+
from awq.modules.fused.model import LlamaLikeModel
7+
from transformers.models.starcoder2.modeling_starcoder2 import (
8+
Starcoder2ForCausalLM as OldStarcoder2ForCausalLM,
9+
Starcoder2DecoderLayer as OldStarcoder2DecoderLayer,
10+
)
11+
from awq.modules.fused.norm import FasterTransformerRMSNorm
12+
13+
14+
class Starcoder2AWQForCausalLM(BaseAWQForCausalLM):
15+
layer_type = "Starcoder2DecoderLayer"
16+
max_seq_len_key = "max_position_embeddings"
17+
18+
@staticmethod
19+
def fuse_layers(model: OldStarcoder2ForCausalLM):
20+
fuser = Starcoder2Fuser(model)
21+
fuser.fuse_transformer()
22+
23+
@staticmethod
24+
def get_model_layers(model: OldStarcoder2ForCausalLM):
25+
return model.model.layers
26+
27+
@staticmethod
28+
def get_act_for_scaling(module: OldStarcoder2DecoderLayer):
29+
return dict(
30+
is_scalable=True,
31+
scale_name="mlp.act",
32+
scale_layer=module.mlp.act,
33+
scale_shape=module.mlp.c_fc.out_features,
34+
)
35+
# return dict(is_scalable=False)
36+
37+
@staticmethod
38+
def move_embed(model: OldStarcoder2ForCausalLM, device):
39+
model.model.embed_tokens = model.model.embed_tokens.to(device)
40+
41+
@staticmethod
42+
def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs):
43+
layers = []
44+
45+
# attention input
46+
layers.append(
47+
dict(
48+
prev_op=module.input_layernorm,
49+
layers=[
50+
module.self_attn.q_proj,
51+
module.self_attn.k_proj,
52+
module.self_attn.v_proj,
53+
],
54+
inp=input_feat["self_attn.q_proj"],
55+
module2inspect=module.self_attn,
56+
kwargs=module_kwargs,
57+
)
58+
)
59+
60+
# attention out
61+
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
62+
layers.append(
63+
dict(
64+
prev_op=module.self_attn.v_proj,
65+
layers=[module.self_attn.o_proj],
66+
inp=input_feat["self_attn.o_proj"],
67+
)
68+
)
69+
70+
# linear 1
71+
layers.append(
72+
dict(
73+
prev_op=module.post_attention_layernorm,
74+
layers=[module.mlp.c_fc],
75+
inp=input_feat["mlp.c_fc"],
76+
module2inspect=module.mlp,
77+
)
78+
)
79+
80+
# linear 2
81+
layers.append(
82+
dict(
83+
prev_op=module.mlp.act,
84+
layers=[module.mlp.c_proj],
85+
inp=input_feat["mlp.c_proj"],
86+
)
87+
)
88+
89+
return layers
90+
91+
class Starcoder2Fuser:
92+
def __init__(self, model: OldStarcoder2ForCausalLM):
93+
self.model = model
94+
95+
self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [
96+
(name, module)
97+
for name, module in self.model.named_modules()
98+
if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower()
99+
]
100+
101+
def fuse_transformer(self):
102+
blocks = []
103+
104+
module: OldStarcoder2DecoderLayer
105+
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
106+
device = next(iter(module.state_dict().values())).device
107+
qkv = fuse_qkv(
108+
module,
109+
module.self_attn.q_proj,
110+
module.self_attn.k_proj,
111+
module.self_attn.v_proj,
112+
)
113+
norm_1 = FasterTransformerRMSNorm(
114+
module.input_layernorm.weight, module.input_layernorm.eps
115+
)
116+
norm_2 = FasterTransformerRMSNorm(
117+
module.post_attention_layernorm.weight,
118+
module.post_attention_layernorm.eps,
119+
)
120+
blocks.append(
121+
LlamaLikeBlock(
122+
hidden_size=self.model.config.hidden_size,
123+
n_heads=self.model.config.num_attention_heads,
124+
n_kv_heads=self.model.config.num_key_value_heads,
125+
qkv_layer=qkv,
126+
o_proj=module.self_attn.o_proj,
127+
mlp=module.mlp,
128+
norm_1=norm_1,
129+
norm_2=norm_2,
130+
dev=device,
131+
max_seq_len=self.model.config.max_seq_len,
132+
)
133+
)
134+
135+
self.model.model = LlamaLikeModel(
136+
self.model.config.vocab_size,
137+
blocks,
138+
self.model.model.embed_tokens,
139+
self.model.model.norm,
140+
)
141+
setattr(self.model.model, "blocks", self.model.model.blocks)

0 commit comments

Comments
 (0)