diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index 5d910740d..0879273f3 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -125,132 +125,132 @@ pub(crate) async fn finish_or_add_toks_to_seq( this.reset_non_granular_state(); } } - } else if let Some(reason) = is_done { - /* - *********************** - Finish the sequence now - *********************** - */ - { - seq.set_state(crate::sequence::SequenceState::Done(reason)); - let (tokenizer, pipeline_name) = { - let pipeline_name = this.name(); - let tokenizer = this.tokenizer(); - (tokenizer, pipeline_name) - }; + } + } else if let Some(reason) = is_done { + /* + *********************** + Finish the sequence now + *********************** + */ + { + seq.set_state(crate::sequence::SequenceState::Done(reason)); + let (tokenizer, pipeline_name) = { + let pipeline_name = this.name(); + let tokenizer = this.tokenizer(); + (tokenizer, pipeline_name) + }; - let logprobs = if seq.return_logprobs() { - let mut logprobs = Vec::new(); - for logprob in seq.logprobs() { - let resp_logprob = crate::ResponseLogprob { + let logprobs = if seq.return_logprobs() { + let mut logprobs = Vec::new(); + for logprob in seq.logprobs() { + let resp_logprob = crate::ResponseLogprob { token: crate::handle_seq_error_ok!( - tokenizer - .as_ref() - .ok_or(candle_core::Error::Msg( - "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer" - .to_string(), - ))?.decode(&[logprob.token], false), - seq.responder() - ), + tokenizer + .as_ref() + .ok_or(candle_core::Error::Msg( + "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer" + .to_string(), + ))?.decode(&[logprob.token], false), + seq.responder() + ), bytes: logprob.bytes.clone().map(|b| b.into_bytes()), logprob: logprob.logprob, top_logprobs: logprob.top_logprobs.clone().unwrap(), }; - logprobs.push(resp_logprob); - } - Some(logprobs) - } else { - None - }; - - let text = match reason { - crate::sequence::StopReason::Length(_) - | crate::sequence::StopReason::ModelLength(_) - | crate::sequence::StopReason::Eos - | crate::sequence::StopReason::StopTok(_) - | crate::sequence::StopReason::Canceled => { - String::from_utf8_lossy(seq.completion_bytes()) - .trim_start() - .to_string() - } - crate::sequence::StopReason::StopString { - completion_bytes_pos, - .. - } => { - let txt = String::from_utf8_lossy(seq.completion_bytes()); - txt[..completion_bytes_pos].trim_start().to_string() - } - crate::sequence::StopReason::GeneratedImage => { - candle_core::bail!("Stop reason was `GeneratedImage`.") - } - }; - - if seq.get_mut_group().is_chat { - let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone()) - .map_err(candle_core::Error::msg)?; - let choice = crate::Choice { - finish_reason: reason.to_string(), - index: seq.get_response_index(), - message: crate::ResponseMessage { - content: text_new.map(ToString::to_string), - role: "assistant".to_string(), - tool_calls, - }, - logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }), - }; - seq.add_choice_to_group(choice); - } else { - let choice = crate::CompletionChoice { - finish_reason: reason.to_string(), - index: seq.get_response_index(), - text, - logprobs: None, - }; - seq.add_completion_choice_to_group(choice); + logprobs.push(resp_logprob); } + Some(logprobs) + } else { + None + }; - if use_prefix_cacher { - prefix_cacher.add_sequence(seq); - prefix_cacher.evict_to_cpu()?; + let text = match reason { + crate::sequence::StopReason::Length(_) + | crate::sequence::StopReason::ModelLength(_) + | crate::sequence::StopReason::Eos + | crate::sequence::StopReason::StopTok(_) + | crate::sequence::StopReason::Canceled => { + String::from_utf8_lossy(seq.completion_bytes()) + .trim_start() + .to_string() } - - let group = seq.get_mut_group(); - if group.is_chat { - group - .maybe_send_chat_done_response( - crate::ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: pipeline_name, - system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ) - .await - .map_err(candle_core::Error::msg)?; - } else { - group - .maybe_send_completion_done_response( - crate::CompletionResponse { - id: seq.id().to_string(), - choices: group.get_completion_choices().to_vec(), - created: seq.creation_time(), - model: pipeline_name, - system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(), - object: "text_completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ) - .await - .map_err(candle_core::Error::msg)?; + crate::sequence::StopReason::StopString { + completion_bytes_pos, + .. + } => { + let txt = String::from_utf8_lossy(seq.completion_bytes()); + txt[..completion_bytes_pos].trim_start().to_string() + } + crate::sequence::StopReason::GeneratedImage => { + candle_core::bail!("Stop reason was `GeneratedImage`.") } + }; + + if seq.get_mut_group().is_chat { + let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone()) + .map_err(candle_core::Error::msg)?; + let choice = crate::Choice { + finish_reason: reason.to_string(), + index: seq.get_response_index(), + message: crate::ResponseMessage { + content: text_new.map(ToString::to_string), + role: "assistant".to_string(), + tool_calls, + }, + logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }), + }; + seq.add_choice_to_group(choice); + } else { + let choice = crate::CompletionChoice { + finish_reason: reason.to_string(), + index: seq.get_response_index(), + text, + logprobs: None, + }; + seq.add_completion_choice_to_group(choice); + } + + if use_prefix_cacher { + prefix_cacher.add_sequence(seq); + prefix_cacher.evict_to_cpu()?; + } + + let group = seq.get_mut_group(); + if group.is_chat { + group + .maybe_send_chat_done_response( + crate::ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: pipeline_name, + system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ) + .await + .map_err(candle_core::Error::msg)?; + } else { + group + .maybe_send_completion_done_response( + crate::CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: pipeline_name, + system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ) + .await + .map_err(candle_core::Error::msg)?; } - this.reset_non_granular_state(); } + this.reset_non_granular_state(); } Ok(())