Skip to content

Commit

Permalink
expose experimental streaming methods for anthropic, openai, and groq
Browse files Browse the repository at this point in the history
  • Loading branch information
WillBeebe committed Jul 15, 2024
1 parent c1a8bef commit b0fe522
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 6 deletions.
61 changes: 60 additions & 1 deletion src/abcs/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional

import anthropic
from abcs.llm import LLM
from abcs.models import PromptResponse, UsageStats
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
from tools.tool_manager import ToolManager

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -142,3 +143,61 @@ def _translate_response(self, response) -> PromptResponse:
except Exception as e:
logger.exception(f"error: {e}\nresponse: {response}")
raise e

async def generate_text_stream(
self,
prompt: str,
past_messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> StreamingPromptResponse:
combined_history = past_messages + [{"role": "user", "content": prompt}]

try:
stream = self.client.messages.create(
model=self.model,
max_tokens=4096,
messages=combined_history,
system=self.system_prompt,
stream=True,
)

async def content_generator():
for event in stream:
if isinstance(event, anthropic.types.MessageStartEvent):
# Message start event, we can ignore this
pass
elif isinstance(event, anthropic.types.ContentBlockStartEvent):
# Content block start event, we can ignore this
pass
elif isinstance(event, anthropic.types.ContentBlockDeltaEvent):
# This is the event that contains the actual text
if event.delta.text:
yield event.delta.text
elif isinstance(event, anthropic.types.ContentBlockStopEvent):
# Content block stop event, we can ignore this
pass
elif isinstance(event, anthropic.types.MessageStopEvent):
# Message stop event, we can ignore this
pass
# Small delay to allow for cooperative multitasking
await asyncio.sleep(0)

return StreamingPromptResponse(
content=content_generator(),
raw_response=stream,
error={},
usage=UsageStats(
input_tokens=0, # These will need to be updated after streaming
output_tokens=0,
extra={},
),
)
except Exception as e:
logger.exception(f"An error occurred while streaming from Claude: {e}")
raise e

async def handle_tool_call(self, tool_calls, combined_history, tools):
# This is a placeholder for handling tool calls in streaming context
# You'll need to implement the logic to execute the tool call and generate a response
pass
11 changes: 10 additions & 1 deletion src/abcs/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from importlib import resources

import yaml
from abcs.models import PromptResponse
from abcs.models import PromptResponse, StreamingPromptResponse
from abcs.tools import gen_anthropic, gen_cohere, gen_google, gen_openai

# Add the project root to the Python path
Expand Down Expand Up @@ -33,6 +33,15 @@ def generate_text(self,
"""Generates text based on the given prompt and additional arguments."""
pass

@abstractmethod
async def generate_text_stream(self,
prompt: str,
past_messages,
tools,
**kwargs) -> StreamingPromptResponse:
"""Generates streaming text based on the given prompt and additional arguments."""
pass

@abstractmethod
def call_tool(self, past_messages, tool_msg) -> str:
"""Calls a specific tool with the given arguments and returns the response."""
Expand Down
11 changes: 10 additions & 1 deletion src/abcs/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, AsyncIterator

from pydantic import BaseModel

Expand Down Expand Up @@ -29,3 +29,12 @@ class OllamaResponse(BaseModel):
prompt_eval_duration: int
eval_count: int
eval_duration: int

class StreamingPromptResponse(BaseModel):
content: AsyncIterator[str]
raw_response: Any
error: object
usage: UsageStats

class Config:
arbitrary_types_allowed = True
63 changes: 61 additions & 2 deletions src/abcs/openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional

import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401
# todo: need to support this for multi tool use, maybe upstream package has it fixed now.
# commented out because it's not working with streams
# import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401
from abcs.llm import LLM
from abcs.models import PromptResponse, UsageStats
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
from openai import OpenAI
from tools.tool_manager import ToolManager

Expand Down Expand Up @@ -188,3 +191,59 @@ def _translate_response(self, response) -> PromptResponse:
# logger.error("An error occurred while translating OpenAI response: %s", e, exc_info=True)
logger.exception(f"error: {e}\nresponse: {response}")
raise e

# https://cookbook.openai.com/examples/how_to_stream_completions
async def generate_text_stream(
self,
prompt: str,
past_messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> StreamingPromptResponse:
system_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt else []
combined_history = system_message + past_messages + [{"role": "user", "content": prompt}]

try:
stream = self.client.chat.completions.create(
model=self.model,
messages=combined_history,
tools=tools,
stream=True,
)

async def content_generator():
for event in stream:
# print("HERE\n"*30)
# print(event)
if event.choices[0].delta.content is not None:
yield event.choices[0].delta.content
# Small delay to allow for cooperative multitasking
await asyncio.sleep(0)

# # After the stream is complete, you might want to handle tool calls here
# # This is a simplification and may need to be adjusted based on your needs
# if tools and collected_content.strip().startswith('{"function":'):
# # Handle tool calls (simplified example)
# tool_response = await self.handle_tool_call(collected_content, combined_history, tools)
# yield tool_response

return StreamingPromptResponse(
content=content_generator(),
raw_response=stream,
error={},
usage=UsageStats(
input_tokens=0, # These will need to be updated after streaming
output_tokens=0,
extra={},
),
)
except Exception as e:
logger.error("Error generating text stream: %s", e, exc_info=True)
raise e

async def handle_tool_call(self, collected_content, combined_history, tools):
# This is a placeholder for handling tool calls in streaming context
# You'll need to implement the logic to parse the tool call, execute it,
# and generate a response based on the tool's output
# This might involve breaking the streaming and making a new API call
pass
32 changes: 31 additions & 1 deletion src/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from abcs.llm import LLM
from abcs.models import PromptResponse
from abcs.models import PromptResponse, StreamingPromptResponse

# from metrics.main import call_tool_counter, generate_text_counter
from storage.storage_manager import StorageManager
Expand All @@ -13,6 +13,8 @@

class Agent(LLM):
def __init__(self, client, tool_manager: ToolManager, system_prompt: str = "", tools=[], storage_manager: StorageManager = None):
if len(tools) == 0 and (client.provider == "openai" or client.provider == "groq"):
tools = None
self.tools = tools
logger.debug("Initializing Agent with tools: %s and system prompt: '%s'", tools, system_prompt)
super().__init__(
Expand Down Expand Up @@ -90,3 +92,31 @@ def _translate_response(self, response) -> PromptResponse:
# except Exception as e:
# logger.error("Error translating response: %s", e, exc_info=True)
# raise e

async def generate_text_stream(self,
prompt: str,
**kwargs) -> StreamingPromptResponse:
"""Generates streaming text based on the given prompt and additional arguments."""
past_messages = []
if self.storage_manager is not None:
past_messages = self.storage_manager.get_past_messages()
logger.debug("Fetched %d past messages", len(past_messages))
if self.storage_manager is not None:
self.storage_manager.store_message("user", prompt)
try:
response = await self.client.generate_text_stream(prompt, past_messages, self.tools)
except Exception as err:
if self.storage_manager is not None:
self.storage_manager.remove_last()
raise err

# TODO: can't do this with streaming. have to handle this in the API
# if self.storage_manager is not None:
# try:
# # translated = self._translate_response(response)
# self.storage_manager.store_message("assistant", response.content)
# except Exception as e:
# logger.error("Error storing messages: %s", e, exc_info=True)
# raise e

return response

0 comments on commit b0fe522

Please sign in to comment.