Skip to content

Commit

Permalink
Add first token time
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Feb 13, 2024
1 parent edf779d commit 7c2e4a5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
43 changes: 32 additions & 11 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import contextvars
import datetime
import functools
import inspect
import logging
Expand Down Expand Up @@ -117,14 +118,15 @@ def _container_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[str] = None,
events: Optional[List[dict]] = None,
):
"""End the run."""
run_tree = container.get("new_run")
if run_tree is None:
# Tracing disabled
return
outputs_ = outputs if isinstance(outputs, dict) else {"output": outputs}
run_tree.end(outputs=outputs_, error=error)
run_tree.end(outputs=outputs_, error=error, events=events)
run_tree.patch()


Expand Down Expand Up @@ -245,15 +247,13 @@ def __call__(
*args: Any,
langsmith_extra: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> R:
...
) -> R: ...


@overload
def traceable(
func: Callable[..., R],
) -> Callable[..., R]:
...
) -> Callable[..., R]: ...


@overload
Expand All @@ -266,8 +266,7 @@ def traceable(
client: Optional[client.Client] = None,
extra: Optional[Dict] = None,
reduce_fn: Optional[Callable] = None,
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]:
...
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]: ...


def traceable(
Expand Down Expand Up @@ -350,6 +349,7 @@ async def async_wrapper(
async def async_generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> AsyncGenerator:
events: List[dict] = []
context_run = _PARENT_RUN_TREE.get()
run_container = _setup_run(
func,
Expand Down Expand Up @@ -385,11 +385,21 @@ async def async_generator_wrapper(
if inspect.iscoroutine(async_gen_result):
async_gen_result = await async_gen_result
async for item in async_gen_result:
if run_type == "llm":
events.append(
{
"name": "new_token",
"time": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
"kwargs": {"token": item},
},
)
results.append(item)
yield item
except BaseException as e:
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
_container_end(run_container, error=stacktrace, events=events)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
Expand All @@ -407,7 +417,7 @@ async def async_generator_wrapper(
function_result = results
else:
function_result = None
_container_end(run_container, outputs=function_result)
_container_end(run_container, outputs=function_result, events=events)

@functools.wraps(func)
def wrapper(
Expand Down Expand Up @@ -456,6 +466,7 @@ def generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> Any:
context_run = _PARENT_RUN_TREE.get()
events: List[dict] = []
run_container = _setup_run(
func,
run_type=run_type,
Expand Down Expand Up @@ -483,14 +494,24 @@ def generator_wrapper(
# around this.
generator_result = func(*args, **kwargs)
for item in generator_result:
if run_type == "llm":
events.append(
{
"name": "new_token",
"time": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
"kwargs": {"token": item},
},
)
results.append(item)
try:
yield item
except GeneratorExit:
break
except BaseException as e:
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
_container_end(run_container, error=stacktrace, events=events)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
Expand All @@ -508,7 +529,7 @@ def generator_wrapper(
function_result = results
else:
function_result = None
_container_end(run_container, outputs=function_result)
_container_end(run_container, outputs=function_result, events=events)

if inspect.isasyncgenfunction(func):
selected_wrapper: Callable = async_generator_wrapper
Expand Down
8 changes: 6 additions & 2 deletions python/langsmith/run_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def infer_defaults(cls, values: dict) -> dict:
values["trace_id"] = values["parent_run"].trace_id
else:
values["trace_id"] = values["id"]
else:
print(values["trace_id"])
cast(dict, values.setdefault("extra", {}))
return values

Expand All @@ -100,13 +98,16 @@ def end(
outputs: Optional[Dict] = None,
error: Optional[str] = None,
end_time: Optional[datetime] = None,
events: Optional[List[Dict]] = None,
) -> None:
"""Set the end time of the run and all child runs."""
self.end_time = end_time or datetime.now(timezone.utc)
if outputs is not None:
self.outputs = outputs
if error is not None:
self.error = error
if events is not None:
self.events = events

def create_child(
self,
Expand Down Expand Up @@ -181,6 +182,9 @@ def patch(self) -> None:
end_time=self.end_time,
dotted_order=self.dotted_order,
trace_id=self.trace_id,
events=self.events,
tags=self.tags,
extra=self.extra,
)

def wait(self) -> None:
Expand Down

0 comments on commit 7c2e4a5

Please sign in to comment.