Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wfh/strip not given #472

Merged
merged 10 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: tests lint format

tests:
poetry run pytest tests/unit_tests
poetry run pytest -n auto --durations=10 tests/unit_tests

tests_watch:
poetry run ptw --now . -- -vv -x tests/unit_tests
Expand Down
10 changes: 10 additions & 0 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class _ContainerInput(TypedDict, total=False):
reduce_fn: Optional[Callable]
project_name: Optional[str]
run_type: ls_client.RUN_TYPE_T
process_inputs: Optional[Callable[[dict], dict]]


def _container_end(
Expand Down Expand Up @@ -207,6 +208,12 @@ def _setup_run(
except TypeError as e:
logger.debug(f"Failed to infer inputs for {name_}: {e}")
inputs = {"args": args, "kwargs": kwargs}
process_inputs = container_input.get("process_inputs")
if process_inputs:
try:
inputs = process_inputs(inputs)
except Exception as e:
logger.error(f"Failed to filter inputs for {name_}: {e}")
outer_tags = _TAGS.get()
tags_ = (langsmith_extra.get("tags") or []) + (outer_tags or [])
_TAGS.set(tags_)
Expand Down Expand Up @@ -325,6 +332,7 @@ def traceable(
client: Optional[ls_client.Client] = None,
reduce_fn: Optional[Callable] = None,
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]:
...

Expand All @@ -350,6 +358,7 @@ def traceable(
called, and the run itself will be stuck in a pending state.
project_name: The name of the project to log the run to. Defaults to None,
which will use the default project.
process_inputs: A function to filter the inputs to the run. Defaults to None.


Returns:
Expand Down Expand Up @@ -492,6 +501,7 @@ def manual_extra_function(x):
client=kwargs.pop("client", None),
project_name=kwargs.pop("project_name", None),
run_type=run_type,
process_inputs=kwargs.pop("process_inputs", None),
)
if kwargs:
warnings.warn(
Expand Down
47 changes: 44 additions & 3 deletions python/langsmith/wrappers/_openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from __future__ import annotations

import functools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)

from langsmith import run_helpers

Expand All @@ -16,6 +28,28 @@
from openai.types.completion import Completion

C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI"])
logger = logging.getLogger(__name__)


@functools.lru_cache
def _get_not_given() -> Optional[Type]:
try:
from openai._types import NotGiven

return NotGiven
except ImportError:
return None


def _strip_not_given(d: dict) -> dict:
try:
not_given = _get_not_given()
if not_given is None:
return d
return {k: v for k, v in d.items() if not isinstance(v, not_given)}
except Exception as e:
logger.error(f"Error stripping NotGiven: {e}")
return d


def _reduce_choices(choices: List[Choice]) -> dict:
Expand Down Expand Up @@ -110,15 +144,22 @@ def _get_wrapper(original_create: Callable, name: str, reduce_fn: Callable) -> C
@functools.wraps(original_create)
def create(*args, stream: bool = False, **kwargs):
decorator = run_helpers.traceable(
name=name, run_type="llm", reduce_fn=reduce_fn if stream else None
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
)

return decorator(original_create)(*args, stream=stream, **kwargs)

@functools.wraps(original_create)
async def acreate(*args, stream: bool = False, **kwargs):
kwargs = _strip_not_given(kwargs)
decorator = run_helpers.traceable(
name=name, run_type="llm", reduce_fn=reduce_fn if stream else None
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
)
if stream:
# TODO: This slightly alters the output to be a generator instead of the
Expand Down
7 changes: 6 additions & 1 deletion python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def my_iterator_fn(a, b, d):
async def test_traceable_async_iterator(use_next: bool, mock_client: Client) -> None:
with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}):

@traceable(client=mock_client)
def filter_inputs(kwargs: dict):
return {"a": "FOOOOOO", "b": kwargs["b"], "d": kwargs["d"]}

@traceable(client=mock_client, process_inputs=filter_inputs)
async def my_iterator_fn(a, b, d):
for i in range(a + b + d):
yield i
Expand All @@ -234,6 +237,8 @@ async def my_iterator_fn(a, b, d):
body = json.loads(call.kwargs["data"])
assert body["post"]
assert body["post"][0]["outputs"]["output"] == expected
# Assert the inputs are filtered as expected
assert body["post"][0]["inputs"] == {"a": "FOOOOOO", "b": 2, "d": 3}


@patch("langsmith.run_trees.Client", autospec=True)
Expand Down
Loading