From bef4dfa99329e51e6e95c84cdd71655f05beda81 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 5 Feb 2024 14:02:18 -0800 Subject: [PATCH] Add info endpoint (#403) --- python/langsmith/client.py | 22 +++++++++++ python/langsmith/schemas.py | 20 +++++++++- python/langsmith/utils.py | 38 +++++++++++++++++++ python/tests/integration_tests/test_client.py | 10 +++++ python/tests/unit_tests/test_utils.py | 23 +++++++++++ 5 files changed, 112 insertions(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index ec7ee63d7..c0056d027 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -457,6 +457,28 @@ def _headers(self) -> Dict[str, str]: headers["x-api-key"] = self.api_key return headers + @property + @ls_utils.ttl_cache(maxsize=1) + def info(self) -> Optional[ls_schemas.LangSmithInfo]: + """Get the information about the LangSmith API. + + Returns + ------- + dict + The information about the LangSmith API. + """ + try: + response = self.session.get( + self.api_url + "/info", + headers=self._headers, + timeout=self.timeout_ms / 1000, + ) + ls_utils.raise_for_status_with_text(response) + return ls_schemas.LangSmithInfo(**response.json()) + except ls_utils.LangSmithAPIError as e: + logger.debug("Failed to get info: %s", e) + return None + def request_with_retries( self, request_method: str, diff --git a/python/langsmith/schemas.py b/python/langsmith/schemas.py index 62d193f3a..7ca4c8baf 100644 --- a/python/langsmith/schemas.py +++ b/python/langsmith/schemas.py @@ -9,12 +9,13 @@ List, Optional, Protocol, - TypedDict, Union, runtime_checkable, ) from uuid import UUID +from typing_extensions import TypedDict + try: from pydantic.v1 import ( # type: ignore[import] BaseModel, @@ -505,4 +506,21 @@ class AnnotationQueue(BaseModel): tenant_id: UUID +class BatchIngestConfig(TypedDict, total=False): + scale_up_qsize_trigger: int + scale_up_nthreads_limit: int + scale_down_nempty_trigger: int + size_limit: int + + +class LangSmithInfo(BaseModel): + """Information about the LangSmith server.""" + + version: str = "" + """The version of the LangSmith server.""" + license_expiration_time: Optional[datetime] = None + """The time the license will expire.""" + batch_ingest_config: Optional[BatchIngestConfig] = None + + Example.update_forward_refs() diff --git a/python/langsmith/utils.py b/python/langsmith/utils.py index 382dc6190..3e04063f6 100644 --- a/python/langsmith/utils.py +++ b/python/langsmith/utils.py @@ -1,9 +1,12 @@ """Generic utility functions.""" + import enum import functools import logging import os import subprocess +import threading +import time from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import requests @@ -284,3 +287,38 @@ def filter(self, record) -> bool: return ( "Connection pool is full, discarding connection" not in record.getMessage() ) + + +def ttl_cache( + ttl_seconds: Optional[int] = None, maxsize: Optional[int] = None +) -> Callable: + """LRU cache with an optional TTL.""" + + def decorator(func: Callable) -> Callable: + cache: Dict[Tuple, Tuple] = {} + cache_lock = threading.RLock() + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + key = (args, frozenset(kwargs.items())) + with cache_lock: + if key in cache: + result, timestamp = cache[key] + if ttl_seconds is None or time.time() - timestamp < ttl_seconds: + # Refresh the timestamp + cache[key] = (result, time.time()) + return result + result = func(*args, **kwargs) + with cache_lock: + cache[key] = (result, time.time()) + + if maxsize is not None: + if len(cache) > maxsize: + oldest_key = min(cache, key=lambda k: cache[k][1]) + del cache[oldest_key] + + return result + + return wrapper + + return decorator diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index 04c6708ea..d53d53b7e 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -494,3 +494,13 @@ def test_batch_ingest_runs(langchain_client: Client) -> None: assert run2.outputs == {"output1": 7, "output2": 8} langchain_client.delete_project(project_name=_session) + + +@freeze_time("2023-01-01") +def test_get_info() -> None: + langchain_client = Client(api_key="not-a-real-key") + info = langchain_client.info + assert info + assert info.version is not None + assert info.batch_ingest_config is not None + assert info.batch_ingest_config["size_limit"] > 0 diff --git a/python/tests/unit_tests/test_utils.py b/python/tests/unit_tests/test_utils.py index 4a7029a2a..c7348d6a3 100644 --- a/python/tests/unit_tests/test_utils.py +++ b/python/tests/unit_tests/test_utils.py @@ -1,3 +1,4 @@ +import time import unittest import pytest @@ -71,3 +72,25 @@ def test_correct_get_tracer_project(self): else ls_utils.get_tracer_project(case.return_default_value) ) self.assertEqual(project, case.expected_project_name) + + +def test_ttl_cache(): + test_function_val = 0 + + class MyClass: + @property + @ls_utils.ttl_cache(ttl_seconds=0.1) + def test_function(self): + nonlocal test_function_val + test_function_val += 1 + return test_function_val + + some_class = MyClass() + for _ in range(3): + assert some_class.test_function == 1 + time.sleep(0.1) + for _ in range(3): + assert some_class.test_function == 2 + time.sleep(0.1) + for _ in range(3): + assert some_class.test_function == 3