From c1c8c5368ad18b53395e6e5a588d26380e928269 Mon Sep 17 00:00:00 2001 From: Thraize Date: Sun, 7 Jul 2024 21:19:46 +0000 Subject: [PATCH 1/8] feat: Support for streaming and not streaming, sync and async anthropic clients Signed-off-by: Thraize --- examples/__init__.py | 0 examples/anthrophic.py | 87 +++++++++++ lunary/__init__.py | 43 ++++++ lunary/anthrophic.py | 322 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 10 ++ 5 files changed, 462 insertions(+) create mode 100644 examples/__init__.py create mode 100644 examples/anthrophic.py create mode 100644 lunary/anthrophic.py create mode 100644 requirements.txt diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/anthrophic.py b/examples/anthrophic.py new file mode 100644 index 0000000..2bfbdf5 --- /dev/null +++ b/examples/anthrophic.py @@ -0,0 +1,87 @@ + +import os +import asyncio +from dotenv import load_dotenv +from anthropic import Anthropic, AsyncAnthropic + +from lunary.anthrophic import monitor + +load_dotenv() + +def sync_non_streaming(): + client = Anthropic() + monitor(client) + + message = client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-opus-20240229", + ) + print(message.content) + + +async def async_non_streaming(): + client = monitor(AsyncAnthropic()) + + message = await client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-opus-20240229", + ) + print(message.content) + + +def sync_streaming(): + client = monitor(Anthropic()) + + stream = client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-opus-20240229", + stream=True, + ) + for event in stream: + print(event) + + + + +async def async_streaming(): + client = monitor(AsyncAnthropic()) + + stream = await client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-opus-20240229", + stream=True, + ) + async for event in stream: + print(event) + + +# sync_non_streaming() +# asyncio.run(async_non_streaming()) + +# sync_streaming() +asyncio.run(async_streaming()) + diff --git a/lunary/__init__.py b/lunary/__init__.py index ef4f859..7329d3a 100644 --- a/lunary/__init__.py +++ b/lunary/__init__.py @@ -1455,6 +1455,49 @@ async def get_raw_template_async(slug: str, app_id: str | None = None, api_url: templateCache[slug] = {'timestamp': now, 'data': data} return data +def get_templates(live_only: Optional[bool] = False, app_id: Optional[str] = None, api_url: Optional[str] = None): + config = get_config() + token = app_id or config.app_id + api_url = api_url or config.api_url + + headers = { + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json' + } + + response = requests.get( + f"{api_url}/v1/templates{'?live=true' if live_only else ''}", + headers=headers, verify=config.ssl_verify + ) + if not response.ok: + logger.exception(f"Error fetching template: {response.status_code} - {response.text}") + + data = response.json() + +async def get_templates_async(live_only: Optional[bool]=False, app_id: Optional[str] = None, api_url: Optional[str] = None): + config = get_config() + token = app_id or config.app_id + api_url = api_url or config.api_url + + headers = { + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json' + } + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{api_url}/v1/templates{'?live=true' if live_only else ''}", + headers=headers + ) as response: + if not response.ok: + raise Exception(f"Lunary: Error fetching template: {response.status} - {await response.text()}") + + data = await response.json() + +def get_langchain_templates(app_id: Optional[str] = None, api_url: Optional[str] = None): + config = get_config() + token = app_id or config.app_id + api_url = api_url or config.api_url def render_template(slug: str, data = {}): try: diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py new file mode 100644 index 0000000..1d17f6a --- /dev/null +++ b/lunary/anthrophic.py @@ -0,0 +1,322 @@ +import typing as t +from functools import partial +from . import track_event, run_context, run_manager, logging, logger, user_props_ctx, user_ctx, traceback, tags_ctx, filter_params + +try: + from anthropic import Anthropic, AsyncAnthropic + from anthropic.types import Message +except ImportError: + raise ImportError("Anthrophic SDK not installed!") from None + + +def __input_parser(kwargs: t.Dict): + return {"input": kwargs.get("messages"), "name": kwargs.get("model")} + + +def __output_parser(output: t.Union[Message], stream: bool = False): + if isinstance(output, Message): + return { + "name": + output.model, + "tokensUsage": + output.usage, + "output": [{ + "content": content.text, + "role": output.role + } for content in output.content], + } + else: + return { + "name": None, + "tokenUsage": None, + "output": getattr(output, "content", output) + } + + +def __stream_handler(method, run_id, name, type, *args, **kwargs): + messages = [] + stream = method(*args, **kwargs) + + for event in stream: + if event.type == "message_start": + # print(event.message.model) + messages.append({ + "role": event.message.role, + "model": event.message.model + }) + if event.type == "message_delta": + # print("*", event.usage.output_tokens) + if len(messages) >= 1: + message = messages[-1] + message["usage"] = {"tokens": event.usage.output_tokens} + + if event.type == "message_stop": pass + if event.type == "content_block_start": + # print("* START") + # print(event.content_block.text) + if len(messages) >= 1: + message = messages[-1] + message["output"] = event.content_block.text + + if event.type == "content_block_delta": + # print(event.delta.text, end="") + if len(messages) >= 1: + message = messages[-1] + message["output"] = message.get("output", + "") + event.delta.text + + if event.type == "content_block_stop": + # print("* END") + pass + + yield event + + track_event( + type, + "end", + run_id, + name=name, + output=[{ + "role": message["role"], + "content": message["output"] + } for message in messages], + token_usage=sum([message["usage"]["tokens"] for message in messages]), + ) + + +async def __async_stream_handler(method, run_id, name, type, *args, **kwargs): + messages = [] + stream = await method(*args, **kwargs) + + async for event in stream: + if event.type == "message_start": + # print(event.message.model) + messages.append({ + "role": event.message.role, + "model": event.message.model + }) + if event.type == "message_delta": + # print("*", event.usage.output_tokens) + if len(messages) >= 1: + message = messages[-1] + message["usage"] = {"tokens": event.usage.output_tokens} + + if event.type == "message_stop": pass + if event.type == "content_block_start": + # print("* START") + # print(event.content_block.text) + if len(messages) >= 1: + message = messages[-1] + message["output"] = event.content_block.text + + if event.type == "content_block_delta": + # print(event.delta.text, end="") + if len(messages) >= 1: + message = messages[-1] + message["output"] = message.get("output", + "") + event.delta.text + + if event.type == "content_block_stop": + # print("* END") + pass + + yield event + + track_event( + type, + "end", + run_id, + name=name, + output=[{ + "role": message["role"], + "content": message["output"] + } for message in messages], + token_usage=sum([message["usage"]["tokens"] for message in messages]), + ) + + +def __wrap_sync(method: t.Callable, + type: t.Optional[str] = None, + user_id: t.Optional[str] = None, + user_props: t.Optional[dict] = None, + tags: t.Optional[dict] = None, + name: t.Optional[str] = None, + run_id: t.Optional[str] = None, + input_parser=__input_parser, + output_parser=__output_parser, + stream_handler=__stream_handler, + *args, + **kwargs): + output = None + + parent_run_id = kwargs.pop("parent", None) + run = run_manager.start_run(run_id, parent_run_id) + + with run_context(run.id): + try: + try: + params = filter_params(kwargs) + metadata = kwargs.get("metadata") + parsed_input = input_parser(kwargs) + + track_event(type, + "start", + run_id=run.id, + parent_run_id=parent_run_id, + input=parsed_input["input"], + name=name or parsed_input["name"], + user_id=(kwargs.pop("user_id", None) + or user_ctx.get() or user_id), + user_props=(kwargs.pop("user_props", None) + or user_props or user_props_ctx.get()), + params=params, + metadata=metadata, + tags=(kwargs.pop("tags", None) or tags + or tags_ctx.get()), + template_id=(kwargs.get("extra_headers", {}).get( + "Template-Id", None)), + is_openai=False) + except Exception as e: + logging.exception(e) + + if kwargs.get("stream") == True: + return stream_handler(method, run.id, name + or parsed_input["name"], type, *args, + **kwargs) + + try: + output = method(*args, **kwargs) + except Exception as e: + track_event( + type, + "error", + run.id, + error={ + "message": str(e), + "stack": traceback.format_exc() + }, + ) + raise e from None + + try: + parsed_output = output_parser(output, + kwargs.get("stream", False)) + + track_event( + type, + "end", + run.id, + # In case need to compute tokens usage server side + name=name or parsed_input["name"], + output=parsed_output["output"], + token_usage=parsed_output["tokensUsage"], + ) + return output + except Exception as e: + logger.exception(e)(e) + finally: + return output + finally: + run_manager.end_run(run.id) + + +async def __wrap_async(method: t.Callable, + type: t.Optional[str] = None, + user_id: t.Optional[str] = None, + user_props: t.Optional[dict] = None, + tags: t.Optional[dict] = None, + name: t.Optional[str] = None, + run_id: t.Optional[str] = None, + input_parser=__input_parser, + output_parser=__output_parser, + stream_handler=__async_stream_handler, + *args, + **kwargs): + output = None + + parent_run_id = kwargs.pop("parent", None) + run = run_manager.start_run(run_id, parent_run_id) + + with run_context(run.id): + try: + try: + params = filter_params(kwargs) + metadata = kwargs.get("metadata") + parsed_input = input_parser(kwargs) + + track_event(type, + "start", + run_id=run.id, + parent_run_id=parent_run_id, + input=parsed_input["input"], + name=name or parsed_input["name"], + user_id=(kwargs.pop("user_id", None) + or user_ctx.get() or user_id), + user_props=(kwargs.pop("user_props", None) + or user_props or user_props_ctx.get()), + params=params, + metadata=metadata, + tags=(kwargs.pop("tags", None) or tags + or tags_ctx.get()), + template_id=(kwargs.get("extra_headers", {}).get( + "Template-Id", None)), + is_openai=True) + except Exception as e: + logging.exception(e) + + if kwargs.get("stream") == True: + return stream_handler(method, run.id, name + or parsed_input["name"], type, + *args, **kwargs) + + try: + output = await method(*args, **kwargs) + except Exception as e: + track_event( + type, + "error", + run.id, + error={ + "message": str(e), + "stack": traceback.format_exc() + }, + ) + raise e from None + + try: + parsed_output = output_parser(output, + kwargs.get("stream", False)) + + track_event( + type, + "end", + run.id, + # In case need to compute tokens usage server side + name=name or parsed_input["name"], + output=parsed_output["output"], + token_usage=parsed_output["tokensUsage"], + ) + return output + except Exception as e: + logger.exception(e)(e) + finally: + return output + finally: + run_manager.end_run(run.id) + + +if t.TYPE_CHECKING: + ClientType = t.TypeVar("ClientType") + + +def monitor(client: "ClientType") -> "ClientType": + if isinstance(client, Anthropic): + client.messages.create = partial(__wrap_sync, client.messages.create, + "llm") + elif isinstance(client, AsyncAnthropic): + client.messages.create = partial(__wrap_async, client.messages.create, + "llm") + else: + raise Exception( + "Invalid argument. Expected instance of Anthropic Client") + return client diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..876341d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +requests +setuptools +tenacity +packaging +chevron +pyhumps +aiohttp +jsonpickle +openai +langchain \ No newline at end of file From d9a81b994e59cdde554e22191077699272335df4 Mon Sep 17 00:00:00 2001 From: Thraize Date: Sun, 7 Jul 2024 21:38:56 +0000 Subject: [PATCH 2/8] fix: Fixed support for using metadatas with anthropic clients Signed-off-by: Thraize --- examples/anthrophic.py | 86 +++++++++++++++++++++++++----------------- lunary/anthrophic.py | 14 +++++++ 2 files changed, 65 insertions(+), 35 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index 2bfbdf5..ad25e59 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -1,4 +1,3 @@ - import os import asyncio from dotenv import load_dotenv @@ -8,50 +7,45 @@ load_dotenv() -def sync_non_streaming(): + +def test_sync_non_streaming(): client = Anthropic() monitor(client) message = client.messages.create( max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], model="claude-3-opus-20240229", ) print(message.content) -async def async_non_streaming(): +async def test_async_non_streaming(): client = monitor(AsyncAnthropic()) message = await client.messages.create( max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], model="claude-3-opus-20240229", ) print(message.content) -def sync_streaming(): +def test_sync_streaming(): client = monitor(Anthropic()) stream = client.messages.create( max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], model="claude-3-opus-20240229", stream=True, ) @@ -59,19 +53,15 @@ def sync_streaming(): print(event) - - -async def async_streaming(): +async def test_async_streaming(): client = monitor(AsyncAnthropic()) stream = await client.messages.create( max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], model="claude-3-opus-20240229", stream=True, ) @@ -79,9 +69,35 @@ async def async_streaming(): print(event) -# sync_non_streaming() -# asyncio.run(async_non_streaming()) +def test_extra_arguments(): + client = Anthropic() + monitor(client) + + message = client.messages.create( + max_tokens=1024, + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], + model="claude-3-opus-20240229", + tags=["translate"], + user_id="user123", + user_props={ + "name": "John Doe", + }, + metadata={ + "test": "hello", + "isTest": True, + "testAmount": 123, + }, + ) + print(message.content) + + +# test_sync_non_streaming() +# test_asyncio.run(async_non_streaming()) -# sync_streaming() -asyncio.run(async_streaming()) +# test_sync_streaming() +# test_asyncio.run(async_streaming()) +test_extra_arguments() \ No newline at end of file diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index 1d17f6a..5395df3 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -135,6 +135,12 @@ async def __async_stream_handler(method, run_id, name, type, *args, **kwargs): ) +def __metadata_parser(metadata): + return { + x: metadata[x] for x in metadata if x in ["user_id"] + } + + def __wrap_sync(method: t.Callable, type: t.Optional[str] = None, user_id: t.Optional[str] = None, @@ -145,6 +151,7 @@ def __wrap_sync(method: t.Callable, input_parser=__input_parser, output_parser=__output_parser, stream_handler=__stream_handler, + metadata_parser=__metadata_parser, *args, **kwargs): output = None @@ -159,6 +166,9 @@ def __wrap_sync(method: t.Callable, metadata = kwargs.get("metadata") parsed_input = input_parser(kwargs) + if metadata: + kwargs["metadata"] = metadata_parser(metadata) + track_event(type, "start", run_id=run.id, @@ -230,6 +240,7 @@ async def __wrap_async(method: t.Callable, input_parser=__input_parser, output_parser=__output_parser, stream_handler=__async_stream_handler, + metadata_parser=__metadata_parser, *args, **kwargs): output = None @@ -244,6 +255,9 @@ async def __wrap_async(method: t.Callable, metadata = kwargs.get("metadata") parsed_input = input_parser(kwargs) + if metadata: + kwargs["metadata"] = metadata_parser(metadata) + track_event(type, "start", run_id=run.id, From b056a39a3ade5e7832fe1901077ee7a8d7d47fdc Mon Sep 17 00:00:00 2001 From: Thraize Date: Tue, 9 Jul 2024 12:48:42 +0000 Subject: [PATCH 3/8] feat: Added support for stream helper function Signed-off-by: Thraize --- examples/anthrophic.py | 40 ++++++++++++++- lunary/anthrophic.py | 110 ++++++++++++++++++++++++++++++----------- 2 files changed, 120 insertions(+), 30 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index ad25e59..1f07d5b 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -69,6 +69,41 @@ async def test_async_streaming(): print(event) +def test_sync_stream_helper(): + client = Anthropic() + monitor(client) + + with client.messages.stream( + max_tokens=1024, + messages=[{ + "role": "user", + "content": "Hello, Claude", + }], + model="claude-3-opus-20240229", + ) as stream: + for event in stream: + print(event) + +async def test_async_stream_helper(): + client = monitor(AsyncAnthropic()) + + async with client.messages.stream( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="claude-3-opus-20240229", + ) as stream: + async for event in stream: + print(event) + + message = await stream.get_final_message() + print(message.to_json()) + + def test_extra_arguments(): client = Anthropic() monitor(client) @@ -100,4 +135,7 @@ def test_extra_arguments(): # test_sync_streaming() # test_asyncio.run(async_streaming()) -test_extra_arguments() \ No newline at end of file +# test_extra_arguments() + +# test_sync_stream_helper() +asyncio.run(test_async_stream_helper()) \ No newline at end of file diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index 5395df3..38dc9fd 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -1,14 +1,42 @@ import typing as t from functools import partial +from inspect import iscoroutine +from contextlib import AsyncContextDecorator, ContextDecorator + from . import track_event, run_context, run_manager, logging, logger, user_props_ctx, user_ctx, traceback, tags_ctx, filter_params try: from anthropic import Anthropic, AsyncAnthropic from anthropic.types import Message + from anthropic.lib.streaming import MessageStreamManager, AsyncMessageStreamManager except ImportError: raise ImportError("Anthrophic SDK not installed!") from None +class sync_context_wrapper(ContextDecorator): + + def __init__(self, stream): + self.__stream = stream + + def __enter__(self): + return self.__stream + + def __exit__(self, *_): + return + + +class async_context_wrapper(AsyncContextDecorator): + + def __init__(self, stream): + self.__stream = stream + + async def __aenter__(self): + return self.__stream + + async def __aexit__(self, *_): + return + + def __input_parser(kwargs: t.Dict): return {"input": kwargs.get("messages"), "name": kwargs.get("model")} @@ -35,42 +63,44 @@ def __output_parser(output: t.Union[Message], stream: bool = False): def __stream_handler(method, run_id, name, type, *args, **kwargs): messages = [] + original_stream = None stream = method(*args, **kwargs) + if isinstance(stream, MessageStreamManager): + original_stream = stream + stream = original_stream.__enter__() + for event in stream: if event.type == "message_start": - # print(event.message.model) messages.append({ "role": event.message.role, "model": event.message.model }) if event.type == "message_delta": - # print("*", event.usage.output_tokens) if len(messages) >= 1: message = messages[-1] message["usage"] = {"tokens": event.usage.output_tokens} if event.type == "message_stop": pass if event.type == "content_block_start": - # print("* START") - # print(event.content_block.text) if len(messages) >= 1: message = messages[-1] message["output"] = event.content_block.text if event.type == "content_block_delta": - # print(event.delta.text, end="") if len(messages) >= 1: message = messages[-1] message["output"] = message.get("output", "") + event.delta.text if event.type == "content_block_stop": - # print("* END") pass yield event + if original_stream: + original_stream.__exit__(None, None, None) + track_event( type, "end", @@ -86,42 +116,47 @@ def __stream_handler(method, run_id, name, type, *args, **kwargs): async def __async_stream_handler(method, run_id, name, type, *args, **kwargs): messages = [] - stream = await method(*args, **kwargs) + original_stream = None + stream = method(*args, **kwargs) + + if iscoroutine(stream): + stream = await stream + + if isinstance(stream, AsyncMessageStreamManager): + original_stream = stream + stream = await original_stream.__aenter__() async for event in stream: if event.type == "message_start": - # print(event.message.model) messages.append({ "role": event.message.role, "model": event.message.model }) if event.type == "message_delta": - # print("*", event.usage.output_tokens) if len(messages) >= 1: message = messages[-1] message["usage"] = {"tokens": event.usage.output_tokens} if event.type == "message_stop": pass if event.type == "content_block_start": - # print("* START") - # print(event.content_block.text) if len(messages) >= 1: message = messages[-1] message["output"] = event.content_block.text if event.type == "content_block_delta": - # print(event.delta.text, end="") if len(messages) >= 1: message = messages[-1] message["output"] = message.get("output", "") + event.delta.text if event.type == "content_block_stop": - # print("* END") pass yield event + if original_stream: + await original_stream.__aexit__(None, None, None) + track_event( type, "end", @@ -136,9 +171,7 @@ async def __async_stream_handler(method, run_id, name, type, *args, **kwargs): def __metadata_parser(metadata): - return { - x: metadata[x] for x in metadata if x in ["user_id"] - } + return {x: metadata[x] for x in metadata if x in ["user_id"]} def __wrap_sync(method: t.Callable, @@ -152,6 +185,7 @@ def __wrap_sync(method: t.Callable, output_parser=__output_parser, stream_handler=__stream_handler, metadata_parser=__metadata_parser, + contextify_stream: t.Optional[t.Callable] = None, *args, **kwargs): output = None @@ -189,10 +223,13 @@ def __wrap_sync(method: t.Callable, except Exception as e: logging.exception(e) - if kwargs.get("stream") == True: - return stream_handler(method, run.id, name - or parsed_input["name"], type, *args, - **kwargs) + if contextify_stream or kwargs.get("stream") == True: + generator = stream_handler(method, run.id, name + or parsed_input["name"], type, + *args, **kwargs) + if contextify_stream: + return contextify_stream(generator) + else: return generator try: output = method(*args, **kwargs) @@ -241,6 +278,7 @@ async def __wrap_async(method: t.Callable, output_parser=__output_parser, stream_handler=__async_stream_handler, metadata_parser=__metadata_parser, + contextify_stream: t.Optional[bool] = False, *args, **kwargs): output = None @@ -274,14 +312,17 @@ async def __wrap_async(method: t.Callable, or tags_ctx.get()), template_id=(kwargs.get("extra_headers", {}).get( "Template-Id", None)), - is_openai=True) + is_openai=False) except Exception as e: logging.exception(e) - if kwargs.get("stream") == True: - return stream_handler(method, run.id, name - or parsed_input["name"], type, - *args, **kwargs) + if contextify_stream or kwargs.get("stream") == True: + generator = stream_handler(method, run.id, name + or parsed_input["name"], type, + *args, **kwargs) + if contextify_stream: + return contextify_stream(generator) + else: return generator try: output = await method(*args, **kwargs) @@ -325,11 +366,22 @@ async def __wrap_async(method: t.Callable, def monitor(client: "ClientType") -> "ClientType": if isinstance(client, Anthropic): - client.messages.create = partial(__wrap_sync, client.messages.create, - "llm") + client.messages.create = partial(__wrap_sync, + client.messages.create, + type="llm") + client.messages.stream = partial(__wrap_sync, + client.messages.stream, + type="llm", + contextify_stream=sync_context_wrapper) elif isinstance(client, AsyncAnthropic): - client.messages.create = partial(__wrap_async, client.messages.create, - "llm") + client.messages.create = partial(__wrap_async, + client.messages.create, + type="llm") + client.messages.stream = partial(__wrap_sync, + client.messages.stream, + type="llm", + stream_handler=__async_stream_handler, + contextify_stream=async_context_wrapper) else: raise Exception( "Invalid argument. Expected instance of Anthropic Client") From 8dd3f9b8fb97bdfa6824fdf42317c5f36cf4770a Mon Sep 17 00:00:00 2001 From: Thraize Date: Wed, 31 Jul 2024 12:23:19 -0700 Subject: [PATCH 4/8] feat: Added tool_use support Signed-off-by: Thraize --- examples/anthrophic.py | 144 ++++++- lunary/__init__.py | 1 - lunary/anthrophic.py | 846 +++++++++++++++++++++++++++++------------ 3 files changed, 725 insertions(+), 266 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index 1f07d5b..f40a65d 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -1,14 +1,10 @@ import os import asyncio -from dotenv import load_dotenv from anthropic import Anthropic, AsyncAnthropic from lunary.anthrophic import monitor -load_dotenv() - - -def test_sync_non_streaming(): +def sync_non_streaming(): client = Anthropic() monitor(client) @@ -23,7 +19,7 @@ def test_sync_non_streaming(): print(message.content) -async def test_async_non_streaming(): +async def async_non_streaming(): client = monitor(AsyncAnthropic()) message = await client.messages.create( @@ -37,7 +33,7 @@ async def test_async_non_streaming(): print(message.content) -def test_sync_streaming(): +def sync_streaming(): client = monitor(Anthropic()) stream = client.messages.create( @@ -53,7 +49,7 @@ def test_sync_streaming(): print(event) -async def test_async_streaming(): +async def async_streaming(): client = monitor(AsyncAnthropic()) stream = await client.messages.create( @@ -69,7 +65,7 @@ async def test_async_streaming(): print(event) -def test_sync_stream_helper(): +def sync_stream_helper(): client = Anthropic() monitor(client) @@ -84,7 +80,7 @@ def test_sync_stream_helper(): for event in stream: print(event) -async def test_async_stream_helper(): +async def async_stream_helper(): client = monitor(AsyncAnthropic()) async with client.messages.stream( @@ -97,14 +93,15 @@ async def test_async_stream_helper(): ], model="claude-3-opus-20240229", ) as stream: - async for event in stream: - print(event) + async for text in stream.text_stream: + print(text, end="", flush=True) + print() message = await stream.get_final_message() print(message.to_json()) -def test_extra_arguments(): +def extra_arguments(): client = Anthropic() monitor(client) @@ -129,13 +126,120 @@ def test_extra_arguments(): print(message.content) -# test_sync_non_streaming() -# test_asyncio.run(async_non_streaming()) +def anthrophic_bedrock(): + from anthropic import AnthropicBedrock + + client = AnthropicBedrock() + + message = client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello!", + } + ], + model="anthropic.claude-3-sonnet-20240229-v1:0", + ) + print(message) + +def tool_calls(): + from anthropic import Anthropic + from anthropic.types import ToolParam, MessageParam + + client = monitor(Anthropic()) + + user_message: MessageParam = { + "role": "user", + "content": "What is the weather in San Francisco, California?", + } + tools: list[ToolParam] = [ + { + "name": "get_weather", + "description": "Get the weather for a specific location", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + } + ] + + message = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + messages=[user_message], + tools=tools, + ) + print(f"Initial response: {message.model_dump_json(indent=2)}") + + assert message.stop_reason == "tool_use" + + tool = next(c for c in message.content if c.type == "tool_use") + response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + messages=[ + user_message, + {"role": message.role, "content": message.content}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool.id, + "content": [{"type": "text", "text": "The weather is 73f"}], + } + ], + }, + ], + tools=tools, + ) + print(f"\nFinal response: {response.model_dump_json(indent=2)}") + + +async def async_tool_calls(): + client = monitor(AsyncAnthropic()) + async with client.messages.stream( + max_tokens=1024, + model="claude-3-haiku-20240307", + tools=[ + { + "name": "get_weather", + "description": "Get the weather at a specific location", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Unit for the output", + }, + }, + "required": ["location"], + }, + } + ], + messages=[{"role": "user", "content": "What is the weather in SF?"}], + ) as stream: + async for event in stream: + if event.type == "input_json": + print(f"delta: {repr(event.partial_json)}") + print(f"snapshot: {event.snapshot}") + + +# sync_non_streaming() +# asyncio.run(async_non_streaming()) + +# sync_streaming() +# asyncio.run(async_streaming()) + +# extra_arguments() -# test_sync_streaming() -# test_asyncio.run(async_streaming()) +# sync_stream_helper() +# asyncio.run(async_stream_helper()) -# test_extra_arguments() +# # anthrophic_bedrock() -# test_sync_stream_helper() -asyncio.run(test_async_stream_helper()) \ No newline at end of file +# tool_calls() +# asyncio.run(async_tool_calls()) diff --git a/lunary/__init__.py b/lunary/__init__.py index 7329d3a..99f70b6 100644 --- a/lunary/__init__.py +++ b/lunary/__init__.py @@ -143,7 +143,6 @@ def track_event( "appId": app_id } - if callback_queue is not None: callback_queue.append(event) else: diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index 38dc9fd..c46a684 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -3,191 +3,544 @@ from inspect import iscoroutine from contextlib import AsyncContextDecorator, ContextDecorator -from . import track_event, run_context, run_manager, logging, logger, user_props_ctx, user_ctx, traceback, tags_ctx, filter_params +from . import ( + track_event, + run_context, + run_manager, + logging, + logger, + user_props_ctx, + user_ctx, + traceback, + tags_ctx, +) try: from anthropic import Anthropic, AsyncAnthropic - from anthropic.types import Message - from anthropic.lib.streaming import MessageStreamManager, AsyncMessageStreamManager + from anthropic.types import Message, ToolParam, MessageParam + from anthropic.lib.streaming import ( + MessageStreamManager, + AsyncMessageStreamManager, + MessageStream, + AsyncMessageStream, + ) except ImportError: raise ImportError("Anthrophic SDK not installed!") from None +PARAMS_TO_CAPTURE = [ + "frequency_penalty", + "function_call", + "functions", + "logit_bias", + "logprobs", + "max_tokens", + "n", + "presence_penalty", + "response_format", + "seed", + "stop", + "temperature", + "tool_choice", + "tools", + "top_logprobs", + "top_p", + # Additional params + "extra_headers", + "extra_query", + "extra_body", + "timeout", +] + + +def __parse_tools(tools: list[ToolParam]): + return [ + { + "type": "function", + "function": { + "name": tool.get("name"), + "description": tool.get("description"), + "inputSchema": tool.get("input_schema"), + }, + } + for tool in tools + ] + + +def __params_parser(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + return { + key: __parse_tools(value) if key == "tools" else value + for key, value in params.items() + if key in PARAMS_TO_CAPTURE + } + + +def __parse_message_content(message: MessageParam): + role = message.get("role") + content = message.get("content") + + # print({"role": role, "content": content}) + + if isinstance(content, str): + yield {"role": role, "content": content} + elif isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + yield {"content": item.get("text"), "role": role} + elif item.get("type") == "tool_use": + yield { + "functionCall": { + "name": item.get("name"), + "arguments": item.get("input"), + }, + "toolCallId": item.get("id"), + } + elif item.get("type") == "tool_result": + yield { + "role": "tool", + "tool_call_id": item.get("tool_use_id"), + "content": item.get("content"), + } -class sync_context_wrapper(ContextDecorator): - - def __init__(self, stream): - self.__stream = stream - - def __enter__(self): - return self.__stream - - def __exit__(self, *_): - return + else: + error = f"Invalid 'content' type for message: {message}" + raise ValueError(error) -class async_context_wrapper(AsyncContextDecorator): +def __input_parser(kwargs: t.Dict): + inputs = [] - def __init__(self, stream): - self.__stream = stream + if kwargs.get("system"): + system = kwargs.get("system") + if isinstance(system, str): + inputs.append({ "role": "system", "content": kwargs["system"] }) + elif isinstance(system, list): + for item in kwargs["system"]: + if item.get("type") == "text": + inputs.append({ "role": "system", "content": item.get("text") }) - async def __aenter__(self): - return self.__stream + for message in kwargs.get("messages", []): + inputs.extend(__parse_message_content(message)) - async def __aexit__(self, *_): - return - - -def __input_parser(kwargs: t.Dict): - return {"input": kwargs.get("messages"), "name": kwargs.get("model")} + return {"input": inputs, "name": kwargs.get("model")} def __output_parser(output: t.Union[Message], stream: bool = False): if isinstance(output, Message): return { - "name": - output.model, - "tokensUsage": - output.usage, - "output": [{ - "content": content.text, - "role": output.role - } for content in output.content], + "name": output.model, + "tokensUsage": { + "completion": output.usage.output_tokens, + "prompt": output.usage.input_tokens, + }, + "output": [ + ( + {"content": content.text, "role": output.role} + if content.type == "text" + else { + "functionCall": { + "name": content.name, + "arguments": content.input, + }, + "toolCallId": content.id, + } + ) + for content in output.content + ], } else: return { "name": None, "tokenUsage": None, - "output": getattr(output, "content", output) + "output": getattr(output, "content", output), } -def __stream_handler(method, run_id, name, type, *args, **kwargs): - messages = [] - original_stream = None - stream = method(*args, **kwargs) - - if isinstance(stream, MessageStreamManager): - original_stream = stream - stream = original_stream.__enter__() - - for event in stream: - if event.type == "message_start": - messages.append({ - "role": event.message.role, - "model": event.message.model - }) - if event.type == "message_delta": - if len(messages) >= 1: - message = messages[-1] - message["usage"] = {"tokens": event.usage.output_tokens} - - if event.type == "message_stop": pass - if event.type == "content_block_start": - if len(messages) >= 1: - message = messages[-1] - message["output"] = event.content_block.text - - if event.type == "content_block_delta": - if len(messages) >= 1: - message = messages[-1] - message["output"] = message.get("output", - "") + event.delta.text - - if event.type == "content_block_stop": - pass - - yield event - - if original_stream: - original_stream.__exit__(None, None, None) - - track_event( - type, - "end", - run_id, - name=name, - output=[{ - "role": message["role"], - "content": message["output"] - } for message in messages], - token_usage=sum([message["usage"]["tokens"] for message in messages]), - ) +class Stream: + def __init__(self, stream: MessageStream, handler: "StreamHandler"): + self.__stream = stream + self.__handler = handler + + self.__messages = [] + + # Method wrappers + self._iterator = self.__iterator__() + self.text_stream = self.__stream_text__() + + def __getattr__(self, name): + return getattr(self.__stream, name) + + def __iter__(self): + return self.__iterator__() + + def __iterator__(self): + for event in self.__stream.__stream__(): + print("\n", event) + if event.type == "message_start": + self.__messages.append( + { + "role": event.message.role, + "model": event.message.model, + "usage": { + "input": event.message.usage.input_tokens, + "output": event.message.usage.output_tokens, + }, + "content": [], + } + ) + if event.type == "message_delta": + if len(self.__messages) >= 1: + message = self.__messages[-1] + message["usage"]["tokens"] = event.usage.output_tokens + + if event.type == "message_stop": + # print("\n\n ** ", list(__parse_message_content(event.message))) + pass + + if event.type == "content_block_start": + if len(self.__messages) >= 1: + message = self.__messages[-1] + + if event.content_block.type == "text": + message["content"].insert( + event.index, + { + "type": event.content_block.type, + "content": event.content_block.text, + }, + ) + else: + message["content"].insert( + event.index, + { + "functionCall": { + "name": event.content_block.name, + "arguments": event.content_block.input, + }, + "toolCallId": event.content_block.id, + }, + ) + + if event.type == "content_block_delta": + if len(self.__messages) >= 1: + message = self.__messages[-1] + event_content = message["content"][event.index] + + if event.delta.type == "text_delta": + event_content["content"] += event.delta.text + # else: + # functionCall = event_content.get("functionCall") + # event_content.update( + # { + # "functionCall": { + # "name": functionCall["name"], + # "arguments": ( + # functionCall["arguments"] + # + event.delta.partial_json + # ), + # } + # } + # ) + + if event.type == "content_block_stop": + if hasattr(event, "content_block") and len(self.__messages) >= 1: + message = self.__messages[-1] + event_content: dict = message["content"][event.index] + + if event.content_block.type == "text": + event_content["content"] = event.content_block.text + elif event.content_block.type == "tool_use": + event_content.update( + { + "functionCall": { + "name": event.content_block.name, + "arguments": event.content_block.input, + }, + "toolCallId": event.content_block.id, + } + ) + else: + raise Exception("Invalid `content_block` type") + + yield event + + output = [] + + for message in self.__messages: + for item in message.get("content"): + content = item.get("content") + if isinstance(content, str): + output.append({"role": message["role"], "content": content}) + elif isinstance(content, list): + for sub_item in content: + output.append(sub_item) + else: + output.append(item) + + track_event( + output=output, + event_name="end", + run_type=self.__handler.__type__, + run_id=self.__handler.__run_id__, + name=self.__handler.__name__, + token_usage={ + "completion": sum( + [message["usage"]["tokens"] for message in self.__messages] + ), + "prompt": sum( + [message["usage"]["input"] for message in self.__messages] + ), + }, + ) + + def __stream_text__(self) -> t.Iterator[str]: + for chunk in self.__iterator__(): + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class AsyncStream: + + def __init__(self, stream: AsyncMessageStream, handler: "AsyncStreamHandler"): + self.__stream = stream + self.__handler = handler + + self.__messages = [] + + # Method wrappers + self._iterator = self.__iterator__() + self.text_stream = self.__stream_text__() + + def __getattr__(self, name): + return getattr(self.__stream, name) + + def __aiter__(self): + return self.__iterator__() + + async def __iterator__(self): + async for event in self.__stream.__stream__(): + print("\n", event) + if event.type == "message_start": + self.__messages.append( + { + "role": event.message.role, + "model": event.message.model, + "usage": { + "input": event.message.usage.input_tokens, + "output": event.message.usage.output_tokens, + }, + "content": [], + } + ) + if event.type == "message_delta": + if len(self.__messages) >= 1: + message = self.__messages[-1] + message["usage"]["tokens"] = event.usage.output_tokens + + if event.type == "message_stop": + # print("\n\n ** ", list(__parse_message_content(event.message))) + pass + + if event.type == "content_block_start": + if len(self.__messages) >= 1: + message = self.__messages[-1] + + if event.content_block.type == "text": + message["content"].insert( + event.index, + { + "type": event.content_block.type, + "content": event.content_block.text, + }, + ) + else: + message["content"].insert( + event.index, + { + "functionCall": { + "name": event.content_block.name, + "arguments": event.content_block.input, + }, + "toolCallId": event.content_block.id, + }, + ) + + if event.type == "content_block_delta": + if len(self.__messages) >= 1: + message = self.__messages[-1] + event_content = message["content"][event.index] + + if event.delta.type == "text_delta": + event_content["content"] += event.delta.text + # else: + # functionCall = event_content.get("functionCall") + # event_content.update( + # { + # "functionCall": { + # "name": functionCall["name"], + # "arguments": ( + # functionCall["arguments"] + # + event.delta.partial_json + # ), + # } + # } + # ) + + if event.type == "content_block_stop": + if hasattr(event, "content_block") and len(self.__messages) >= 1: + message = self.__messages[-1] + event_content: dict = message["content"][event.index] + + if event.content_block.type == "text": + event_content["content"] = event.content_block.text + elif event.content_block.type == "tool_use": + event_content.update( + { + "functionCall": { + "name": event.content_block.name, + "arguments": event.content_block.input, + }, + "toolCallId": event.content_block.id, + } + ) + else: + raise Exception("Invalid `content_block` type") + + yield event + + output = [] + + for message in self.__messages: + for item in message.get("content"): + content = item.get("content") + if isinstance(content, str): + output.append({"role": message["role"], "content": content}) + elif isinstance(content, list): + for sub_item in content: + output.append(sub_item) + else: + output.append(item) + + track_event( + output=output, + event_name="end", + run_type=self.__handler.__type__, + run_id=self.__handler.__run_id__, + name=self.__handler.__name__, + token_usage={ + "completion": sum( + [message["usage"]["tokens"] for message in self.__messages] + ), + "prompt": sum( + [message["usage"]["input"] for message in self.__messages] + ), + }, + ) + + async def __stream_text__(self) -> t.AsyncIterator[str]: + async for chunk in self.__iterator__(): + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class StreamHandler: + + __stream_manager: MessageStreamManager + + def __init__( + self, method: t.Callable, run_id: str, name: str, type: str, *args, **kwargs + ): + self.__method = method + self.__args = args + self.__kwargs = kwargs + + self.__run_id__ = run_id + self.__name__ = name + self.__type__ = type + + self.__stream_manager = self.__method(*self.__args, **self.__kwargs) -async def __async_stream_handler(method, run_id, name, type, *args, **kwargs): - messages = [] - original_stream = None - stream = method(*args, **kwargs) - - if iscoroutine(stream): - stream = await stream - - if isinstance(stream, AsyncMessageStreamManager): - original_stream = stream - stream = await original_stream.__aenter__() - - async for event in stream: - if event.type == "message_start": - messages.append({ - "role": event.message.role, - "model": event.message.model - }) - if event.type == "message_delta": - if len(messages) >= 1: - message = messages[-1] - message["usage"] = {"tokens": event.usage.output_tokens} - - if event.type == "message_stop": pass - if event.type == "content_block_start": - if len(messages) >= 1: - message = messages[-1] - message["output"] = event.content_block.text - - if event.type == "content_block_delta": - if len(messages) >= 1: - message = messages[-1] - message["output"] = message.get("output", - "") + event.delta.text - - if event.type == "content_block_stop": - pass - - yield event - - if original_stream: - await original_stream.__aexit__(None, None, None) - - track_event( - type, - "end", - run_id, - name=name, - output=[{ - "role": message["role"], - "content": message["output"] - } for message in messages], - token_usage=sum([message["usage"]["tokens"] for message in messages]), - ) + def __enter__(self): + stream = self.__stream_manager.__enter__() + return Stream(stream, self) + + def __exit__(self, *_): + self.__stream_manager.__exit__(None, None, None) + + def __getattr__(self, name): + return getattr(self.__stream_manager, name) + + def __iter__(self): + stream = Stream(self.__stream_manager, self) + return stream.__iterator__() + + +class AsyncStreamHandler: + + __stream_manager: AsyncMessageStreamManager + + def __init__( + self, method: t.Callable, run_id: str, name: str, type: str, *args, **kwargs + ): + self.__method = method + self.__args = args + self.__kwargs = kwargs + + self.__run_id__ = run_id + self.__name__ = name + self.__type__ = type + + self.__stream_manager = self.__method(*self.__args, **self.__kwargs) + + def __await__(self): + + async def _(): + if iscoroutine(self.__stream_manager): + self.__stream_manager = await self.__stream_manager + return self + + return _().__await__() + + async def __aenter__(self): + if iscoroutine(self.__stream_manager): + self.__stream_manager = await self.__stream_manager + + stream = await self.__stream_manager.__aenter__() + return AsyncStream(stream, self) + + async def __aexit__(self, *_): + await self.__stream_manager.__aexit__(None, None, None) + + def __getattr__(self, name): + return getattr(self.__stream_manager, name) + + def __aiter__(self): + stream = AsyncStream(self.__stream_manager, self) + return stream.__iterator__() def __metadata_parser(metadata): return {x: metadata[x] for x in metadata if x in ["user_id"]} -def __wrap_sync(method: t.Callable, - type: t.Optional[str] = None, - user_id: t.Optional[str] = None, - user_props: t.Optional[dict] = None, - tags: t.Optional[dict] = None, - name: t.Optional[str] = None, - run_id: t.Optional[str] = None, - input_parser=__input_parser, - output_parser=__output_parser, - stream_handler=__stream_handler, - metadata_parser=__metadata_parser, - contextify_stream: t.Optional[t.Callable] = None, - *args, - **kwargs): +def __wrap_sync( + method: t.Callable, + type: t.Optional[str] = None, + user_id: t.Optional[str] = None, + user_props: t.Optional[dict] = None, + tags: t.Optional[dict] = None, + name: t.Optional[str] = None, + run_id: t.Optional[str] = None, + input_parser=__input_parser, + output_parser=__output_parser, + params_parser=__params_parser, + stream_handler=StreamHandler, + metadata_parser=__metadata_parser, + contextify_stream: t.Optional[t.Callable] = None, + *args, + **kwargs, +): output = None parent_run_id = kwargs.pop("parent", None) @@ -196,40 +549,42 @@ def __wrap_sync(method: t.Callable, with run_context(run.id): try: try: - params = filter_params(kwargs) + params = params_parser(kwargs) metadata = kwargs.get("metadata") parsed_input = input_parser(kwargs) if metadata: kwargs["metadata"] = metadata_parser(metadata) - track_event(type, - "start", - run_id=run.id, - parent_run_id=parent_run_id, - input=parsed_input["input"], - name=name or parsed_input["name"], - user_id=(kwargs.pop("user_id", None) - or user_ctx.get() or user_id), - user_props=(kwargs.pop("user_props", None) - or user_props or user_props_ctx.get()), - params=params, - metadata=metadata, - tags=(kwargs.pop("tags", None) or tags - or tags_ctx.get()), - template_id=(kwargs.get("extra_headers", {}).get( - "Template-Id", None)), - is_openai=False) + track_event( + type, + "start", + run_id=run.id, + parent_run_id=parent_run_id, + input=parsed_input["input"], + name=name or parsed_input["name"], + user_id=(kwargs.pop("user_id", None) or user_ctx.get() or user_id), + user_props=( + kwargs.pop("user_props", None) + or user_props + or user_props_ctx.get() + ), + params=params, + metadata=metadata, + tags=(kwargs.pop("tags", None) or tags or tags_ctx.get()), + template_id=( + kwargs.get("extra_headers", {}).get("Template-Id", None) + ), + is_openai=False, + ) except Exception as e: - logging.exception(e) + raise e + return logging.exception(e) if contextify_stream or kwargs.get("stream") == True: - generator = stream_handler(method, run.id, name - or parsed_input["name"], type, - *args, **kwargs) - if contextify_stream: - return contextify_stream(generator) - else: return generator + return stream_handler( + method, run.id, name or parsed_input["name"], type, *args, **kwargs + ) try: output = method(*args, **kwargs) @@ -238,16 +593,12 @@ def __wrap_sync(method: t.Callable, type, "error", run.id, - error={ - "message": str(e), - "stack": traceback.format_exc() - }, + error={"message": str(e), "stack": traceback.format_exc()}, ) raise e from None try: - parsed_output = output_parser(output, - kwargs.get("stream", False)) + parsed_output = output_parser(output, kwargs.get("stream", False)) track_event( type, @@ -267,20 +618,23 @@ def __wrap_sync(method: t.Callable, run_manager.end_run(run.id) -async def __wrap_async(method: t.Callable, - type: t.Optional[str] = None, - user_id: t.Optional[str] = None, - user_props: t.Optional[dict] = None, - tags: t.Optional[dict] = None, - name: t.Optional[str] = None, - run_id: t.Optional[str] = None, - input_parser=__input_parser, - output_parser=__output_parser, - stream_handler=__async_stream_handler, - metadata_parser=__metadata_parser, - contextify_stream: t.Optional[bool] = False, - *args, - **kwargs): +async def __wrap_async( + method: t.Callable, + type: t.Optional[str] = None, + user_id: t.Optional[str] = None, + user_props: t.Optional[dict] = None, + tags: t.Optional[dict] = None, + name: t.Optional[str] = None, + run_id: t.Optional[str] = None, + input_parser=__input_parser, + output_parser=__output_parser, + params_parser=__params_parser, + stream_handler=AsyncStreamHandler, + metadata_parser=__metadata_parser, + contextify_stream: t.Optional[bool] = False, + *args, + **kwargs, +): output = None parent_run_id = kwargs.pop("parent", None) @@ -289,40 +643,42 @@ async def __wrap_async(method: t.Callable, with run_context(run.id): try: try: - params = filter_params(kwargs) + params = params_parser(kwargs) metadata = kwargs.get("metadata") parsed_input = input_parser(kwargs) if metadata: kwargs["metadata"] = metadata_parser(metadata) - track_event(type, - "start", - run_id=run.id, - parent_run_id=parent_run_id, - input=parsed_input["input"], - name=name or parsed_input["name"], - user_id=(kwargs.pop("user_id", None) - or user_ctx.get() or user_id), - user_props=(kwargs.pop("user_props", None) - or user_props or user_props_ctx.get()), - params=params, - metadata=metadata, - tags=(kwargs.pop("tags", None) or tags - or tags_ctx.get()), - template_id=(kwargs.get("extra_headers", {}).get( - "Template-Id", None)), - is_openai=False) + track_event( + type, + "start", + run_id=run.id, + parent_run_id=parent_run_id, + input=parsed_input["input"], + name=name or parsed_input["name"], + user_id=(kwargs.pop("user_id", None) or user_ctx.get() or user_id), + user_props=( + kwargs.pop("user_props", None) + or user_props + or user_props_ctx.get() + ), + params=params, + metadata=metadata, + tags=(kwargs.pop("tags", None) or tags or tags_ctx.get()), + template_id=( + kwargs.get("extra_headers", {}).get("Template-Id", None) + ), + is_openai=False, + ) except Exception as e: - logging.exception(e) + raise e + return logging.exception(e) if contextify_stream or kwargs.get("stream") == True: - generator = stream_handler(method, run.id, name - or parsed_input["name"], type, - *args, **kwargs) - if contextify_stream: - return contextify_stream(generator) - else: return generator + return await stream_handler( + method, run.id, name or parsed_input["name"], type, *args, **kwargs + ) try: output = await method(*args, **kwargs) @@ -331,16 +687,12 @@ async def __wrap_async(method: t.Callable, type, "error", run.id, - error={ - "message": str(e), - "stack": traceback.format_exc() - }, + error={"message": str(e), "stack": traceback.format_exc()}, ) raise e from None try: - parsed_output = output_parser(output, - kwargs.get("stream", False)) + parsed_output = output_parser(output, kwargs.get("stream", False)) track_event( type, @@ -366,23 +718,27 @@ async def __wrap_async(method: t.Callable, def monitor(client: "ClientType") -> "ClientType": if isinstance(client, Anthropic): - client.messages.create = partial(__wrap_sync, - client.messages.create, - type="llm") - client.messages.stream = partial(__wrap_sync, - client.messages.stream, - type="llm", - contextify_stream=sync_context_wrapper) + client.messages.create = partial( + __wrap_sync, client.messages.create, type="llm" + ) + client.messages.stream = partial( + __wrap_sync, + client.messages.stream, + type="llm", + stream_handler=StreamHandler, + contextify_stream=True, + ) elif isinstance(client, AsyncAnthropic): - client.messages.create = partial(__wrap_async, - client.messages.create, - type="llm") - client.messages.stream = partial(__wrap_sync, - client.messages.stream, - type="llm", - stream_handler=__async_stream_handler, - contextify_stream=async_context_wrapper) + client.messages.create = partial( + __wrap_async, client.messages.create, type="llm" + ) + client.messages.stream = partial( + __wrap_sync, + client.messages.stream, + type="llm", + stream_handler=AsyncStreamHandler, + contextify_stream=True, + ) else: - raise Exception( - "Invalid argument. Expected instance of Anthropic Client") + raise Exception("Invalid argument. Expected instance of Anthropic Client") return client From 90a713e7195b662a66917cd6f5ae6416c5310378 Mon Sep 17 00:00:00 2001 From: Thraize Date: Wed, 31 Jul 2024 13:07:51 -0700 Subject: [PATCH 5/8] refactor: code cleanup Signed-off-by: Thraize --- examples/anthrophic.py | 64 +++++++++++++++++++++++++++++++++++++++++- lunary/anthrophic.py | 8 ++---- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index f40a65d..fa46e3c 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -2,7 +2,8 @@ import asyncio from anthropic import Anthropic, AsyncAnthropic -from lunary.anthrophic import monitor +import lunary +from lunary.anthrophic import monitor, parse_message def sync_non_streaming(): client = Anthropic() @@ -228,6 +229,65 @@ async def async_tool_calls(): print(f"snapshot: {event.snapshot}") +def reconcilliation_tool_calls(): + from anthropic import Anthropic + from anthropic.types import ToolParam, MessageParam + + thread = lunary.open_thread() + client = monitor(Anthropic()) + + user_message: MessageParam = { + "role": "user", + "content": "What is the weather in San Francisco, California?", + } + tools: list[ToolParam] = [ + { + "name": "get_weather", + "description": "Get the weather for a specific location", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + } + ] + + message_id = thread.track_message(user_message) + + with lunary.parent(message_id): + message = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + messages=[user_message], + tools=tools, + ) + print(f"Initial response: {message.model_dump_json(indent=2)}") + + assert message.stop_reason == "tool_use" + + tool = next(c for c in message.content if c.type == "tool_use") + response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + messages=[ + user_message, + {"role": message.role, "content": message.content}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool.id, + "content": [{"type": "text", "text": "The weather is 73f"}], + } + ], + }, + ], + tools=tools, + ) + print(f"\nFinal response: {response.model_dump_json(indent=2)}") + + + # sync_non_streaming() # asyncio.run(async_non_streaming()) @@ -243,3 +303,5 @@ async def async_tool_calls(): # tool_calls() # asyncio.run(async_tool_calls()) + +reconcilliation_tool_calls() \ No newline at end of file diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index c46a684..b8a4603 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -74,7 +74,7 @@ def __params_parser(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: } -def __parse_message_content(message: MessageParam): +def parse_message(message: MessageParam): role = message.get("role") content = message.get("content") @@ -120,7 +120,7 @@ def __input_parser(kwargs: t.Dict): inputs.append({ "role": "system", "content": item.get("text") }) for message in kwargs.get("messages", []): - inputs.extend(__parse_message_content(message)) + inputs.extend(parse_message(message)) return {"input": inputs, "name": kwargs.get("model")} @@ -176,7 +176,6 @@ def __iter__(self): def __iterator__(self): for event in self.__stream.__stream__(): - print("\n", event) if event.type == "message_start": self.__messages.append( { @@ -320,7 +319,6 @@ def __aiter__(self): async def __iterator__(self): async for event in self.__stream.__stream__(): - print("\n", event) if event.type == "message_start": self.__messages.append( { @@ -578,7 +576,6 @@ def __wrap_sync( is_openai=False, ) except Exception as e: - raise e return logging.exception(e) if contextify_stream or kwargs.get("stream") == True: @@ -672,7 +669,6 @@ async def __wrap_async( is_openai=False, ) except Exception as e: - raise e return logging.exception(e) if contextify_stream or kwargs.get("stream") == True: From 06defc2a58bbbc0a416f009bd40c7156092a78f3 Mon Sep 17 00:00:00 2001 From: Thraize Date: Fri, 2 Aug 2024 22:29:28 -0700 Subject: [PATCH 6/8] feat: code cleanup and tested agent, reconciliation and threads Signed-off-by: Thraize --- examples/anthrophic.py | 98 +++++++++++++++++------- lunary/anthrophic.py | 166 +++++++++++++++++++++++++++-------------- 2 files changed, 180 insertions(+), 84 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index fa46e3c..0c655de 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -3,7 +3,8 @@ from anthropic import Anthropic, AsyncAnthropic import lunary -from lunary.anthrophic import monitor, parse_message +from lunary.anthrophic import monitor, parse_message, agent + def sync_non_streaming(): client = Anthropic() @@ -11,10 +12,12 @@ def sync_non_streaming(): message = client.messages.create( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", ) print(message.content) @@ -25,10 +28,12 @@ async def async_non_streaming(): message = await client.messages.create( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", ) print(message.content) @@ -39,10 +44,12 @@ def sync_streaming(): stream = client.messages.create( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", stream=True, ) @@ -55,10 +62,12 @@ async def async_streaming(): stream = await client.messages.create( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", stream=True, ) @@ -72,15 +81,18 @@ def sync_stream_helper(): with client.messages.stream( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", ) as stream: for event in stream: print(event) + async def async_stream_helper(): client = monitor(AsyncAnthropic()) @@ -108,10 +120,12 @@ def extra_arguments(): message = client.messages.create( max_tokens=1024, - messages=[{ - "role": "user", - "content": "Hello, Claude", - }], + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], model="claude-3-opus-20240229", tags=["translate"], user_id="user123", @@ -144,6 +158,7 @@ def anthrophic_bedrock(): ) print(message) + def tool_calls(): from anthropic import Anthropic from anthropic.types import ToolParam, MessageParam @@ -210,7 +225,10 @@ async def async_tool_calls(): "input_schema": { "type": "object", "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], @@ -229,6 +247,7 @@ async def async_tool_calls(): print(f"snapshot: {event.snapshot}") +@agent("DemoAgent") def reconcilliation_tool_calls(): from anthropic import Anthropic from anthropic.types import ToolParam, MessageParam @@ -259,15 +278,36 @@ def reconcilliation_tool_calls(): max_tokens=1024, messages=[user_message], tools=tools, + parent=message_id, ) print(f"Initial response: {message.model_dump_json(indent=2)}") assert message.stop_reason == "tool_use" tool = next(c for c in message.content if c.type == "tool_use") + + for item in ( + [ + user_message, + {"role": message.role, "content": message.content}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool.id, + "content": [{"type": "text", "text": "The weather is 73f"}], + } + ], + }, + ] + ): + thread.track_message(item) + response = client.messages.create( model="claude-3-opus-20240229", max_tokens=1024, + parent=message_id, messages=[ user_message, {"role": message.role, "content": message.content}, @@ -286,6 +326,10 @@ def reconcilliation_tool_calls(): ) print(f"\nFinal response: {response.model_dump_json(indent=2)}") + for item in parse_message(response): + thread.track_message(item) + + return response # sync_non_streaming() @@ -304,4 +348,4 @@ def reconcilliation_tool_calls(): # tool_calls() # asyncio.run(async_tool_calls()) -reconcilliation_tool_calls() \ No newline at end of file +reconcilliation_tool_calls() diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index b8a4603..b4c9244 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -52,20 +52,36 @@ ] +def __prop( + target: t.Union[t.Dict, t.Any], + property_or_keys: t.Union[t.List, str], + default_value: t.Any = None +): + if isinstance(property_or_keys, list): + value = target + for key in property_or_keys: + value = __prop(value, key) + if not value: return default_value + return value + + if isinstance(target, dict): + return target.get(property_or_keys, default_value) + return getattr(target, property_or_keys, default_value) + + def __parse_tools(tools: list[ToolParam]): return [ { "type": "function", "function": { - "name": tool.get("name"), - "description": tool.get("description"), - "inputSchema": tool.get("input_schema"), + "name": __prop(tool, "name"), + "description": __prop(tool, "description"), + "inputSchema": __prop(tool, "input_schema"), }, } for tool in tools ] - def __params_parser(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: return { key: __parse_tools(value) if key == "tools" else value @@ -75,32 +91,31 @@ def __params_parser(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: def parse_message(message: MessageParam): - role = message.get("role") - content = message.get("content") + role = __prop(message, "role", "system") + content = __prop(message, "content", message) - # print({"role": role, "content": content}) + print(role, content) if isinstance(content, str): yield {"role": role, "content": content} elif isinstance(content, list): for item in content: - if isinstance(item, dict): - if item.get("type") == "text": - yield {"content": item.get("text"), "role": role} - elif item.get("type") == "tool_use": - yield { - "functionCall": { - "name": item.get("name"), - "arguments": item.get("input"), - }, - "toolCallId": item.get("id"), - } - elif item.get("type") == "tool_result": - yield { - "role": "tool", - "tool_call_id": item.get("tool_use_id"), - "content": item.get("content"), - } + if __prop(item, "type") == "text": + yield {"content": __prop(item, "text"), "role": role} + elif __prop(item, "type") == "tool_use": + yield { + "functionCall": { + "name": __prop(item, "name"), + "arguments": __prop(item, "input"), + }, + "toolCallId": __prop(item, "id"), + } + elif __prop(item, "type") == "tool_result": + yield { + "role": "tool", + "tool_call_id": __prop(item, "tool_use_id"), + "content": __prop(item, "content"), + } else: error = f"Invalid 'content' type for message: {message}" @@ -113,47 +128,39 @@ def __input_parser(kwargs: t.Dict): if kwargs.get("system"): system = kwargs.get("system") if isinstance(system, str): - inputs.append({ "role": "system", "content": kwargs["system"] }) + inputs.append({"role": "system", "content": kwargs["system"]}) elif isinstance(system, list): for item in kwargs["system"]: - if item.get("type") == "text": - inputs.append({ "role": "system", "content": item.get("text") }) + if __prop(item, "type") == "text": + inputs.append({"role": "system", "content": __prop(item, "text")}) for message in kwargs.get("messages", []): inputs.extend(parse_message(message)) return {"input": inputs, "name": kwargs.get("model")} - -def __output_parser(output: t.Union[Message], stream: bool = False): - if isinstance(output, Message): - return { - "name": output.model, - "tokensUsage": { - "completion": output.usage.output_tokens, - "prompt": output.usage.input_tokens, - }, - "output": [ - ( - {"content": content.text, "role": output.role} - if content.type == "text" - else { - "functionCall": { - "name": content.name, - "arguments": content.input, - }, - "toolCallId": content.id, - } - ) - for content in output.content - ], - } - else: - return { - "name": None, - "tokenUsage": None, - "output": getattr(output, "content", output), - } +def __output_parser(output: t.Any, stream: bool = False): + return { + "name": __prop(output, "model"), + "tokensUsage": { + "completion": __prop(output, ["usage", "output_tokens"]), + "prompt": __prop(output, ["usage", "input_tokens"]), + }, + "output": [ + ( + {"content": __prop(content, "text"), "role": __prop(output, "role")} + if __prop(content, "type") == "text" + else { + "functionCall": { + "name": __prop(content, "name"), + "arguments": __prop(content, "input"), + }, + "toolCallId": __prop(content, "id"), + } + ) + for content in __prop(output, "content", output) + ], + } class Stream: @@ -597,6 +604,8 @@ def __wrap_sync( try: parsed_output = output_parser(output, kwargs.get("stream", False)) + print(parsed_input, parsed_output, output) + track_event( type, "end", @@ -738,3 +747,46 @@ def monitor(client: "ClientType") -> "ClientType": else: raise Exception("Invalid argument. Expected instance of Anthropic Client") return client + + +def agent(name=None, user_id=None, user_props=None, tags=None): + def decorator(fn): + return partial( + __wrap_sync, + fn, "agent", + name=name or fn.__name__, + user_id=user_id, + user_props=user_props, + tags=tags + ) + + return decorator + + +def chain(name=None, user_id=None, user_props=None, tags=None): + def decorator(fn): + return partial( + __wrap_sync, + fn, "chain", + name=name or fn.__name__, + user_id=user_id, + user_props=user_props, + tags=tags + ) + + return decorator + + +def tool(name=None, user_id=None, user_props=None, tags=None): + def decorator(fn): + return partial( + __wrap_sync, + fn, "tool", + name=name or fn.__name__, + user_id=user_id, + user_props=user_props, + tags=tags + ) + + return decorator + From 2923100c52031c608b78ce59003bfc690e5bab7f Mon Sep 17 00:00:00 2001 From: Thraize Date: Fri, 2 Aug 2024 23:11:22 -0700 Subject: [PATCH 7/8] refactor: cleaned up anthrophic reconciliation example --- examples/anthrophic.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/examples/anthrophic.py b/examples/anthrophic.py index 0c655de..bf2d794 100644 --- a/examples/anthrophic.py +++ b/examples/anthrophic.py @@ -286,23 +286,7 @@ def reconcilliation_tool_calls(): tool = next(c for c in message.content if c.type == "tool_use") - for item in ( - [ - user_message, - {"role": message.role, "content": message.content}, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool.id, - "content": [{"type": "text", "text": "The weather is 73f"}], - } - ], - }, - ] - ): - thread.track_message(item) + for item in parse_message(message): thread.track_message(item) response = client.messages.create( model="claude-3-opus-20240229", From 29e1cfaf79f15f38b2ed5f4a5757e675d649621e Mon Sep 17 00:00:00 2001 From: 7HR4IZ3 Date: Sun, 1 Sep 2024 14:54:00 -0700 Subject: [PATCH 8/8] fix: minor bug fixes --- examples/openai/basic.py | 2 +- examples/threads.py | 4 ++-- lunary/anthrophic.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/openai/basic.py b/examples/openai/basic.py index bc43142..512f36a 100644 --- a/examples/openai/basic.py +++ b/examples/openai/basic.py @@ -1,5 +1,5 @@ import lunary -from openai import OpenAI +from openai import OpenAI, Client import os client = OpenAI( diff --git a/examples/threads.py b/examples/threads.py index f387aba..ec3123d 100644 --- a/examples/threads.py +++ b/examples/threads.py @@ -9,7 +9,7 @@ thread.track_message({ "role": "user", "content": "Hello!" -}) +}, "user123") time.sleep(0.5) @@ -23,7 +23,7 @@ thread.track_message({ "role": "user", "content": "I have a question about your product." -}) +}, "user123") time.sleep(0.5) msg_id = thread.track_message({ diff --git a/lunary/anthrophic.py b/lunary/anthrophic.py index b4c9244..d1f02e2 100644 --- a/lunary/anthrophic.py +++ b/lunary/anthrophic.py @@ -198,7 +198,7 @@ def __iterator__(self): if event.type == "message_delta": if len(self.__messages) >= 1: message = self.__messages[-1] - message["usage"]["tokens"] = event.usage.output_tokens + message["usage"]["output"] = event.usage.output_tokens if event.type == "message_stop": # print("\n\n ** ", list(__parse_message_content(event.message))) @@ -292,7 +292,7 @@ def __iterator__(self): name=self.__handler.__name__, token_usage={ "completion": sum( - [message["usage"]["tokens"] for message in self.__messages] + [message["usage"]["output"] for message in self.__messages] ), "prompt": sum( [message["usage"]["input"] for message in self.__messages] @@ -341,7 +341,7 @@ async def __iterator__(self): if event.type == "message_delta": if len(self.__messages) >= 1: message = self.__messages[-1] - message["usage"]["tokens"] = event.usage.output_tokens + message["usage"]["output"] = event.usage.output_tokens if event.type == "message_stop": # print("\n\n ** ", list(__parse_message_content(event.message))) @@ -435,7 +435,7 @@ async def __iterator__(self): name=self.__handler.__name__, token_usage={ "completion": sum( - [message["usage"]["tokens"] for message in self.__messages] + [message["usage"]["output"] for message in self.__messages] ), "prompt": sum( [message["usage"]["input"] for message in self.__messages]