From b0fe522a3ddee4649a98ee54d30457b532de9641 Mon Sep 17 00:00:00 2001 From: Will <651833+WillBeebe@users.noreply.github.com> Date: Mon, 15 Jul 2024 15:57:08 -0700 Subject: [PATCH] expose experimental streaming methods for anthropic, openai, and groq --- src/abcs/anthropic.py | 61 ++++++++++++++++++++++++++++++++++++++++- src/abcs/llm.py | 11 +++++++- src/abcs/models.py | 11 +++++++- src/abcs/openai.py | 63 +++++++++++++++++++++++++++++++++++++++++-- src/agents/agent.py | 32 +++++++++++++++++++++- 5 files changed, 172 insertions(+), 6 deletions(-) diff --git a/src/abcs/anthropic.py b/src/abcs/anthropic.py index 847f3a3..32536d7 100644 --- a/src/abcs/anthropic.py +++ b/src/abcs/anthropic.py @@ -1,10 +1,11 @@ +import asyncio import logging import os from typing import Any, Dict, List, Optional import anthropic from abcs.llm import LLM -from abcs.models import PromptResponse, UsageStats +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats from tools.tool_manager import ToolManager logging.basicConfig(level=logging.INFO) @@ -142,3 +143,61 @@ def _translate_response(self, response) -> PromptResponse: except Exception as e: logger.exception(f"error: {e}\nresponse: {response}") raise e + + async def generate_text_stream( + self, + prompt: str, + past_messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> StreamingPromptResponse: + combined_history = past_messages + [{"role": "user", "content": prompt}] + + try: + stream = self.client.messages.create( + model=self.model, + max_tokens=4096, + messages=combined_history, + system=self.system_prompt, + stream=True, + ) + + async def content_generator(): + for event in stream: + if isinstance(event, anthropic.types.MessageStartEvent): + # Message start event, we can ignore this + pass + elif isinstance(event, anthropic.types.ContentBlockStartEvent): + # Content block start event, we can ignore this + pass + elif isinstance(event, anthropic.types.ContentBlockDeltaEvent): + # This is the event that contains the actual text + if event.delta.text: + yield event.delta.text + elif isinstance(event, anthropic.types.ContentBlockStopEvent): + # Content block stop event, we can ignore this + pass + elif isinstance(event, anthropic.types.MessageStopEvent): + # Message stop event, we can ignore this + pass + # Small delay to allow for cooperative multitasking + await asyncio.sleep(0) + + return StreamingPromptResponse( + content=content_generator(), + raw_response=stream, + error={}, + usage=UsageStats( + input_tokens=0, # These will need to be updated after streaming + output_tokens=0, + extra={}, + ), + ) + except Exception as e: + logger.exception(f"An error occurred while streaming from Claude: {e}") + raise e + + async def handle_tool_call(self, tool_calls, combined_history, tools): + # This is a placeholder for handling tool calls in streaming context + # You'll need to implement the logic to execute the tool call and generate a response + pass diff --git a/src/abcs/llm.py b/src/abcs/llm.py index 62b0f5c..c7866c6 100644 --- a/src/abcs/llm.py +++ b/src/abcs/llm.py @@ -4,7 +4,7 @@ from importlib import resources import yaml -from abcs.models import PromptResponse +from abcs.models import PromptResponse, StreamingPromptResponse from abcs.tools import gen_anthropic, gen_cohere, gen_google, gen_openai # Add the project root to the Python path @@ -33,6 +33,15 @@ def generate_text(self, """Generates text based on the given prompt and additional arguments.""" pass + @abstractmethod + async def generate_text_stream(self, + prompt: str, + past_messages, + tools, + **kwargs) -> StreamingPromptResponse: + """Generates streaming text based on the given prompt and additional arguments.""" + pass + @abstractmethod def call_tool(self, past_messages, tool_msg) -> str: """Calls a specific tool with the given arguments and returns the response.""" diff --git a/src/abcs/models.py b/src/abcs/models.py index 930f8c3..59390a4 100644 --- a/src/abcs/models.py +++ b/src/abcs/models.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, AsyncIterator from pydantic import BaseModel @@ -29,3 +29,12 @@ class OllamaResponse(BaseModel): prompt_eval_duration: int eval_count: int eval_duration: int + +class StreamingPromptResponse(BaseModel): + content: AsyncIterator[str] + raw_response: Any + error: object + usage: UsageStats + + class Config: + arbitrary_types_allowed = True diff --git a/src/abcs/openai.py b/src/abcs/openai.py index c521918..14a0b9e 100644 --- a/src/abcs/openai.py +++ b/src/abcs/openai.py @@ -1,11 +1,14 @@ +import asyncio import json import logging import os from typing import Any, Dict, List, Optional -import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401 +# todo: need to support this for multi tool use, maybe upstream package has it fixed now. +# commented out because it's not working with streams +# import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401 from abcs.llm import LLM -from abcs.models import PromptResponse, UsageStats +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats from openai import OpenAI from tools.tool_manager import ToolManager @@ -188,3 +191,59 @@ def _translate_response(self, response) -> PromptResponse: # logger.error("An error occurred while translating OpenAI response: %s", e, exc_info=True) logger.exception(f"error: {e}\nresponse: {response}") raise e + + # https://cookbook.openai.com/examples/how_to_stream_completions + async def generate_text_stream( + self, + prompt: str, + past_messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> StreamingPromptResponse: + system_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt else [] + combined_history = system_message + past_messages + [{"role": "user", "content": prompt}] + + try: + stream = self.client.chat.completions.create( + model=self.model, + messages=combined_history, + tools=tools, + stream=True, + ) + + async def content_generator(): + for event in stream: + # print("HERE\n"*30) + # print(event) + if event.choices[0].delta.content is not None: + yield event.choices[0].delta.content + # Small delay to allow for cooperative multitasking + await asyncio.sleep(0) + + # # After the stream is complete, you might want to handle tool calls here + # # This is a simplification and may need to be adjusted based on your needs + # if tools and collected_content.strip().startswith('{"function":'): + # # Handle tool calls (simplified example) + # tool_response = await self.handle_tool_call(collected_content, combined_history, tools) + # yield tool_response + + return StreamingPromptResponse( + content=content_generator(), + raw_response=stream, + error={}, + usage=UsageStats( + input_tokens=0, # These will need to be updated after streaming + output_tokens=0, + extra={}, + ), + ) + except Exception as e: + logger.error("Error generating text stream: %s", e, exc_info=True) + raise e + + async def handle_tool_call(self, collected_content, combined_history, tools): + # This is a placeholder for handling tool calls in streaming context + # You'll need to implement the logic to parse the tool call, execute it, + # and generate a response based on the tool's output + # This might involve breaking the streaming and making a new API call + pass diff --git a/src/agents/agent.py b/src/agents/agent.py index d73c8e3..56435c8 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -1,7 +1,7 @@ import logging from abcs.llm import LLM -from abcs.models import PromptResponse +from abcs.models import PromptResponse, StreamingPromptResponse # from metrics.main import call_tool_counter, generate_text_counter from storage.storage_manager import StorageManager @@ -13,6 +13,8 @@ class Agent(LLM): def __init__(self, client, tool_manager: ToolManager, system_prompt: str = "", tools=[], storage_manager: StorageManager = None): + if len(tools) == 0 and (client.provider == "openai" or client.provider == "groq"): + tools = None self.tools = tools logger.debug("Initializing Agent with tools: %s and system prompt: '%s'", tools, system_prompt) super().__init__( @@ -90,3 +92,31 @@ def _translate_response(self, response) -> PromptResponse: # except Exception as e: # logger.error("Error translating response: %s", e, exc_info=True) # raise e + + async def generate_text_stream(self, + prompt: str, + **kwargs) -> StreamingPromptResponse: + """Generates streaming text based on the given prompt and additional arguments.""" + past_messages = [] + if self.storage_manager is not None: + past_messages = self.storage_manager.get_past_messages() + logger.debug("Fetched %d past messages", len(past_messages)) + if self.storage_manager is not None: + self.storage_manager.store_message("user", prompt) + try: + response = await self.client.generate_text_stream(prompt, past_messages, self.tools) + except Exception as err: + if self.storage_manager is not None: + self.storage_manager.remove_last() + raise err + + # TODO: can't do this with streaming. have to handle this in the API + # if self.storage_manager is not None: + # try: + # # translated = self._translate_response(response) + # self.storage_manager.store_message("assistant", response.content) + # except Exception as e: + # logger.error("Error storing messages: %s", e, exc_info=True) + # raise e + + return response