diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 71effb6d6..cb19cbbdf 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -85,13 +85,28 @@ class LangSmithExtra(TypedDict, total=False): class _TraceableContainer(TypedDict, total=False): """Typed response when initializing a run a traceable.""" - new_run: run_trees.RunTree + new_run: Optional[run_trees.RunTree] project_name: Optional[str] outer_project: Optional[str] outer_metadata: Optional[Dict[str, Any]] outer_tags: Optional[List[str]] +def _container_end( + container: _TraceableContainer, + outputs: Optional[Any] = None, + error: Optional[str] = None, +): + """End the run.""" + run_tree = container.get("new_run") + if run_tree is None: + # Tracing disabled + return + outputs_ = outputs if isinstance(outputs, dict) else {"output": outputs} + run_tree.end(outputs=outputs_, error=error) + run_tree.patch() + + def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict: run_extra = langsmith_extra.get("run_extra", None) if run_extra: @@ -119,6 +134,19 @@ def _setup_run( ) langsmith_extra = langsmith_extra or LangSmithExtra() parent_run_ = langsmith_extra.get("run_tree") or _PARENT_RUN_TREE.get() + if not parent_run_ and not utils.tracing_is_enabled(): + utils.log_once( + logging.DEBUG, "LangSmith tracing is disabled, returning original function." + ) + return _TraceableContainer( + new_run=None, + project_name=outer_project, + outer_project=outer_project, + outer_metadata=None, + outer_tags=None, + ) + # Else either the env var is set OR a parent run was explicitly set, + # which occurs in the `as_runnable()` flow project_name_ = langsmith_extra.get("project_name", outer_project) signature = inspect.signature(func) name_ = name or func.__name__ @@ -172,14 +200,18 @@ def _setup_run( executor=executor, client=client_, ) + new_run.post() - return _TraceableContainer( + response_container = _TraceableContainer( new_run=new_run, project_name=project_name_, outer_project=outer_project, outer_metadata=outer_metadata, outer_tags=outer_tags, ) + _PROJECT_NAME.set(response_container["project_name"]) + _PARENT_RUN_TREE.set(response_container["new_run"]) + return response_container def traceable( @@ -217,36 +249,6 @@ def traceable( extra_outer = extra or {} def decorator(func: Callable): - if not utils.tracing_is_enabled(): - if inspect.iscoroutinefunction(func): - - @functools.wraps(func) - async def anoop_wrapper( - *args: Any, - langsmith_extra: Optional[LangSmithExtra] = None, - **kwargs: Any, - ) -> Any: - return await func(*args, **kwargs) - - fn = anoop_wrapper - else: - - @functools.wraps(func) - def noop_wrapper( - *args: Any, - langsmith_extra: Optional[LangSmithExtra] = None, - **kwargs: Any, - ) -> Any: - return func(*args, **kwargs) - - fn = noop_wrapper - - utils.log_once( - logging.DEBUG, "Tracing is disabled, returning original function" - ) - setattr(fn, "__langsmith_traceable__", True) - return fn - @functools.wraps(func) async def async_wrapper( *args: Any, @@ -268,8 +270,6 @@ async def async_wrapper( args=args, kwargs=kwargs, ) - _PROJECT_NAME.set(run_container["project_name"]) - _PARENT_RUN_TREE.set(run_container["new_run"]) func_accepts_parent_run = ( inspect.signature(func).parameters.get("run_tree", None) is not None ) @@ -281,18 +281,15 @@ async def async_wrapper( else: function_result = await func(*args, **kwargs) except Exception as e: - run_container["new_run"].end(error=str(e)) - run_container["new_run"].patch() + stacktrace = traceback.format_exc() + _container_end(run_container, error=stacktrace) + raise e + finally: _PARENT_RUN_TREE.set(context_run) _PROJECT_NAME.set(run_container["outer_project"]) - raise e - _PARENT_RUN_TREE.set(context_run) - _PROJECT_NAME.set(run_container["outer_project"]) - if isinstance(function_result, dict): - run_container["new_run"].end(outputs=function_result) - else: - run_container["new_run"].end(outputs={"output": function_result}) - run_container["new_run"].patch() + _TAGS.set(run_container["outer_tags"]) + _METADATA.set(run_container["outer_metadata"]) + _container_end(run_container, outputs=function_result) return function_result @functools.wraps(func) @@ -313,8 +310,6 @@ async def async_generator_wrapper( args=args, kwargs=kwargs, ) - _PROJECT_NAME.set(run_container["project_name"]) - _PARENT_RUN_TREE.set(run_container["new_run"]) func_accepts_parent_run = ( inspect.signature(func).parameters.get("run_tree", None) is not None ) @@ -336,15 +331,15 @@ async def async_generator_wrapper( async for item in async_gen_result: results.append(item) yield item - except (BaseException, Exception, KeyboardInterrupt) as e: + except BaseException as e: stacktrace = traceback.format_exc() - run_container["new_run"].end(error=stacktrace) - run_container["new_run"].patch() + _container_end(run_container, error=stacktrace) + raise e + finally: _PARENT_RUN_TREE.set(context_run) _PROJECT_NAME.set(run_container["outer_project"]) _TAGS.set(run_container["outer_tags"]) _METADATA.set(run_container["outer_metadata"]) - raise e if results: if reduce_fn: try: @@ -356,11 +351,7 @@ async def async_generator_wrapper( function_result = results else: function_result = None - if isinstance(function_result, dict): - run_container["new_run"].end(outputs=function_result) - else: - run_container["new_run"].end(outputs={"output": function_result}) - run_container["new_run"].patch() + _container_end(run_container, outputs=function_result) @functools.wraps(func) def wrapper( @@ -383,8 +374,6 @@ def wrapper( args=args, kwargs=kwargs, ) - _PROJECT_NAME.set(run_container["project_name"]) - _PARENT_RUN_TREE.set(run_container["new_run"]) func_accepts_parent_run = ( inspect.signature(func).parameters.get("run_tree", None) is not None ) @@ -395,24 +384,16 @@ def wrapper( ) else: function_result = func(*args, **kwargs) - except (BaseException, Exception, KeyboardInterrupt) as e: + except BaseException as e: stacktrace = traceback.format_exc() - run_container["new_run"].end(error=stacktrace) - run_container["new_run"].patch() + _container_end(run_container, error=stacktrace) + raise e + finally: _PARENT_RUN_TREE.set(context_run) _PROJECT_NAME.set(run_container["outer_project"]) _TAGS.set(run_container["outer_tags"]) _METADATA.set(run_container["outer_metadata"]) - raise e - _PARENT_RUN_TREE.set(context_run) - _PROJECT_NAME.set(run_container["outer_project"]) - _TAGS.set(run_container["outer_tags"]) - _METADATA.set(run_container["outer_metadata"]) - if isinstance(function_result, dict): - run_container["new_run"].end(outputs=function_result) - else: - run_container["new_run"].end(outputs={"output": function_result}) - run_container["new_run"].patch() + _container_end(run_container, outputs=function_result) return function_result @functools.wraps(func) @@ -433,8 +414,7 @@ def generator_wrapper( args=args, kwargs=kwargs, ) - _PROJECT_NAME.set(run_container["project_name"]) - _PARENT_RUN_TREE.set(run_container["new_run"]) + func_accepts_parent_run = ( inspect.signature(func).parameters.get("run_tree", None) is not None ) @@ -449,22 +429,18 @@ def generator_wrapper( # called mid-generation. Need to explicitly accept run_tree to get # around this. generator_result = func(*args, **kwargs) - _PARENT_RUN_TREE.set(context_run) - _PROJECT_NAME.set(run_container["outer_project"]) - _TAGS.set(run_container["outer_tags"]) - _METADATA.set(run_container["outer_metadata"]) for item in generator_result: results.append(item) yield item - except (BaseException, Exception, KeyboardInterrupt) as e: + except BaseException as e: stacktrace = traceback.format_exc() - run_container["new_run"].end(error=stacktrace) - run_container["new_run"].patch() + _container_end(run_container, error=stacktrace) + raise e + finally: _PARENT_RUN_TREE.set(context_run) _PROJECT_NAME.set(run_container["outer_project"]) _TAGS.set(run_container["outer_tags"]) _METADATA.set(run_container["outer_metadata"]) - raise e if results: if reduce_fn: try: @@ -476,11 +452,7 @@ def generator_wrapper( function_result = results else: function_result = None - if isinstance(function_result, dict) or function_result is None: - run_container["new_run"].end(outputs=function_result) - else: - run_container["new_run"].end(outputs={"output": function_result}) - run_container["new_run"].patch() + _container_end(run_container, outputs=function_result) if inspect.isasyncgenfunction(func): selected_wrapper: Callable = async_generator_wrapper diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index d0b92ca60..638aaf7ec 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -150,12 +150,12 @@ def test_persist_update_run( run["outputs"] = {"output": ["Hi"]} run["extra"]["foo"] = "bar" langchain_client.update_run(run["id"], **run) - for _ in range(3): + for _ in range(5): try: stored_run = langchain_client.read_run(run["id"]) break except LangSmithError: - time.sleep(1) + time.sleep(2) assert stored_run.id == run["id"] assert stored_run.outputs == run["outputs"] @@ -217,8 +217,8 @@ def grader(run_input: str, run_output: str, answer: Optional[str]) -> dict: return dict(score=score, value=value) evaluator = StringEvaluator(evaluation_name="Jaccard", grading_function=grader) - - for _ in range(3): + runs = None + for _ in range(5): try: runs = list( langchain_client.list_runs( @@ -229,8 +229,8 @@ def grader(run_input: str, run_output: str, answer: Optional[str]) -> dict: ) break except LangSmithError: - time.sleep(1) - + time.sleep(2) + assert runs is not None all_eval_results: List[EvaluationResult] = [] for run in runs: all_eval_results.append(langchain_client.evaluate_run(run, evaluator))