1
+ import tqdm
2
+ from typing import List , Tuple
3
+ from .base import BaseAWQForCausalLM
4
+ from awq .utils .fused_utils import fuse_qkv
5
+ from awq .modules .fused .block import LlamaLikeBlock
6
+ from awq .modules .fused .model import LlamaLikeModel
7
+ from transformers .models .starcoder2 .modeling_starcoder2 import (
8
+ Starcoder2ForCausalLM as OldStarcoder2ForCausalLM ,
9
+ Starcoder2DecoderLayer as OldStarcoder2DecoderLayer ,
10
+ )
11
+ from awq .modules .fused .norm import FasterTransformerRMSNorm
12
+
13
+
14
+ class Starcoder2AWQForCausalLM (BaseAWQForCausalLM ):
15
+ layer_type = "Starcoder2DecoderLayer"
16
+ max_seq_len_key = "max_position_embeddings"
17
+
18
+ @staticmethod
19
+ def fuse_layers (model : OldStarcoder2ForCausalLM ):
20
+ fuser = Starcoder2Fuser (model )
21
+ fuser .fuse_transformer ()
22
+
23
+ @staticmethod
24
+ def get_model_layers (model : OldStarcoder2ForCausalLM ):
25
+ return model .model .layers
26
+
27
+ @staticmethod
28
+ def get_act_for_scaling (module : OldStarcoder2DecoderLayer ):
29
+ return dict (
30
+ is_scalable = True ,
31
+ scale_name = "mlp.act" ,
32
+ scale_layer = module .mlp .act ,
33
+ scale_shape = module .mlp .c_fc .out_features ,
34
+ )
35
+ # return dict(is_scalable=False)
36
+
37
+ @staticmethod
38
+ def move_embed (model : OldStarcoder2ForCausalLM , device ):
39
+ model .model .embed_tokens = model .model .embed_tokens .to (device )
40
+
41
+ @staticmethod
42
+ def get_layers_for_scaling (module : OldStarcoder2DecoderLayer , input_feat , module_kwargs ):
43
+ layers = []
44
+
45
+ # attention input
46
+ layers .append (
47
+ dict (
48
+ prev_op = module .input_layernorm ,
49
+ layers = [
50
+ module .self_attn .q_proj ,
51
+ module .self_attn .k_proj ,
52
+ module .self_attn .v_proj ,
53
+ ],
54
+ inp = input_feat ["self_attn.q_proj" ],
55
+ module2inspect = module .self_attn ,
56
+ kwargs = module_kwargs ,
57
+ )
58
+ )
59
+
60
+ # attention out
61
+ if module .self_attn .v_proj .weight .shape == module .self_attn .o_proj .weight .shape :
62
+ layers .append (
63
+ dict (
64
+ prev_op = module .self_attn .v_proj ,
65
+ layers = [module .self_attn .o_proj ],
66
+ inp = input_feat ["self_attn.o_proj" ],
67
+ )
68
+ )
69
+
70
+ # linear 1
71
+ layers .append (
72
+ dict (
73
+ prev_op = module .post_attention_layernorm ,
74
+ layers = [module .mlp .c_fc ],
75
+ inp = input_feat ["mlp.c_fc" ],
76
+ module2inspect = module .mlp ,
77
+ )
78
+ )
79
+
80
+ # linear 2
81
+ layers .append (
82
+ dict (
83
+ prev_op = module .mlp .act ,
84
+ layers = [module .mlp .c_proj ],
85
+ inp = input_feat ["mlp.c_proj" ],
86
+ )
87
+ )
88
+
89
+ return layers
90
+
91
+ class Starcoder2Fuser :
92
+ def __init__ (self , model : OldStarcoder2ForCausalLM ):
93
+ self .model = model
94
+
95
+ self .starcoder2_blocks : List [Tuple [str , OldStarcoder2DecoderLayer ]] = [
96
+ (name , module )
97
+ for name , module in self .model .named_modules ()
98
+ if "Starcoder2DecoderLayer" .lower () in module .__class__ .__name__ .lower ()
99
+ ]
100
+
101
+ def fuse_transformer (self ):
102
+ blocks = []
103
+
104
+ module : OldStarcoder2DecoderLayer
105
+ for module in tqdm .tqdm (self .model .model .layers , desc = "Fusing layers..." ):
106
+ device = next (iter (module .state_dict ().values ())).device
107
+ qkv = fuse_qkv (
108
+ module ,
109
+ module .self_attn .q_proj ,
110
+ module .self_attn .k_proj ,
111
+ module .self_attn .v_proj ,
112
+ )
113
+ norm_1 = FasterTransformerRMSNorm (
114
+ module .input_layernorm .weight , module .input_layernorm .eps
115
+ )
116
+ norm_2 = FasterTransformerRMSNorm (
117
+ module .post_attention_layernorm .weight ,
118
+ module .post_attention_layernorm .eps ,
119
+ )
120
+ blocks .append (
121
+ LlamaLikeBlock (
122
+ hidden_size = self .model .config .hidden_size ,
123
+ n_heads = self .model .config .num_attention_heads ,
124
+ n_kv_heads = self .model .config .num_key_value_heads ,
125
+ qkv_layer = qkv ,
126
+ o_proj = module .self_attn .o_proj ,
127
+ mlp = module .mlp ,
128
+ norm_1 = norm_1 ,
129
+ norm_2 = norm_2 ,
130
+ dev = device ,
131
+ max_seq_len = self .model .config .max_seq_len ,
132
+ )
133
+ )
134
+
135
+ self .model .model = LlamaLikeModel (
136
+ self .model .config .vocab_size ,
137
+ blocks ,
138
+ self .model .model .embed_tokens ,
139
+ self .model .model .norm ,
140
+ )
141
+ setattr (self .model .model , "blocks" , self .model .model .blocks )
0 commit comments