Skip to content

Commit

Permalink
FIx for llama multi node (#1136)
Browse files Browse the repository at this point in the history
* Fix llama

* Add for the other models
  • Loading branch information
EricLBuehler authored Feb 13, 2025
1 parent 323e7cd commit c9ac321
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 59 deletions.
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -249,7 +256,11 @@ impl Attention {
rotary_emb,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand Down Expand Up @@ -254,7 +261,11 @@ impl Attention {
sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
Expand Down
38 changes: 13 additions & 25 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer, Shard,
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -193,30 +193,18 @@ impl CausalSelfAttention {
comm,
vb.pp("q_proj"),
)?;

// We may need to replicate the kv heads
let kv_replicate = if comm.world_size() > cfg.num_key_value_heads {
comm.world_size() / cfg.num_key_value_heads
} else {
1
};

let kv_shard_id = comm.rank() / kv_replicate;
// let kv_block_size = size_kv / comm.world_size();
let kv_block_size = cfg.hidden_size / cfg.num_attention_heads;
let shard = Shard::Offset {
dim: 0,
offset: kv_shard_id * kv_block_size,
len: kv_block_size,
};

let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
size_in,
size_kv,
&cfg.quantization_config,
false,
comm,
shard,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new_with_shard(
Expand All @@ -225,7 +213,7 @@ impl CausalSelfAttention {
&cfg.quantization_config,
false,
comm,
shard,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -248,11 +236,11 @@ impl CausalSelfAttention {
max_seq_len: cfg.max_position_embeddings,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: if kv_replicate != 1 {
(cfg.num_attention_heads / cfg.num_key_value_heads) / kv_replicate
} else {
cfg.num_attention_heads / cfg.num_key_value_heads
},
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -232,7 +239,11 @@ impl Attention {
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -121,7 +128,11 @@ impl Attention {
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_attention_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
true,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
true,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let dense = RowParallelLayer::new(
Expand Down Expand Up @@ -236,7 +243,11 @@ impl Attention {
head_dim,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_attention_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/phi3_5_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
cfg.attention_bias,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
cfg.attention_bias,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -147,7 +154,11 @@ impl Attention {
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,27 @@ impl Attention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
true,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
true,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -226,7 +233,11 @@ impl Attention {
rotary_emb,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
Expand Down
Loading

0 comments on commit c9ac321

Please sign in to comment.