diff --git a/js/package.json b/js/package.json index c694a93d6..a9506707d 100644 --- a/js/package.json +++ b/js/package.json @@ -121,4 +121,4 @@ }, "./package.json": "./package.json" } -} \ No newline at end of file +} diff --git a/python/langsmith/client.py b/python/langsmith/client.py index c0056d027..144a22545 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -458,25 +458,37 @@ def _headers(self) -> Dict[str, str]: return headers @property - @ls_utils.ttl_cache(maxsize=1) def info(self) -> Optional[ls_schemas.LangSmithInfo]: """Get the information about the LangSmith API. Returns ------- - dict - The information about the LangSmith API. + Optional[ls_schemas.LangSmithInfo] + The information about the LangSmith API, or None if the API is + not available. """ + info = Client._get_info(self.session, self.api_url, self.timeout_ms) + if info is None and self.tracing_queue is not None: + self.tracing_queue = None + return cast(Optional[ls_schemas.LangSmithInfo], info) + + @staticmethod + @functools.lru_cache(maxsize=1) + def _get_info( + session: requests.Session, api_url: str, timeout_ms: int + ) -> Optional[ls_schemas.LangSmithInfo]: try: - response = self.session.get( - self.api_url + "/info", - headers=self._headers, - timeout=self.timeout_ms / 1000, + response = session.get( + api_url + "/info", + headers={"Accept": "application/json"}, + timeout=timeout_ms / 1000, ) ls_utils.raise_for_status_with_text(response) return ls_schemas.LangSmithInfo(**response.json()) - except ls_utils.LangSmithAPIError as e: - logger.debug("Failed to get info: %s", e) + except requests.HTTPError: + return None + except BaseException as e: + logger.warning(f"Failed to get info from {api_url}: {repr(e)}") return None def request_with_retries( @@ -829,6 +841,15 @@ def upload_csv( def _run_transform( run: Union[ls_schemas.Run, dict, ls_schemas.RunLikeDict], ) -> dict: + """ + Transforms the given run object into a dictionary representation. + + Args: + run (Union[ls_schemas.Run, dict]): The run object to transform. + + Returns: + dict: The transformed run object as a dictionary. + """ if hasattr(run, "dict") and callable(getattr(run, "dict")): run_create = run.dict() # type: ignore else: @@ -935,6 +956,8 @@ def create_run( # batch ingest requires trace_id and dotted_order to be set and run_create.get("trace_id") is not None and run_create.get("dotted_order") is not None + # Checked last since it makes a (cached) API call + and self.info is not None # Older versions don't support batch ingest ): return self.tracing_queue.put( TracingQueueItem(run_create["dotted_order"], "create", run_create) @@ -1109,6 +1132,8 @@ def update_run( # batch ingest requires trace_id and dotted_order to be set and data["trace_id"] is not None and data["dotted_order"] is not None + # Checked last since it makes an API call + and self.info is not None # Older versions don't support batch ingest ): return self.tracing_queue.put( TracingQueueItem(data["dotted_order"], "update", data) @@ -3250,40 +3275,72 @@ def _tracing_thread_handle_batch( _AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4 +def _ensure_ingest_config( + info: Optional[ls_schemas.LangSmithInfo], +) -> ls_schemas.BatchIngestConfig: + default_config = ls_schemas.BatchIngestConfig( + size_limit=100, + scale_up_nthreads_limit=_AUTO_SCALE_UP_NTHREADS_LIMIT, + scale_up_qsize_trigger=_AUTO_SCALE_UP_QSIZE_TRIGGER, + scale_down_nempty_trigger=_AUTO_SCALE_DOWN_NEMPTY_TRIGGER, + ) + if not info: + return default_config + try: + if not info.batch_ingest_config: + return default_config + return info.batch_ingest_config + except BaseException: + return default_config + + def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: client = client_ref() if client is None: return + try: + if not client.info: + print(f"no info: {client.info}", file=sys.stderr, flush=True) + return + except BaseException as e: + logger.debug("Error in tracing control thread: %s", e) + return tracing_queue = client.tracing_queue assert tracing_queue is not None + batch_ingest_config = _ensure_ingest_config(client.info) + size_limit: int = batch_ingest_config["size_limit"] + scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"] + scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"] sub_threads: List[threading.Thread] = [] + # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached + num_known_refs = 3 # loop until while ( # the main thread dies threading.main_thread().is_alive() # or we're the only remaining reference to the client - and sys.getrefcount(client) > 3 + len(sub_threads) - # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached + and sys.getrefcount(client) > num_known_refs + len(sub_threads) ): for thread in sub_threads: if not thread.is_alive(): sub_threads.remove(thread) if ( - len(sub_threads) < _AUTO_SCALE_UP_NTHREADS_LIMIT - and tracing_queue.qsize() > _AUTO_SCALE_UP_QSIZE_TRIGGER + len(sub_threads) < scale_up_nthreads_limit + and tracing_queue.qsize() > scale_up_qsize_trigger ): new_thread = threading.Thread( target=_tracing_sub_thread_func, args=(weakref.ref(client),) ) sub_threads.append(new_thread) new_thread.start() - if next_batch := _tracing_thread_drain_queue(tracing_queue): + if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit): _tracing_thread_handle_batch(client, tracing_queue, next_batch) - # drain the queue on exit - while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): + while next_batch := _tracing_thread_drain_queue( + tracing_queue, limit=size_limit, block=False + ): _tracing_thread_handle_batch(client, tracing_queue, next_batch) @@ -3291,9 +3348,16 @@ def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None: client = client_ref() if client is None: return + try: + if not client.info: + return + except BaseException as e: + logger.debug("Error in tracing control thread: %s", e) + return tracing_queue = client.tracing_queue assert tracing_queue is not None - + batch_ingest_config = _ensure_ingest_config(client.info) + size_limit = batch_ingest_config.get("size_limit", 100) seen_successive_empty_queues = 0 # loop until @@ -3301,14 +3365,17 @@ def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None: # the main thread dies threading.main_thread().is_alive() # or we've seen the queue empty 4 times in a row - and seen_successive_empty_queues <= _AUTO_SCALE_DOWN_NEMPTY_TRIGGER + and seen_successive_empty_queues + <= batch_ingest_config["scale_down_nempty_trigger"] ): - if next_batch := _tracing_thread_drain_queue(tracing_queue): + if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit): seen_successive_empty_queues = 0 _tracing_thread_handle_batch(client, tracing_queue, next_batch) else: seen_successive_empty_queues += 1 # drain the queue on exit - while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): + while next_batch := _tracing_thread_drain_queue( + tracing_queue, limit=size_limit, block=False + ): _tracing_thread_handle_batch(client, tracing_queue, next_batch) diff --git a/python/langsmith/utils.py b/python/langsmith/utils.py index 3e04063f6..abf7edb4c 100644 --- a/python/langsmith/utils.py +++ b/python/langsmith/utils.py @@ -5,8 +5,6 @@ import logging import os import subprocess -import threading -import time from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import requests @@ -287,38 +285,3 @@ def filter(self, record) -> bool: return ( "Connection pool is full, discarding connection" not in record.getMessage() ) - - -def ttl_cache( - ttl_seconds: Optional[int] = None, maxsize: Optional[int] = None -) -> Callable: - """LRU cache with an optional TTL.""" - - def decorator(func: Callable) -> Callable: - cache: Dict[Tuple, Tuple] = {} - cache_lock = threading.RLock() - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - key = (args, frozenset(kwargs.items())) - with cache_lock: - if key in cache: - result, timestamp = cache[key] - if ttl_seconds is None or time.time() - timestamp < ttl_seconds: - # Refresh the timestamp - cache[key] = (result, time.time()) - return result - result = func(*args, **kwargs) - with cache_lock: - cache[key] = (result, time.time()) - - if maxsize is not None: - if len(cache) > maxsize: - oldest_key = min(cache, key=lambda k: cache[k][1]) - del cache[oldest_key] - - return result - - return wrapper - - return decorator diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index d53d53b7e..54dbd7f83 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -1,4 +1,5 @@ """LangSmith langchain_client Integration Tests.""" + import io import os import random @@ -501,6 +502,6 @@ def test_get_info() -> None: langchain_client = Client(api_key="not-a-real-key") info = langchain_client.info assert info - assert info.version is not None - assert info.batch_ingest_config is not None - assert info.batch_ingest_config["size_limit"] > 0 + assert info.version is not None # type: ignore + assert info.batch_ingest_config is not None # type: ignore + assert info.batch_ingest_config["size_limit"] > 0 # type: ignore diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index fc5c80b5e..19e264fd6 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -1,4 +1,5 @@ """Test the LangSmith client.""" + import asyncio import dataclasses import gc @@ -212,28 +213,27 @@ def __call__(self, *args: object, **kwargs: object) -> None: self.counter += 1 +@pytest.mark.parametrize("supports_batch_endpoint", [True, False]) @pytest.mark.parametrize("auto_batch_tracing", [True, False]) -def test_client_gc_empty(auto_batch_tracing: bool) -> None: - client = Client( - api_url="http://localhost:1984", - api_key="123", - auto_batch_tracing=auto_batch_tracing, - ) - tracker = CallTracker() - weakref.finalize(client, tracker) - assert tracker.counter == 0 - - del client - time.sleep(1) # Give the background thread time to stop - gc.collect() # Force garbage collection - assert tracker.counter == 1, "Client was not garbage collected" +def test_client_gc(auto_batch_tracing: bool, supports_batch_endpoint: bool) -> None: + session = mock.MagicMock(spec=requests.Session) + api_url = "http://localhost:1984" + def mock_get(*args, **kwargs): + if args[0] == f"{api_url}/info": + response = mock.Mock() + if supports_batch_endpoint: + response.json.return_value = {} + else: + response.raise_for_status.side_effect = HTTPError() + response.status_code = 404 + return response + else: + return MagicMock() -@pytest.mark.parametrize("auto_batch_tracing", [True, False]) -def test_client_gc(auto_batch_tracing: bool) -> None: - session = mock.MagicMock(spec=requests.Session) + session.get.side_effect = mock_get client = Client( - api_url="http://localhost:1984", + api_url=api_url, api_key="123", auto_batch_tracing=auto_batch_tracing, session=session, @@ -253,22 +253,32 @@ def test_client_gc(auto_batch_tracing: bool) -> None: dotted_order=id, ) - if auto_batch_tracing: + if auto_batch_tracing and supports_batch_endpoint: assert client.tracing_queue client.tracing_queue.join() request_calls = [call for call in session.request.mock_calls if call.args] - assert len(request_calls) == 1 + assert len(request_calls) >= 1 + for call in request_calls: assert call.args[0] == "post" assert call.args[1] == "http://localhost:1984/runs/batch" + get_calls = [call for call in session.get.mock_calls if call.args] + # assert len(get_calls) == 1 + for call in get_calls: + assert call.args[0] == f"{api_url}/info" else: request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 10 for call in request_calls: assert call.args[0] == "post" assert call.args[1] == "http://localhost:1984/runs" - + if auto_batch_tracing: + get_calls = [call for call in session.get.mock_calls if call.args] + # assert len(get_calls) == 1 + for call in get_calls: + assert call.args[0] == f"{api_url}/info" del client time.sleep(1) # Give the background thread time to stop gc.collect() # Force garbage collection @@ -342,14 +352,34 @@ def test_client_gc_after_autoscale() -> None: assert call.args[1] == "http://localhost:1984/runs/batch" +@pytest.mark.parametrize("supports_batch_endpoint", [True, False]) @pytest.mark.parametrize("auto_batch_tracing", [True, False]) def test_create_run_includes_langchain_env_var_metadata( + supports_batch_endpoint: bool, auto_batch_tracing: bool, ) -> None: + session = mock.Mock() + session.request = mock.Mock() + api_url = "http://localhost:1984" + + def mock_get(*args, **kwargs): + if args[0] == f"{api_url}/info": + response = mock.Mock() + if supports_batch_endpoint: + response.json.return_value = {} + else: + response.raise_for_status.side_effect = HTTPError() + response.status_code = 404 + return response + else: + return MagicMock() + + session.get.side_effect = mock_get client = Client( - api_url="http://localhost:1984", + api_url=api_url, api_key="123", auto_batch_tracing=auto_batch_tracing, + session=session, ) inputs = { "foo": "これは私の友達です", @@ -358,39 +388,34 @@ def test_create_run_includes_langchain_env_var_metadata( "qux": "나는\u3000밥을\u3000먹었습니다.", "는\u3000밥": "나는\u3000밥을\u3000먹었습니다.", } - session = mock.Mock() - session.request = mock.Mock() + # Set the environment variables just for this test with patch.dict(os.environ, {"LANGCHAIN_REVISION": "abcd2234"}): # Clear the cache to ensure the environment variables are re-read ls_env.get_langchain_env_var_metadata.cache_clear() - with patch.object(client, "session", session): - id_ = uuid.uuid4() - start_time = datetime.now() - client.create_run( - "my_run", - inputs=inputs, - run_type="llm", - id=id_, - trace_id=id_, - dotted_order=f"{start_time.strftime('%Y%m%dT%H%M%S%fZ')}{id_}", - start_time=start_time, + id_ = uuid.uuid4() + start_time = datetime.now() + client.create_run( + "my_run", + inputs=inputs, + run_type="llm", + id=id_, + trace_id=id_, + dotted_order=f"{start_time.strftime('%Y%m%dT%H%M%S%fZ')}{id_}", + start_time=start_time, + ) + if tracing_queue := client.tracing_queue: + tracing_queue.join() + # Check the posted value in the request + posted_value = json.loads(session.request.call_args[1]["data"]) + if auto_batch_tracing and supports_batch_endpoint: + assert ( + posted_value["post"][0]["extra"]["metadata"]["LANGCHAIN_REVISION"] + == "abcd2234" ) - if tracing_queue := client.tracing_queue: - tracing_queue.join() - # Check the posted value in the request - posted_value = json.loads(session.request.call_args[1]["data"]) - if not auto_batch_tracing: - assert ( - posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"] - == "abcd2234" - ) - assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"] - else: - assert ( - posted_value["post"][0]["extra"]["metadata"]["LANGCHAIN_REVISION"] - == "abcd2234" - ) + else: + assert posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"] == "abcd2234" + assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"] @pytest.mark.parametrize("source_type", ["api", "model"]) diff --git a/python/tests/unit_tests/test_utils.py b/python/tests/unit_tests/test_utils.py index c7348d6a3..4a7029a2a 100644 --- a/python/tests/unit_tests/test_utils.py +++ b/python/tests/unit_tests/test_utils.py @@ -1,4 +1,3 @@ -import time import unittest import pytest @@ -72,25 +71,3 @@ def test_correct_get_tracer_project(self): else ls_utils.get_tracer_project(case.return_default_value) ) self.assertEqual(project, case.expected_project_name) - - -def test_ttl_cache(): - test_function_val = 0 - - class MyClass: - @property - @ls_utils.ttl_cache(ttl_seconds=0.1) - def test_function(self): - nonlocal test_function_val - test_function_val += 1 - return test_function_val - - some_class = MyClass() - for _ in range(3): - assert some_class.test_function == 1 - time.sleep(0.1) - for _ in range(3): - assert some_class.test_function == 2 - time.sleep(0.1) - for _ in range(3): - assert some_class.test_function == 3