Skip to content

Commit

Permalink
fix: remove accidentally included guideline from rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Nov 22, 2024
1 parent 33284f8 commit c17635f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 27 deletions.
17 changes: 3 additions & 14 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use std::collections::HashSet;

/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Expand All @@ -15,7 +14,6 @@ pub(crate) struct ChatTemplate {
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
variables: HashSet<String>,
}

impl ChatTemplate {
Expand Down Expand Up @@ -47,21 +45,14 @@ impl ChatTemplate {
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
variables,
}
}

pub(crate) fn apply(
&self,
guideline: Option<&str>,
mut messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
// check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
}

let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
// check if the `tools` variable is used in the template
Expand All @@ -88,7 +79,6 @@ impl ChatTemplate {
let mut rendered_template = self
.template
.render(ChatTemplateInputs {
guideline,
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
Expand Down Expand Up @@ -782,7 +772,6 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
guideline: Some("Do not use offensive language."),
..Default::default()
},
target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
Expand Down Expand Up @@ -843,7 +832,7 @@ mod tests {
},
];

let result = ct.apply(None, msgs, None);
let result = ct.apply(msgs, None);

match result {
Ok(_) => panic!("Should have failed since no guideline is provided"),
Expand Down Expand Up @@ -885,7 +874,7 @@ mod tests {
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
Expand Down Expand Up @@ -919,7 +908,7 @@ mod tests {
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
Expand Down
3 changes: 1 addition & 2 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,13 @@ impl Infer {
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
guideline: Option<String>,
messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(guideline.as_deref(), messages, tools_and_prompt)
.apply(messages, tools_and_prompt)
.map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}");
Expand Down
14 changes: 3 additions & 11 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,11 +909,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,

/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,

/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
Expand All @@ -934,7 +929,6 @@ impl ChatRequest {
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
Expand Down Expand Up @@ -962,7 +956,7 @@ impl ChatRequest {

let (inputs, grammar, using_tools) = match response_format {
Some(format) => {
let inputs = infer.apply_chat_template(guideline, messages, None)?;
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, Some(format), false)
}
None => {
Expand All @@ -971,21 +965,20 @@ impl ChatRequest {
Some((updated_tools, tool_schema)) => {
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt)),
)?;
(inputs, Some(grammar), true)
}
None => {
// same as if no response_format or tools are set
let inputs = infer.apply_chat_template(guideline, messages, None)?;
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
}
}
} else {
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
}
}
Expand Down Expand Up @@ -1163,7 +1156,6 @@ pub(crate) struct ChatTemplateInputs<'a> {
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<Vec<Tool>>,
guideline: Option<&'a str>,
}

#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
Expand Down

0 comments on commit c17635f

Please sign in to comment.