Skip to content

Commit

Permalink
feat: Support tool calling for SGLang and Groq (#1512)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan authored Jan 26, 2025
1 parent 4745e99 commit 2790d24
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion camel/configs/gemini_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class GeminiConfig(BaseConfig):
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
tool_choice: Optional[Union[dict[str, str], str]] = None
tool_choice: Optional[Union[dict[str, str], str, NotGiven]] = NOT_GIVEN

def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
Expand Down
4 changes: 4 additions & 0 deletions camel/configs/sglang_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class SGLangConfig(BaseConfig):
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
tools (list[FunctionTool], optional): A list of tools the model may
call. Currently, only functions are supported as a tool. Use this
to provide a list of functions the model may generate JSON inputs
for. A max of 128 functions are supported.
"""

stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
Expand Down
2 changes: 0 additions & 2 deletions camel/models/groq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def token_counter(self) -> BaseTokenCounter:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
# Make sure you have the access to these open-source model in
# HuggingFace
if not self._token_counter:
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
return self._token_counter
Expand Down
45 changes: 36 additions & 9 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,11 @@ class ModelType(UnifiedModelType, Enum):

# Groq platform models
GROQ_LLAMA_3_1_8B = "llama-3.1-8b-instant"
GROQ_LLAMA_3_1_70B = "llama-3.1-70b-versatile"
GROQ_LLAMA_3_1_405B = "llama-3.1-405b-reasoning"
GROQ_LLAMA_3_3_70B = "llama-3.3-70b-versatile"
GROQ_LLAMA_3_3_70B_PREVIEW = "llama-3.3-70b-specdec"
GROQ_LLAMA_3_8B = "llama3-8b-8192"
GROQ_LLAMA_3_70B = "llama3-70b-8192"
GROQ_MIXTRAL_8_7B = "mixtral-8x7b-32768"
GROQ_GEMMA_7B_IT = "gemma-7b-it"
GROQ_GEMMA_2_9B_IT = "gemma2-9b-it"

# TogetherAI platform models support tool calling
Expand All @@ -67,6 +64,17 @@ class ModelType(UnifiedModelType, Enum):
SAMBA_LLAMA_3_1_70B = "Meta-Llama-3.1-70B-Instruct"
SAMBA_LLAMA_3_1_405B = "Meta-Llama-3.1-405B-Instruct"

# SGLang models support tool calling
SGLANG_LLAMA_3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"
SGLANG_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
SGLANG_LLAMA_3_1_405B = "meta-llama/Meta-Llama-3.1-405B-Instruct"
SGLANG_LLAMA_3_2_1B = "meta-llama/Llama-3.2-1B-Instruct"
SGLANG_MIXTRAL_NEMO = "mistralai/Mistral-Nemo-Instruct-2407"
SGLANG_MISTRAL_7B = "mistralai/Mistral-7B-Instruct-v0.3"
SGLANG_QWEN_2_5_7B = "Qwen/Qwen2.5-7B-Instruct"
SGLANG_QWEN_2_5_32B = "Qwen/Qwen2.5-32B-Instruct"
SGLANG_QWEN_2_5_72B = "Qwen/Qwen2.5-72B-Instruct"

STUB = "stub"

# Legacy anthropic models
Expand Down Expand Up @@ -190,6 +198,8 @@ def support_native_tool_calling(self) -> bool:
self.is_internlm,
self.is_together,
self.is_sambanova,
self.is_groq,
self.is_sglang,
]
)

Expand Down Expand Up @@ -252,14 +262,11 @@ def is_groq(self) -> bool:
r"""Returns whether this type of models is served by Groq."""
return self in {
ModelType.GROQ_LLAMA_3_1_8B,
ModelType.GROQ_LLAMA_3_1_70B,
ModelType.GROQ_LLAMA_3_1_405B,
ModelType.GROQ_LLAMA_3_3_70B,
ModelType.GROQ_LLAMA_3_3_70B_PREVIEW,
ModelType.GROQ_LLAMA_3_8B,
ModelType.GROQ_LLAMA_3_70B,
ModelType.GROQ_MIXTRAL_8_7B,
ModelType.GROQ_GEMMA_7B_IT,
ModelType.GROQ_GEMMA_2_9B_IT,
}

Expand Down Expand Up @@ -413,6 +420,20 @@ def is_internlm(self) -> bool:
ModelType.INTERNLM2_PRO_CHAT,
}

@property
def is_sglang(self) -> bool:
return self in {
ModelType.SGLANG_LLAMA_3_1_8B,
ModelType.SGLANG_LLAMA_3_1_70B,
ModelType.SGLANG_LLAMA_3_1_405B,
ModelType.SGLANG_LLAMA_3_2_1B,
ModelType.SGLANG_MIXTRAL_NEMO,
ModelType.SGLANG_MISTRAL_7B,
ModelType.SGLANG_QWEN_2_5_7B,
ModelType.SGLANG_QWEN_2_5_32B,
ModelType.SGLANG_QWEN_2_5_72B,
}

@property
def token_limit(self) -> int:
r"""Returns the maximum token limit for a given model.
Expand Down Expand Up @@ -440,7 +461,6 @@ def token_limit(self) -> int:
ModelType.GROQ_LLAMA_3_8B,
ModelType.GROQ_LLAMA_3_70B,
ModelType.GROQ_LLAMA_3_3_70B_PREVIEW,
ModelType.GROQ_GEMMA_7B_IT,
ModelType.GROQ_GEMMA_2_9B_IT,
ModelType.GLM_3_TURBO,
ModelType.GLM_4,
Expand Down Expand Up @@ -479,6 +499,7 @@ def token_limit(self) -> int:
ModelType.INTERNLM2_5_LATEST,
ModelType.INTERNLM2_PRO_CHAT,
ModelType.TOGETHER_MIXTRAL_8_7B,
ModelType.SGLANG_MISTRAL_7B,
}:
return 32_768
elif self in {
Expand Down Expand Up @@ -518,19 +539,25 @@ def token_limit(self) -> int:
ModelType.NVIDIA_LLAMA3_3_70B_INSTRUCT,
ModelType.GROQ_LLAMA_3_3_70B,
ModelType.SAMBA_LLAMA_3_1_70B,
ModelType.SGLANG_LLAMA_3_1_8B,
ModelType.SGLANG_LLAMA_3_1_70B,
ModelType.SGLANG_LLAMA_3_1_405B,
ModelType.SGLANG_LLAMA_3_2_1B,
ModelType.SGLANG_MIXTRAL_NEMO,
}:
return 128_000
elif self in {
ModelType.GROQ_LLAMA_3_1_8B,
ModelType.GROQ_LLAMA_3_1_70B,
ModelType.GROQ_LLAMA_3_1_405B,
ModelType.QWEN_PLUS,
ModelType.QWEN_TURBO,
ModelType.QWEN_CODER_TURBO,
ModelType.TOGETHER_LLAMA_3_1_8B,
ModelType.TOGETHER_LLAMA_3_1_70B,
ModelType.TOGETHER_LLAMA_3_1_405B,
ModelType.TOGETHER_LLAMA_3_3_70B,
ModelType.SGLANG_QWEN_2_5_7B,
ModelType.SGLANG_QWEN_2_5_32B,
ModelType.SGLANG_QWEN_2_5_72B,
}:
return 131_072
elif self in {
Expand Down
2 changes: 1 addition & 1 deletion examples/models/role_playing_with_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ def main(model_type=None) -> None:


if __name__ == "__main__":
main(model_type=ModelType.GROQ_LLAMA_3_1_70B)
main(model_type=ModelType.GROQ_LLAMA_3_1_8B)

0 comments on commit 2790d24

Please sign in to comment.