Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PythonListCustomToolGenerator to support overriding system prompt #271

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions models/llama3/prompt_templates/system_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import textwrap
from datetime import datetime
from typing import Any, List
from typing import Any, List, Optional

from llama_models.llama3.api.datatypes import (
BuiltinTool,
Expand Down Expand Up @@ -215,14 +215,33 @@ def data_examples(self) -> List[List[ToolDefinition]]:


class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
DEFAULT_PROMPT = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.

{{ function_description }}
""".strip(
"\n"
)
)

def gen(
self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None
) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)

def _gen_function_description(
self, custom_tools: List[ToolDefinition]
) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.

If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.

Expand Down Expand Up @@ -263,7 +282,7 @@ def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
return PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
)
).render()

def data_examples(self) -> List[List[ToolDefinition]]:
return [
Expand Down
46 changes: 46 additions & 0 deletions models/llama3/tests/prompt_templates/test_system_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,49 @@ def test_llama_3_2_system_zero_shot(self):
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))

def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Overriding message.

If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.

Here is a list of functions in JSON format that you can invoke.

[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]"""
)
user_system_prompt = textwrap.dedent(
"""
Overriding message.

{{ function_description }}
"""
)
example = generator.data_examples()[0]

pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"