Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral: Mixture of Experts quantization #251

Merged
merged 21 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"gptj": GPTJAWQForCausalLM,
"gpt_bigcode": GptBigCodeAWQForCausalLM,
"mistral": MistralAWQForCausalLM,
"mixtral": MixtralAWQForCausalLM,
"gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM,
Expand Down
12 changes: 9 additions & 3 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from awq.quantize.quantizer import AwqQuantizer
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from transformers import (
AutoModelForCausalLM,
AutoConfig,
Expand All @@ -24,7 +28,6 @@
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory

class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config):
Expand Down Expand Up @@ -176,7 +179,7 @@ def _load_config(self, model_path, model_filename, safetensors=True,
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else:
ignore_patterns.append("*.safetensors*")

Expand Down Expand Up @@ -215,6 +218,9 @@ def _load_quantized_modules(self, model, quant_config, version):
# Get every linear layer in a block
named_linears = get_named_linears(layer)

# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)

# Replace activation functions
self._scale_activations(self, layer)

Expand Down
137 changes: 137 additions & 0 deletions awq/models/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))

# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))

# linear in
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))

# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'],
))

return layers


class MixtralFuser:
def __init__(self, model: OldMixtralForCausalLM):
self.model = model

self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))

self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)

34 changes: 34 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,40 @@
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev

def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)

h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h))
out = h + out

return out, None, past_key_value

class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
Expand Down
5 changes: 4 additions & 1 deletion awq/modules/fused/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

self.activation = activation

def forward(self, x):
def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
Expand All @@ -57,6 +57,9 @@ def forward(self, x):
x = x.reshape(out_shape)
x = self.down_proj(x)

if routing_weights is not None:
x = routing_weights * x

return x


Expand Down
59 changes: 57 additions & 2 deletions awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,63 @@
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock


class MixtralModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape

fused_utils.prepare_cache(self.blocks, seqlen)

h = self.embedding(input_ids)

mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)

for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)

h = self.norm(h)

return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)


class LlamaLikeModel(nn.Module):
"""
Expand Down
22 changes: 13 additions & 9 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize
)


class AwqQuantizer:
Expand Down Expand Up @@ -69,13 +75,6 @@ def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: to

return w

def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers

def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
Expand All @@ -90,7 +89,7 @@ def quantize(self):
named_linears = get_named_linears(self.modules[i])

# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)

input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
Expand Down Expand Up @@ -384,6 +383,11 @@ def cache_input_hook(m, x, y, name, feat_dict):

input_feat = defaultdict(list)
handles = []

# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}

for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
Expand Down
Loading