Skip to content

Commit

Permalink
Update for tools use
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 6, 2024
1 parent 3f3fa2b commit 90fa2e0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import json
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
Expand All @@ -7,13 +8,14 @@

from anthropic import Anthropic, Stream
from anthropic.types import (
ContentBlock,
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStreamEvent,
TextBlock,
TextDelta,
ToolUseBlock,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,6 +63,8 @@ class AnthropicChatGenerator:
# The parameters that can be passed to the Anthropic API https://docs.anthropic.com/claude/reference/messages_post
ALLOWED_PARAMS: ClassVar[List[str]] = [
"system",
"tools",
"tool_choice",
"max_tokens",
"metadata",
"stop_sequences",
Expand All @@ -75,6 +79,7 @@ def __init__(
model: str = "claude-3-sonnet-20240229",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
ignore_tools_thinking_messages: bool = True,
):
"""
Creates an instance of AnthropicChatGenerator.
Expand All @@ -95,13 +100,15 @@ def __init__(
- `temperature`: The temperature to use for sampling.
- `top_p`: The top_p value to use for nucleus sampling.
- `top_k`: The top_k value to use for top-k sampling.
:param ignore_tools_thinking_messages: If True, the generator will drop so-called thinking messages when tool
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use) for more details.
"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.client = Anthropic(api_key=self.api_key.resolve_value())
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -203,18 +210,24 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
completions = [self._connect_chunks(chunks, start_event, delta)]
# if streaming is disabled, the response is an Anthropic Message
elif isinstance(response, Message):
has_tools_msgs = any(isinstance(content_block, ToolUseBlock) for content_block in response.content)
if has_tools_msgs and self.ignore_tools_thinking_messages:
response.content = [block for block in response.content if isinstance(block, ToolUseBlock)]
completions = [self._build_message(content_block, response) for content_block in response.content]

return {"replies": completions}

def _build_message(self, content_block: ContentBlock, message: Message) -> ChatMessage:
def _build_message(self, content_block: Union[TextBlock, ToolUseBlock], message: Message) -> ChatMessage:
"""
Converts the non-streaming Anthropic Message to a ChatMessage.
:param content_block: The content block of the message.
:param message: The non-streaming Anthropic Message.
:returns: The ChatMessage.
"""
chat_message = ChatMessage.from_assistant(content_block.text)
if isinstance(content_block, TextBlock):
chat_message = ChatMessage.from_assistant(content_block.text)
else:
chat_message = ChatMessage.from_assistant(json.dumps(content_block.model_dump(mode="json")))
chat_message.meta.update(
{
"model": message.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,12 @@ def _build_message(self, cohere_response):
:param cohere_response: The completion returned by the Cohere API.
:returns: The ChatMessage.
"""
content = cohere_response.text
message = ChatMessage.from_assistant(content=content)

message = None
if cohere_response.tool_calls:
# TODO revisit to see if we need to handle multiple tool calls
message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json())
elif cohere_response.text:
message = ChatMessage.from_assistant(content=cohere_response.text)
total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens
message.meta.update(
{
Expand Down

0 comments on commit 90fa2e0

Please sign in to comment.