Skip to content

Commit

Permalink
GPU-accelerated sampling (+5% decode perf) (#1094)
Browse files Browse the repository at this point in the history
* GPU-accelerates sampling

* Clippy
  • Loading branch information
EricLBuehler authored Jan 24, 2025
1 parent e020295 commit 5d0d281
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions mistralrs-core/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,20 @@ impl Sampler {
})
}

fn get_top_logprobs(
&self,
probs: &[f32],
argsort_indices: &[usize],
) -> Result<Vec<TopLogprob>> {
fn get_top_logprobs(&self, probs: &[f32], argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
let mut argsort_indices_sorted = argsort_indices.to_vec();
// Sort by descending prob
argsort_indices_sorted
.sort_by(|a, b| probs[*b].partial_cmp(&probs[*a]).expect("No ordering."));
argsort_indices_sorted.sort_by(|a, b| {
probs[*b as usize]
.partial_cmp(&probs[*a as usize])
.expect("No ordering.")
});
// These are where the top n are
let top_n_toks_range = 0..self.top_n_logprobs;
// The top n's values
let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()]
.iter()
.map(|x| probs[*x].log(10.0))
.map(|x| probs[*x as usize].log(10.0))
.collect::<Vec<_>>();
// Find where they actually are in the logits
let mut top_n_toks = Vec::new();
Expand All @@ -283,22 +282,22 @@ impl Sampler {
for tok in &top_n_toks {
bytes.push(
tokenizer
.decode(&[*tok as u32], false)
.decode(&[{ *tok }], false)
.map_err(|x| Error::Msg(x.to_string()))?,
);
}

Ok(zip(bytes, zip(top_n_toks, top_n_logprobs))
.map(|(bytes, (token, logprob))| TopLogprob {
token: token as u32,
token,
logprob,
bytes: Some(bytes),
})
.collect::<Vec<_>>())
} else {
Ok(zip(top_n_toks, top_n_logprobs)
.map(|(token, logprob)| TopLogprob {
token: token as u32,
token,
logprob,
bytes: None,
})
Expand All @@ -311,7 +310,7 @@ impl Sampler {

let probs: Vec<f32> = logits.to_vec1()?;

let argsort_indices = (0..probs.len()).collect::<Vec<_>>();
let argsort_indices = (0..probs.len() as u32).collect::<Vec<_>>();
let logprob = probs[next_token as usize].log(10.0);

let top_logprobs = if return_logprobs {
Expand Down Expand Up @@ -347,17 +346,13 @@ impl Sampler {
min_p: f32,
) -> Result<Logprobs> {
let mut probs: Vec<f32> = logits.to_vec1()?;
let mut argsort_indices = (0..probs.len()).collect::<Vec<_>>();

// Sort by descending probability.
argsort_indices
.sort_unstable_by(|&i, &j| probs[j].partial_cmp(&probs[i]).expect("No ordering."));
let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;

if top_k > 0 {
// Clamp smaller probabilities to zero.
for (index, val) in argsort_indices.iter().enumerate() {
if index >= top_k as usize {
probs[*val] = 0.0;
probs[*val as usize] = 0.0;
}
}
}
Expand All @@ -372,13 +367,13 @@ impl Sampler {
let mut cumsum = 0.;
for index in &argsort_indices {
if cumsum >= top_p {
probs[*index] = 0.0;
probs[*index as usize] = 0.0;
} else {
cumsum += probs[*index];
cumsum += probs[*index as usize];
}
}

let max_p = probs[argsort_indices[0]];
let max_p = probs[argsort_indices[0] as usize];

// MIN P

Expand All @@ -387,8 +382,8 @@ impl Sampler {

// Clamp smaller probabilities to zero.
for index in &argsort_indices {
if max_p * min_p >= probs[*index] {
probs[*index] = 0.0;
if max_p * min_p >= probs[*index as usize] {
probs[*index as usize] = 0.0;
}
}

Expand Down Expand Up @@ -425,7 +420,7 @@ impl Sampler {
fn sample_multinomial(
&self,
probs: &mut Vec<f32>,
argsort_indices: Vec<usize>,
argsort_indices: Vec<u32>,
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
Expand Down Expand Up @@ -459,25 +454,24 @@ impl Sampler {
})
}

#[allow(clippy::too_many_arguments)]
fn sample_top_kp_min_p(
&self,
probs: &mut Vec<f32>,
logits: &Tensor,
top_k: i64,
top_p: f32,
min_p: f32,
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
let mut argsort_indices = (0..probs.len()).collect::<Vec<_>>();
// Sort by descending probability.
argsort_indices
.sort_unstable_by(|&i, &j| probs[j].partial_cmp(&probs[i]).expect("No ordering."));
let argsort_indices: Vec<u32> = logits.arg_sort_last_dim(false)?.to_vec1()?;

if top_k > 0 {
// Clamp smaller probabilities to zero.
for (index, val) in argsort_indices.iter().enumerate() {
if index >= top_k as usize {
probs[*val] = 0.0;
probs[*val as usize] = 0.0;
}
}
}
Expand All @@ -496,17 +490,17 @@ impl Sampler {
let mut cumsum = 0.;
for index in &argsort_indices {
if cumsum >= top_p {
probs[*index] = 0.0;
probs[*index as usize] = 0.0;
} else {
cumsum += probs[*index];
cumsum += probs[*index as usize];
}
}

if min_p <= 0.0 || min_p >= 1.0 {
return self.sample_multinomial(probs, argsort_indices, return_logprobs, rng);
}

let max_p = probs[argsort_indices[0]];
let max_p = probs[argsort_indices[0] as usize];

// MIN P

Expand All @@ -515,8 +509,8 @@ impl Sampler {

// Clamp smaller probabilities to zero.
for index in &argsort_indices {
if max_p * min_p >= probs[*index] {
probs[*index] = 0.0;
if max_p * min_p >= probs[*index as usize] {
probs[*index as usize] = 0.0;
}
}

Expand Down Expand Up @@ -677,11 +671,12 @@ impl Sampler {
None => self.sample_argmax(logits, return_logprobs)?,
Some(temperature) => {
let logits = (&logits / temperature)?;
let probs = candle_nn::ops::softmax_last_dim(&logits)?;
let mut probs: Vec<f32> = probs.to_vec1()?;
let logits = candle_nn::ops::softmax_last_dim(&logits)?;
let mut probs: Vec<f32> = logits.to_vec1()?;

self.sample_top_kp_min_p(
&mut probs,
&logits,
self.top_k,
self.top_p as f32,
self.min_p as f32,
Expand Down

0 comments on commit 5d0d281

Please sign in to comment.