Skip to content

Commit

Permalink
Fix chat sampling response (#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Feb 19, 2025
1 parent 71650a4 commit 8d89c14
Showing 1 changed file with 114 additions and 114 deletions.
228 changes: 114 additions & 114 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,132 +125,132 @@ pub(crate) async fn finish_or_add_toks_to_seq(
this.reset_non_granular_state();
}
}
} else if let Some(reason) = is_done {
/*
***********************
Finish the sequence now
***********************
*/
{
seq.set_state(crate::sequence::SequenceState::Done(reason));
let (tokenizer, pipeline_name) = {
let pipeline_name = this.name();
let tokenizer = this.tokenizer();
(tokenizer, pipeline_name)
};
}
} else if let Some(reason) = is_done {
/*
***********************
Finish the sequence now
***********************
*/
{
seq.set_state(crate::sequence::SequenceState::Done(reason));
let (tokenizer, pipeline_name) = {
let pipeline_name = this.name();
let tokenizer = this.tokenizer();
(tokenizer, pipeline_name)
};

let logprobs = if seq.return_logprobs() {
let mut logprobs = Vec::new();
for logprob in seq.logprobs() {
let resp_logprob = crate::ResponseLogprob {
let logprobs = if seq.return_logprobs() {
let mut logprobs = Vec::new();
for logprob in seq.logprobs() {
let resp_logprob = crate::ResponseLogprob {
token: crate::handle_seq_error_ok!(
tokenizer
.as_ref()
.ok_or(candle_core::Error::Msg(
"`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
.to_string(),
))?.decode(&[logprob.token], false),
seq.responder()
),
tokenizer
.as_ref()
.ok_or(candle_core::Error::Msg(
"`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
.to_string(),
))?.decode(&[logprob.token], false),
seq.responder()
),
bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
logprob: logprob.logprob,
top_logprobs: logprob.top_logprobs.clone().unwrap(),
};
logprobs.push(resp_logprob);
}
Some(logprobs)
} else {
None
};

let text = match reason {
crate::sequence::StopReason::Length(_)
| crate::sequence::StopReason::ModelLength(_)
| crate::sequence::StopReason::Eos
| crate::sequence::StopReason::StopTok(_)
| crate::sequence::StopReason::Canceled => {
String::from_utf8_lossy(seq.completion_bytes())
.trim_start()
.to_string()
}
crate::sequence::StopReason::StopString {
completion_bytes_pos,
..
} => {
let txt = String::from_utf8_lossy(seq.completion_bytes());
txt[..completion_bytes_pos].trim_start().to_string()
}
crate::sequence::StopReason::GeneratedImage => {
candle_core::bail!("Stop reason was `GeneratedImage`.")
}
};

if seq.get_mut_group().is_chat {
let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone())
.map_err(candle_core::Error::msg)?;
let choice = crate::Choice {
finish_reason: reason.to_string(),
index: seq.get_response_index(),
message: crate::ResponseMessage {
content: text_new.map(ToString::to_string),
role: "assistant".to_string(),
tool_calls,
},
logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
};
seq.add_choice_to_group(choice);
} else {
let choice = crate::CompletionChoice {
finish_reason: reason.to_string(),
index: seq.get_response_index(),
text,
logprobs: None,
};
seq.add_completion_choice_to_group(choice);
logprobs.push(resp_logprob);
}
Some(logprobs)
} else {
None
};

if use_prefix_cacher {
prefix_cacher.add_sequence(seq);
prefix_cacher.evict_to_cpu()?;
let text = match reason {
crate::sequence::StopReason::Length(_)
| crate::sequence::StopReason::ModelLength(_)
| crate::sequence::StopReason::Eos
| crate::sequence::StopReason::StopTok(_)
| crate::sequence::StopReason::Canceled => {
String::from_utf8_lossy(seq.completion_bytes())
.trim_start()
.to_string()
}

let group = seq.get_mut_group();
if group.is_chat {
group
.maybe_send_chat_done_response(
crate::ChatCompletionResponse {
id: seq.id().to_string(),
choices: group.get_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name,
system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
)
.await
.map_err(candle_core::Error::msg)?;
} else {
group
.maybe_send_completion_done_response(
crate::CompletionResponse {
id: seq.id().to_string(),
choices: group.get_completion_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name,
system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
object: "text_completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
)
.await
.map_err(candle_core::Error::msg)?;
crate::sequence::StopReason::StopString {
completion_bytes_pos,
..
} => {
let txt = String::from_utf8_lossy(seq.completion_bytes());
txt[..completion_bytes_pos].trim_start().to_string()
}
crate::sequence::StopReason::GeneratedImage => {
candle_core::bail!("Stop reason was `GeneratedImage`.")
}
};

if seq.get_mut_group().is_chat {
let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone())
.map_err(candle_core::Error::msg)?;
let choice = crate::Choice {
finish_reason: reason.to_string(),
index: seq.get_response_index(),
message: crate::ResponseMessage {
content: text_new.map(ToString::to_string),
role: "assistant".to_string(),
tool_calls,
},
logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
};
seq.add_choice_to_group(choice);
} else {
let choice = crate::CompletionChoice {
finish_reason: reason.to_string(),
index: seq.get_response_index(),
text,
logprobs: None,
};
seq.add_completion_choice_to_group(choice);
}

if use_prefix_cacher {
prefix_cacher.add_sequence(seq);
prefix_cacher.evict_to_cpu()?;
}

let group = seq.get_mut_group();
if group.is_chat {
group
.maybe_send_chat_done_response(
crate::ChatCompletionResponse {
id: seq.id().to_string(),
choices: group.get_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name,
system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
)
.await
.map_err(candle_core::Error::msg)?;
} else {
group
.maybe_send_completion_done_response(
crate::CompletionResponse {
id: seq.id().to_string(),
choices: group.get_completion_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name,
system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
object: "text_completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
)
.await
.map_err(candle_core::Error::msg)?;
}
this.reset_non_granular_state();
}
this.reset_non_granular_state();
}

Ok(())
Expand Down

0 comments on commit 8d89c14

Please sign in to comment.