Skip to content

Commit

Permalink
Update PythonListCustomToolGenerator to support overriding system prompt
Browse files Browse the repository at this point in the history
Summary:

Support user supplied prompt template in PythonListCustomToolGenerator. This is to allow user to provided their own system prompt without having to format function descirptions.

Test Plan:
python -m unittest llama_models.llama3.tests.prompt_templates.test_system_prompts
  • Loading branch information
ehhuang committed Feb 6, 2025
1 parent ecf2f12 commit 83e4fa3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 8 deletions.
35 changes: 27 additions & 8 deletions models/llama3/prompt_templates/system_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import textwrap
from datetime import datetime
from typing import Any, List
from typing import Any, List, Optional

from llama_models.llama3.api.datatypes import (
BuiltinTool,
Expand Down Expand Up @@ -215,14 +215,33 @@ def data_examples(self) -> List[List[ToolDefinition]]:


class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
DEFAULT_PROMPT = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
{{ function_description }}
""".strip(
"\n"
)
)

def gen(
self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None
) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)

def _gen_function_description(
self, custom_tools: List[ToolDefinition]
) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Expand Down Expand Up @@ -263,7 +282,7 @@ def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
return PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
)
).render()

def data_examples(self) -> List[List[ToolDefinition]]:
return [
Expand Down
46 changes: 46 additions & 0 deletions models/llama3/tests/prompt_templates/test_system_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,49 @@ def test_llama_3_2_system_zero_shot(self):
"""
)
self.check_generator_output(generator, expected_text.strip("\n"))

def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
expected_text = textwrap.dedent(
"""
Overriding message.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": ["city"],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]"""
)
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]

pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"

0 comments on commit 83e4fa3

Please sign in to comment.