Skip to content

Commit

Permalink
Trace if explicitly set (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 27, 2023
1 parent 07b4ceb commit 7bd7a85
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 91 deletions.
142 changes: 57 additions & 85 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,28 @@ class LangSmithExtra(TypedDict, total=False):
class _TraceableContainer(TypedDict, total=False):
"""Typed response when initializing a run a traceable."""

new_run: run_trees.RunTree
new_run: Optional[run_trees.RunTree]
project_name: Optional[str]
outer_project: Optional[str]
outer_metadata: Optional[Dict[str, Any]]
outer_tags: Optional[List[str]]


def _container_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[str] = None,
):
"""End the run."""
run_tree = container.get("new_run")
if run_tree is None:
# Tracing disabled
return
outputs_ = outputs if isinstance(outputs, dict) else {"output": outputs}
run_tree.end(outputs=outputs_, error=error)
run_tree.patch()


def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict:
run_extra = langsmith_extra.get("run_extra", None)
if run_extra:
Expand Down Expand Up @@ -119,6 +134,19 @@ def _setup_run(
)
langsmith_extra = langsmith_extra or LangSmithExtra()
parent_run_ = langsmith_extra.get("run_tree") or _PARENT_RUN_TREE.get()
if not parent_run_ and not utils.tracing_is_enabled():
utils.log_once(
logging.DEBUG, "LangSmith tracing is disabled, returning original function."
)
return _TraceableContainer(
new_run=None,
project_name=outer_project,
outer_project=outer_project,
outer_metadata=None,
outer_tags=None,
)
# Else either the env var is set OR a parent run was explicitly set,
# which occurs in the `as_runnable()` flow
project_name_ = langsmith_extra.get("project_name", outer_project)
signature = inspect.signature(func)
name_ = name or func.__name__
Expand Down Expand Up @@ -172,14 +200,18 @@ def _setup_run(
executor=executor,
client=client_,
)

new_run.post()
return _TraceableContainer(
response_container = _TraceableContainer(
new_run=new_run,
project_name=project_name_,
outer_project=outer_project,
outer_metadata=outer_metadata,
outer_tags=outer_tags,
)
_PROJECT_NAME.set(response_container["project_name"])
_PARENT_RUN_TREE.set(response_container["new_run"])
return response_container


def traceable(
Expand Down Expand Up @@ -217,36 +249,6 @@ def traceable(
extra_outer = extra or {}

def decorator(func: Callable):
if not utils.tracing_is_enabled():
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def anoop_wrapper(
*args: Any,
langsmith_extra: Optional[LangSmithExtra] = None,
**kwargs: Any,
) -> Any:
return await func(*args, **kwargs)

fn = anoop_wrapper
else:

@functools.wraps(func)
def noop_wrapper(
*args: Any,
langsmith_extra: Optional[LangSmithExtra] = None,
**kwargs: Any,
) -> Any:
return func(*args, **kwargs)

fn = noop_wrapper

utils.log_once(
logging.DEBUG, "Tracing is disabled, returning original function"
)
setattr(fn, "__langsmith_traceable__", True)
return fn

@functools.wraps(func)
async def async_wrapper(
*args: Any,
Expand All @@ -268,8 +270,6 @@ async def async_wrapper(
args=args,
kwargs=kwargs,
)
_PROJECT_NAME.set(run_container["project_name"])
_PARENT_RUN_TREE.set(run_container["new_run"])
func_accepts_parent_run = (
inspect.signature(func).parameters.get("run_tree", None) is not None
)
Expand All @@ -281,18 +281,15 @@ async def async_wrapper(
else:
function_result = await func(*args, **kwargs)
except Exception as e:
run_container["new_run"].end(error=str(e))
run_container["new_run"].patch()
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
raise e
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
if isinstance(function_result, dict):
run_container["new_run"].end(outputs=function_result)
else:
run_container["new_run"].end(outputs={"output": function_result})
run_container["new_run"].patch()
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
_container_end(run_container, outputs=function_result)
return function_result

@functools.wraps(func)
Expand All @@ -313,8 +310,6 @@ async def async_generator_wrapper(
args=args,
kwargs=kwargs,
)
_PROJECT_NAME.set(run_container["project_name"])
_PARENT_RUN_TREE.set(run_container["new_run"])
func_accepts_parent_run = (
inspect.signature(func).parameters.get("run_tree", None) is not None
)
Expand All @@ -336,15 +331,15 @@ async def async_generator_wrapper(
async for item in async_gen_result:
results.append(item)
yield item
except (BaseException, Exception, KeyboardInterrupt) as e:
except BaseException as e:
stacktrace = traceback.format_exc()
run_container["new_run"].end(error=stacktrace)
run_container["new_run"].patch()
_container_end(run_container, error=stacktrace)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
raise e
if results:
if reduce_fn:
try:
Expand All @@ -356,11 +351,7 @@ async def async_generator_wrapper(
function_result = results
else:
function_result = None
if isinstance(function_result, dict):
run_container["new_run"].end(outputs=function_result)
else:
run_container["new_run"].end(outputs={"output": function_result})
run_container["new_run"].patch()
_container_end(run_container, outputs=function_result)

@functools.wraps(func)
def wrapper(
Expand All @@ -383,8 +374,6 @@ def wrapper(
args=args,
kwargs=kwargs,
)
_PROJECT_NAME.set(run_container["project_name"])
_PARENT_RUN_TREE.set(run_container["new_run"])
func_accepts_parent_run = (
inspect.signature(func).parameters.get("run_tree", None) is not None
)
Expand All @@ -395,24 +384,16 @@ def wrapper(
)
else:
function_result = func(*args, **kwargs)
except (BaseException, Exception, KeyboardInterrupt) as e:
except BaseException as e:
stacktrace = traceback.format_exc()
run_container["new_run"].end(error=stacktrace)
run_container["new_run"].patch()
_container_end(run_container, error=stacktrace)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
raise e
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
if isinstance(function_result, dict):
run_container["new_run"].end(outputs=function_result)
else:
run_container["new_run"].end(outputs={"output": function_result})
run_container["new_run"].patch()
_container_end(run_container, outputs=function_result)
return function_result

@functools.wraps(func)
Expand All @@ -433,8 +414,7 @@ def generator_wrapper(
args=args,
kwargs=kwargs,
)
_PROJECT_NAME.set(run_container["project_name"])
_PARENT_RUN_TREE.set(run_container["new_run"])

func_accepts_parent_run = (
inspect.signature(func).parameters.get("run_tree", None) is not None
)
Expand All @@ -449,22 +429,18 @@ def generator_wrapper(
# called mid-generation. Need to explicitly accept run_tree to get
# around this.
generator_result = func(*args, **kwargs)
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
for item in generator_result:
results.append(item)
yield item
except (BaseException, Exception, KeyboardInterrupt) as e:
except BaseException as e:
stacktrace = traceback.format_exc()
run_container["new_run"].end(error=stacktrace)
run_container["new_run"].patch()
_container_end(run_container, error=stacktrace)
raise e
finally:
_PARENT_RUN_TREE.set(context_run)
_PROJECT_NAME.set(run_container["outer_project"])
_TAGS.set(run_container["outer_tags"])
_METADATA.set(run_container["outer_metadata"])
raise e
if results:
if reduce_fn:
try:
Expand All @@ -476,11 +452,7 @@ def generator_wrapper(
function_result = results
else:
function_result = None
if isinstance(function_result, dict) or function_result is None:
run_container["new_run"].end(outputs=function_result)
else:
run_container["new_run"].end(outputs={"output": function_result})
run_container["new_run"].patch()
_container_end(run_container, outputs=function_result)

if inspect.isasyncgenfunction(func):
selected_wrapper: Callable = async_generator_wrapper
Expand Down
12 changes: 6 additions & 6 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def test_persist_update_run(
run["outputs"] = {"output": ["Hi"]}
run["extra"]["foo"] = "bar"
langchain_client.update_run(run["id"], **run)
for _ in range(3):
for _ in range(5):
try:
stored_run = langchain_client.read_run(run["id"])
break
except LangSmithError:
time.sleep(1)
time.sleep(2)

assert stored_run.id == run["id"]
assert stored_run.outputs == run["outputs"]
Expand Down Expand Up @@ -217,8 +217,8 @@ def grader(run_input: str, run_output: str, answer: Optional[str]) -> dict:
return dict(score=score, value=value)

evaluator = StringEvaluator(evaluation_name="Jaccard", grading_function=grader)

for _ in range(3):
runs = None
for _ in range(5):
try:
runs = list(
langchain_client.list_runs(
Expand All @@ -229,8 +229,8 @@ def grader(run_input: str, run_output: str, answer: Optional[str]) -> dict:
)
break
except LangSmithError:
time.sleep(1)

time.sleep(2)
assert runs is not None
all_eval_results: List[EvaluationResult] = []
for run in runs:
all_eval_results.append(langchain_client.evaluate_run(run, evaluator))
Expand Down

0 comments on commit 7bd7a85

Please sign in to comment.