Skip to content
This repository has been archived by the owner on Jan 12, 2025. It is now read-only.

Commit

Permalink
refactor: code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
7HR4IZ3 committed Jul 31, 2024
1 parent 40112d9 commit d55bd5d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
64 changes: 63 additions & 1 deletion examples/anthrophic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import asyncio
from anthropic import Anthropic, AsyncAnthropic

from lunary.anthrophic import monitor
import lunary
from lunary.anthrophic import monitor, parse_message

def sync_non_streaming():
client = Anthropic()
Expand Down Expand Up @@ -228,6 +229,65 @@ async def async_tool_calls():
print(f"snapshot: {event.snapshot}")


def reconcilliation_tool_calls():
from anthropic import Anthropic
from anthropic.types import ToolParam, MessageParam

thread = lunary.open_thread()
client = monitor(Anthropic())

user_message: MessageParam = {
"role": "user",
"content": "What is the weather in San Francisco, California?",
}
tools: list[ToolParam] = [
{
"name": "get_weather",
"description": "Get the weather for a specific location",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
]

message_id = thread.track_message(user_message)

with lunary.parent(message_id):
message = client.messages.create(
model="claude-3-opus-20240229",
max_tokens=1024,
messages=[user_message],
tools=tools,
)
print(f"Initial response: {message.model_dump_json(indent=2)}")

assert message.stop_reason == "tool_use"

tool = next(c for c in message.content if c.type == "tool_use")
response = client.messages.create(
model="claude-3-opus-20240229",
max_tokens=1024,
messages=[
user_message,
{"role": message.role, "content": message.content},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool.id,
"content": [{"type": "text", "text": "The weather is 73f"}],
}
],
},
],
tools=tools,
)
print(f"\nFinal response: {response.model_dump_json(indent=2)}")



# sync_non_streaming()
# asyncio.run(async_non_streaming())

Expand All @@ -243,3 +303,5 @@ async def async_tool_calls():

# tool_calls()
# asyncio.run(async_tool_calls())

reconcilliation_tool_calls()
8 changes: 2 additions & 6 deletions lunary/anthrophic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __params_parser(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
}


def __parse_message_content(message: MessageParam):
def parse_message(message: MessageParam):
role = message.get("role")
content = message.get("content")

Expand Down Expand Up @@ -120,7 +120,7 @@ def __input_parser(kwargs: t.Dict):
inputs.append({ "role": "system", "content": item.get("text") })

for message in kwargs.get("messages", []):
inputs.extend(__parse_message_content(message))
inputs.extend(parse_message(message))

return {"input": inputs, "name": kwargs.get("model")}

Expand Down Expand Up @@ -176,7 +176,6 @@ def __iter__(self):

def __iterator__(self):
for event in self.__stream.__stream__():
print("\n", event)
if event.type == "message_start":
self.__messages.append(
{
Expand Down Expand Up @@ -320,7 +319,6 @@ def __aiter__(self):

async def __iterator__(self):
async for event in self.__stream.__stream__():
print("\n", event)
if event.type == "message_start":
self.__messages.append(
{
Expand Down Expand Up @@ -578,7 +576,6 @@ def __wrap_sync(
is_openai=False,
)
except Exception as e:
raise e
return logging.exception(e)

if contextify_stream or kwargs.get("stream") == True:
Expand Down Expand Up @@ -672,7 +669,6 @@ async def __wrap_async(
is_openai=False,
)
except Exception as e:
raise e
return logging.exception(e)

if contextify_stream or kwargs.get("stream") == True:
Expand Down

0 comments on commit d55bd5d

Please sign in to comment.