Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip #3657

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

wip #3657

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions weave/trace/concurrent/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import logging
from concurrent.futures import Future, wait
from contextvars import ContextVar
import random
from threading import Lock
from typing import Any, Callable, TypeVar

Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
max_workers: int | None = None,
thread_name_prefix: str = THREAD_NAME_PREFIX,
):
self._id = random.randint(0, 1000000)
self._max_workers = max_workers
self._executor: ContextAwareThreadPoolExecutor | None = None
if max_workers != 0:
Expand Down
37 changes: 34 additions & 3 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Iterator, Sequence
from concurrent.futures import Future
from functools import lru_cache
import threading
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -670,6 +671,9 @@ def to_dict(self) -> CallDict:
)


images = set()


def make_client_call(
entity: str, project: str, server_call: CallSchema, server: TraceServerInterface
) -> WeaveObject:
Expand Down Expand Up @@ -752,7 +756,7 @@ def __repr__(self) -> str:

class WeaveClient:
server: TraceServerInterface
future_executor: FutureExecutor
# future_executor: FutureExecutor

"""
A client for interacting with the Weave trace server.
Expand All @@ -775,7 +779,8 @@ def __init__(
self.project = project
self.server = server
self._anonymous_ops: dict[str, Op] = {}
self.future_executor = FutureExecutor(max_workers=client_parallelism())
self._future_executor_local = threading.local()
self._future_executor = FutureExecutor(max_workers=client_parallelism())
self.ensure_project_exists = ensure_project_exists

if ensure_project_exists:
Expand All @@ -787,6 +792,16 @@ def __init__(
if isinstance(self.server, RemoteHTTPTraceServer):
self._server_is_flushable = self.server.should_batch

@property
def future_executor(self) -> FutureExecutor:
# return self._future_executor

if not hasattr(self._future_executor_local, "future_executor"):
self._future_executor_local.future_executor = FutureExecutor(
max_workers=client_parallelism()
)
return self._future_executor_local.future_executor

################ High Level Convenience Methods ################

@trace_sentry.global_trace_sentry.watch()
Expand Down Expand Up @@ -1873,7 +1888,23 @@ def _flush(self) -> None:
self.server.call_processor.wait_until_all_processed() # type: ignore

def _send_file_create(self, req: FileCreateReq) -> Future[FileCreateRes]:
return self.future_executor.defer(self.server.file_create, req)
import base64
from threading import current_thread
from datetime import datetime

thread = current_thread()
name = base64.b64encode(req.content).decode("utf-8")
print(
"send_file_create",
datetime.now().strftime("%M:%S.%f")[:-3],
name[-30:],
thread.name,
"seen={}".format(name in images),
"len={}".format(len(images)),
)
images.add(name)
out = self.future_executor.defer(self.server.file_create, req)
return out


def safe_current_wb_run_id() -> str | None:
Expand Down
21 changes: 13 additions & 8 deletions weave/trace_server/requests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Helpers for printing HTTP requests and responses."""

from contextlib import contextmanager
import datetime
import json
import os
import threading
from time import time
from typing import Any, Optional, Union
from typing import Any, Iterator, Optional, Union

from requests import HTTPError as HTTPError
from requests import PreparedRequest, Response, Session
Expand All @@ -14,6 +15,8 @@
from rich.syntax import Syntax
from rich.text import Text

from weave.trace.settings import client_parallelism

console = Console()

# See https://rich.readthedocs.io/en/stable/appendix/colors.html
Expand Down Expand Up @@ -147,16 +150,18 @@ def send(self, request: PreparedRequest, **kwargs: Any) -> Response: # type: ig
return response


session = Session()
if os.environ.get("WEAVE_DEBUG_HTTP") == "1":
adapter = LoggingHTTPAdapter()
session.mount("http://", adapter)
session.mount("https://", adapter)
thread_local = threading.local()


def get_session() -> Session:
if not hasattr(thread_local, "session"):
thread_local.session = Session()
return thread_local.session


def get(url: str, params: Optional[dict[str, str]] = None, **kwargs: Any) -> Response:
"""Send a GET request with optional logging."""
return session.get(url, params=params, **kwargs)
return get_session().get(url, params=params, **kwargs)


def post(
Expand All @@ -166,4 +171,4 @@ def post(
**kwargs: Any,
) -> Response:
"""Send a POST request with optional logging."""
return session.post(url, data=data, json=json, **kwargs)
return get_session().post(url, data=data, json=json, **kwargs)
Loading