Skip to content

Commit

Permalink
Fix clone dataset from other server (#1175)
Browse files Browse the repository at this point in the history
Right now if you make a client pointing to another server but try to
clone from the public saas server it fails.

---------

Co-authored-by: Brendan Maginnis <[email protected]>
  • Loading branch information
hinthornw and brendanator authored Nov 5, 2024
1 parent 21e4486 commit 8781a12
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 26 deletions.
33 changes: 14 additions & 19 deletions python/langsmith/_internal/_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import pathlib
import re
import uuid
from typing import (
Any,
)
from typing import Any

import orjson

Expand All @@ -33,14 +31,8 @@ def _simple_default(obj):
# https://github.com/ijl/orjson#serialize
if isinstance(obj, datetime.datetime):
return obj.isoformat()
if isinstance(obj, uuid.UUID):
elif isinstance(obj, uuid.UUID):
return str(obj)
if hasattr(obj, "model_dump") and callable(obj.model_dump):
return obj.model_dump()
elif hasattr(obj, "dict") and callable(obj.dict):
return obj.dict()
elif hasattr(obj, "_asdict") and callable(obj._asdict):
return obj._asdict()
elif isinstance(obj, BaseException):
return {"error": type(obj).__name__, "message": str(obj)}
elif isinstance(obj, (set, frozenset, collections.deque)):
Expand Down Expand Up @@ -77,6 +69,16 @@ def _simple_default(obj):
return str(obj)


_serialization_methods = [
(
"model_dump",
{"exclude_none": True, "mode": "json"},
), # Pydantic V2 with non-serializable fields
("dict", {}), # Pydantic V1 with non-serializable field
("to_dict", {}), # dataclasses-json
]


def _serialize_json(obj: Any) -> Any:
try:
if isinstance(obj, (set, tuple)):
Expand All @@ -85,22 +87,15 @@ def _serialize_json(obj: Any) -> Any:
return obj._asdict()
return list(obj)

serialization_methods = [
("model_dump", True), # Pydantic V2 with non-serializable fields
("dict", False), # Pydantic V1 with non-serializable field
("to_dict", False), # dataclasses-json
]
for attr, exclude_none in serialization_methods:
for attr, kwargs in _serialization_methods:
if (
hasattr(obj, attr)
and callable(getattr(obj, attr))
and not isinstance(obj, type)
):
try:
method = getattr(obj, attr)
response = (
method(exclude_none=exclude_none) if exclude_none else method()
)
response = method(**kwargs)
if not isinstance(response, dict):
return str(response)
return response
Expand Down
5 changes: 5 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,13 @@ def _parse_token_or_url(
path_parts = parsed_url.path.split("/")
if len(path_parts) >= num_parts:
token_uuid = path_parts[-num_parts]
_as_uuid(token_uuid, var="token parts")
else:
raise ls_utils.LangSmithUserError(f"Invalid public {kind} URL: {url_or_token}")
if parsed_url.netloc == "smith.langchain.com":
api_url = "https://api.smith.langchain.com"
elif parsed_url.netloc == "beta.smith.langchain.com":
api_url = "https://beta.api.smith.langchain.com"
return api_url, token_uuid


Expand Down
11 changes: 7 additions & 4 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,13 @@ def manual_extra_function(x):
manual_extra_function(5, langsmith_extra={"metadata": {"version": "1.0"}})
"""
run_type: ls_client.RUN_TYPE_T = (
args[0]
if args and isinstance(args[0], str)
else (kwargs.pop("run_type", None) or "chain")
run_type = cast(
ls_client.RUN_TYPE_T,
(
args[0]
if args and isinstance(args[0], str)
else (kwargs.pop("run_type", None) or "chain")
),
)
if run_type not in _VALID_RUN_TYPES:
warnings.warn(
Expand Down
73 changes: 70 additions & 3 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import math
import pathlib
import sys
import time
import uuid
Expand Down Expand Up @@ -37,7 +38,9 @@
Client,
_dumps_json,
_is_langchain_hosted,
_parse_token_or_url,
)
from langsmith.utils import LangSmithUserError

_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)

Expand Down Expand Up @@ -719,16 +722,24 @@ def test_pydantic_serialize() -> None:

class ChildPydantic(BaseModel):
uid: uuid.UUID
child_path_keys: Dict[pathlib.Path, pathlib.Path]

class MyPydantic(BaseModel):
foo: str
uid: uuid.UUID
tim: datetime
ex: Optional[str] = None
child: Optional[ChildPydantic] = None
path_keys: Dict[pathlib.Path, pathlib.Path]

obj = MyPydantic(
foo="bar", uid=test_uuid, tim=test_time, child=ChildPydantic(uid=test_uuid)
foo="bar",
uid=test_uuid,
tim=test_time,
child=ChildPydantic(
uid=test_uuid, child_path_keys={pathlib.Path("foo"): pathlib.Path("bar")}
),
path_keys={pathlib.Path("foo"): pathlib.Path("bar")},
)
res = json.loads(json.dumps(obj, default=_serialize_json))
expected = {
Expand All @@ -737,7 +748,9 @@ class MyPydantic(BaseModel):
"tim": test_time.isoformat(),
"child": {
"uid": str(test_uuid),
"child_path_keys": {"foo": "bar"},
},
"path_keys": {"foo": "bar"},
}
assert res == expected

Expand Down Expand Up @@ -777,6 +790,7 @@ def __repr__(self):
class MyPydantic(BaseModel):
foo: str
bar: int
path_keys: Dict[pathlib.Path, "MyPydantic"]

@dataclasses.dataclass
class MyDataclass:
Expand Down Expand Up @@ -816,7 +830,11 @@ class MyNamedTuple(NamedTuple):
"class_with_tee": ClassWithTee(),
"my_dataclass": MyDataclass("foo", 1),
"my_enum": MyEnum.FOO,
"my_pydantic": MyPydantic(foo="foo", bar=1),
"my_pydantic": MyPydantic(
foo="foo",
bar=1,
path_keys={pathlib.Path("foo"): MyPydantic(foo="foo", bar=1, path_keys={})},
),
"my_pydantic_class": MyPydantic,
"person": Person(name="foo_person"),
"a_bool": True,
Expand All @@ -842,7 +860,11 @@ class MyNamedTuple(NamedTuple):
"class_with_tee": "tee_a, tee_b",
"my_dataclass": {"foo": "foo", "bar": 1},
"my_enum": "foo",
"my_pydantic": {"foo": "foo", "bar": 1},
"my_pydantic": {
"foo": "foo",
"bar": 1,
"path_keys": {"foo": {"foo": "foo", "bar": 1, "path_keys": {}}},
},
"my_pydantic_class": lambda x: "MyPydantic" in x,
"person": {"name": "foo_person"},
"a_bool": True,
Expand Down Expand Up @@ -1182,3 +1204,48 @@ def test_validate_api_key_if_hosted(
# Check no warning is raised here.
warnings.simplefilter("error")
client_cls(api_url="http://localhost:1984")


def test_parse_token_or_url():
# Test with URL
url = "https://smith.langchain.com/public/419dcab2-1d66-4b94-8901-0357ead390df/d"
api_url = "https://api.smith.langchain.com"
assert _parse_token_or_url(url, api_url) == (
api_url,
"419dcab2-1d66-4b94-8901-0357ead390df",
)

url = "https://smith.langchain.com/public/419dcab2-1d66-4b94-8901-0357ead390df/d"
beta_api_url = "https://beta.api.smith.langchain.com"
# Should still point to the correct public one
assert _parse_token_or_url(url, beta_api_url) == (
api_url,
"419dcab2-1d66-4b94-8901-0357ead390df",
)

token = "419dcab2-1d66-4b94-8901-0357ead390df"
assert _parse_token_or_url(token, api_url) == (
api_url,
token,
)

# Test with UUID object
token_uuid = uuid.UUID("419dcab2-1d66-4b94-8901-0357ead390df")
assert _parse_token_or_url(token_uuid, api_url) == (
api_url,
str(token_uuid),
)

# Test with custom num_parts
url_custom = (
"https://smith.langchain.com/public/419dcab2-1d66-4b94-8901-0357ead390df/p/q"
)
assert _parse_token_or_url(url_custom, api_url, num_parts=3) == (
api_url,
"419dcab2-1d66-4b94-8901-0357ead390df",
)

# Test with invalid URL
invalid_url = "https://invalid.com/419dcab2-1d66-4b94-8901-0357ead390df"
with pytest.raises(LangSmithUserError):
_parse_token_or_url(invalid_url, api_url)

0 comments on commit 8781a12

Please sign in to comment.