From aa017cc7af827ded0f1723cec065bfe20834ae34 Mon Sep 17 00:00:00 2001 From: charrli Date: Fri, 22 Mar 2024 17:07:30 +0800 Subject: [PATCH 1/2] add starcoder2 --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/starcoder2.py | 140 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 awq/models/starcoder2.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 75542fe4..b2496170 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -15,3 +15,4 @@ from .mixtral import MixtralAWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM from .gemma import GemmaAWQForCausalLM +from .starcoder2 import Starcoder2AWQForCausalLM \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index 1ac6342a..d1f7ba66 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -24,6 +24,7 @@ "llava": LlavaAWQForCausalLM, "qwen2": Qwen2AWQForCausalLM, "gemma": GemmaAWQForCausalLM, + "starcoder2": Starcoder2AWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index e5691ae0..06d0fece 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -68,6 +68,7 @@ "llava": "AutoModelForVision2Seq", "qwen2": "AutoModelForCausalLM", "gemma": "AutoModelForCausalLM", + "starcoder2": "AutoModelForCausalLM", } diff --git a/awq/models/starcoder2.py b/awq/models/starcoder2.py new file mode 100644 index 00000000..55e102ef --- /dev/null +++ b/awq/models/starcoder2.py @@ -0,0 +1,140 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.starcoder2.modeling_starcoder2 import ( + Starcoder2ForCausalLM as OldStarcoder2ForCausalLM, + Starcoder2DecoderLayer as OldStarcoder2DecoderLayer, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class Starcoder2AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Starcoder2DecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def fuse_layers(model: OldStarcoder2ForCausalLM): + fuser = Starcoder2Fuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldStarcoder2ForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldStarcoder2DecoderLayer): + return dict( + is_scalable=True, + scale_name="mlp.act", + scale_layer=module.mlp.act, + scale_shape=module.mlp.c_fc.out_features, + ) + + @staticmethod + def move_embed(model: OldStarcoder2ForCausalLM, device): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.c_fc], + inp=input_feat["mlp.c_fc"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.act, + layers=[module.mlp.c_proj], + inp=input_feat["mlp.c_proj"], + ) + ) + + return layers + +class Starcoder2Fuser: + def __init__(self, model: OldStarcoder2ForCausalLM): + self.model = model + + self.qwen2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldStarcoder2DecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.variance_epsilon + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon, + ) + blocks.append( + LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + ) + ) + + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) \ No newline at end of file From 71047cd35c59c8a3b31f56af3cf7c5ff6f3b6a76 Mon Sep 17 00:00:00 2001 From: charrli Date: Sat, 23 Mar 2024 01:11:55 +0800 Subject: [PATCH 2/2] add starcoder2 support --- awq/models/starcoder2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/awq/models/starcoder2.py b/awq/models/starcoder2.py index 55e102ef..2e493514 100644 --- a/awq/models/starcoder2.py +++ b/awq/models/starcoder2.py @@ -32,6 +32,7 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer): scale_layer=module.mlp.act, scale_shape=module.mlp.c_fc.out_features, ) + # return dict(is_scalable=False) @staticmethod def move_embed(model: OldStarcoder2ForCausalLM, device): @@ -91,7 +92,7 @@ class Starcoder2Fuser: def __init__(self, model: OldStarcoder2ForCausalLM): self.model = model - self.qwen2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [ + self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [ (name, module) for name, module in self.model.named_modules() if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower() @@ -110,11 +111,11 @@ def fuse_transformer(self): module.self_attn.v_proj, ) norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, module.input_layernorm.variance_epsilon + module.input_layernorm.weight, module.input_layernorm.eps ) norm_2 = FasterTransformerRMSNorm( module.post_attention_layernorm.weight, - module.post_attention_layernorm.variance_epsilon, + module.post_attention_layernorm.eps, ) blocks.append( LlamaLikeBlock(