From bfded0f09167ebcb02852e2e8686450314530e75 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 00:58:15 +0530 Subject: [PATCH 01/20] update opentelemetry --- packages/grid/backend/backend.dockerfile | 3 +- packages/grid/backend/grid/start.sh | 25 ++++++-- packages/grid/default.env | 5 -- .../backend/backend-statefulset.yaml | 14 ++-- packages/grid/helm/syft/values.yaml | 1 + packages/syft/setup.cfg | 20 ++++-- packages/syft/src/syft/util/telemetry.py | 64 ++++++------------- scripts/dev_tools.sh | 2 +- tox.ini | 10 +++ 9 files changed, 71 insertions(+), 73 deletions(-) diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index 606569c49f4..6037dd167f3 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -41,7 +41,7 @@ COPY syft/src/syft/VERSION ./syft/src/syft/ RUN --mount=type=cache,target=/root/.cache,sharing=locked \ # remove torch because we already have the cpu version pre-installed sed --in-place /torch==/d ./syft/setup.cfg && \ - uv pip install -e ./syft[data_science] + uv pip install -e ./syft[data_science,telemetry] # ==================== [Final] Setup Syft Server ==================== # @@ -75,7 +75,6 @@ ENV \ APPDIR="/root/app" \ NODE_NAME="default_node_name" \ NODE_TYPE="domain" \ - SERVICE_NAME="backend" \ RELEASE="production" \ DEV_MODE="False" \ DEBUGGER_ENABLED="False" \ diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 4b3d5de4cf2..4d6571c4570 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -10,19 +10,36 @@ PORT=${PORT:-80} NODE_TYPE=${NODE_TYPE:-domain} APPDIR=${APPDIR:-$HOME/app} RELOAD="" -DEBUG_CMD="" +ROOT_PROC="" if [[ ${DEV_MODE} == "True" ]]; then - echo "DEV_MODE Enabled" + echo "Hot-reload Enabled" RELOAD="--reload" fi # only set by kubernetes to avoid conflict with docker tests if [[ ${DEBUGGER_ENABLED} == "True" ]]; then + echo "Debugger Enabled" uv pip install debugpy - DEBUG_CMD="python -m debugpy --listen 0.0.0.0:5678 -m" + ROOT_PROC="python -m debugpy --listen 0.0.0.0:5678 -m" +fi + +if [[ ${TRACING} == "True" ]]; +then + echo "OpenTelemetry Enabled" + + # TODOs: + # ! Handle case when OTEL_EXPORTER_OTLP_ENDPOINT is not set. + # ! syft-signoz-otel-collector.platform:4317 should be plumbed through helm charts + # ? Kubernetes OTel operator is recommended by signoz + export OTEL_PYTHON_LOG_CORRELATION=${OTEL_PYTHON_LOG_CORRELATION:-true} + export OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_EXPORTER_OTLP_ENDPOINT:-"http://syft-signoz-otel-collector.platform:4317"} + export OTEL_EXPORTER_OTLP_PROTOCOL=${OTEL_EXPORTER_OTLP_PROTOCOL:-grpc} + + # TODO: Finalize if uvicorn is stable with OpenTelemetry + # ROOT_PROC="opentelemetry-instrument" fi export CREDENTIALS_PATH=${CREDENTIALS_PATH:-$HOME/data/creds/credentials.json} @@ -33,4 +50,4 @@ export NODE_TYPE=$NODE_TYPE echo "NODE_UID=$NODE_UID" echo "NODE_TYPE=$NODE_TYPE" -exec $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-config=$APPDIR/grid/logging.yaml --log-level $LOG_LEVEL "$APP_MODULE" +exec $ROOT_PROC uvicorn $RELOAD --host $HOST --port $PORT --log-config=$APPDIR/grid/logging.yaml "$APP_MODULE" diff --git a/packages/grid/default.env b/packages/grid/default.env index 6ae9748bfef..902fdb9c5b5 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -91,11 +91,6 @@ REDIS_HOST=redis CONTAINER_HOST=docker RELATIVE_PATH="" -# Jaeger -TRACE=False -JAEGER_HOST=localhost -JAEGER_PORT=14268 - # Syft SYFT_TUTORIAL_MODE=False ENABLE_WARNINGS=True diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index be0a35d6245..5aab3ab34d6 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -85,7 +85,7 @@ spec: {{- end }} {{- if .Values.node.debuggerEnabled }} - name: DEBUGGER_ENABLED - value: "true" + value: "True" - name: ASSOCIATION_REQUEST_AUTO_APPROVAL value: {{ .Values.node.associationRequestAutoApproval | quote }} {{- end }} @@ -130,14 +130,10 @@ spec: key: s3RootPassword {{- end }} # Tracing - - name: TRACE - value: "false" - - name: SERVICE_NAME - value: "backend" - - name: JAEGER_HOST - value: "localhost" - - name: JAEGER_PORT - value: "14268" + {{- if .Values.node.tracing }} + - name: TRACING + value: "True" + {{- end }} # Enclave attestation {{- if .Values.attestation.enabled }} - name: ENCLAVE_ATTESTATION_ENABLED diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 631475b2462..f61f7b6c9da 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -166,6 +166,7 @@ node: debuggerEnabled: false associationRequestAutoApproval: false useInternalRegistry: true + tracing: true # Default Worker pool settings defaultWorkerPool: diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 56d52e14f74..877c5ebface 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -100,12 +100,20 @@ dev = safety>=2.4.0b2 telemetry = - opentelemetry-api==1.14.0 - opentelemetry-sdk==1.14.0 - opentelemetry-exporter-jaeger==1.14.0 - opentelemetry-instrumentation==0.35b0 - opentelemetry-instrumentation-requests==0.35b0 - ; opentelemetry-instrumentation-digma==0.9.0 + opentelemetry-api==1.25.0 + opentelemetry-sdk==1.25.0 + opentelemetry-exporter-otlp==1.25.0 + opentelemetry-instrumentation==0.46b0 + opentelemetry-instrumentation-requests==0.46b0 + opentelemetry-instrumentation-fastapi==0.46b0 + opentelemetry-instrumentation-pymongo==0.46b0 + opentelemetry-instrumentation-botocore==0.46b0 + opentelemetry-instrumentation-logging==0.46b0 + ; opentelemetry-instrumentation-asyncio==0.46b0 + ; opentelemetry-instrumentation-sqlite3==0.46b0 + ; opentelemetry-instrumentation-threading==0.46b0 + ; opentelemetry-instrumentation-jinja2==0.46b0 + ; opentelemetry-instrumentation-system-metrics==0.46b0 # pytest>=8.0 broke pytest-lazy-fixture which doesn't seem to be actively maintained # temporarily pin to pytest<8 diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index d03f240a1de..0eb13108707 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -5,19 +5,14 @@ from typing import Any from typing import TypeVar -logger = logging.getLogger(__name__) - - -def str_to_bool(bool_str: str | None) -> bool: - result = False - bool_str = str(bool_str).lower() - if bool_str == "true" or bool_str == "1": - result = True - return result +# relative +from .util import str_to_bool +__all__ = ["TRACING_ENABLED", "instrument"] -TRACE_MODE = str_to_bool(os.environ.get("TRACE", "False")) +logger = logging.getLogger(__name__) +TRACING_ENABLED = str_to_bool(os.environ.get("TRACING", "False")) T = TypeVar("T", bound=Callable | type) @@ -26,54 +21,31 @@ def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: return __func_or_class -if not TRACE_MODE: +if not TRACING_ENABLED: instrument = noop else: try: - service_name = os.environ.get("SERVICE_NAME", "client") - jaeger_host = os.environ.get("JAEGER_HOST", "localhost") - jaeger_port = int(os.environ.get("JAEGER_PORT", "14268")) - # third party from opentelemetry import trace - from opentelemetry.exporter.jaeger.thrift import JaegerExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter, + ) from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.resources import SERVICE_NAME from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor - trace.set_tracer_provider( - TracerProvider(resource=Resource.create({SERVICE_NAME: service_name})) - ) - jaeger_exporter = JaegerExporter( - # agent_host_name=jaeger_host, - # agent_port=jaeger_port, - collector_endpoint=f"http://{jaeger_host}:{jaeger_port}/api/traces?format=jaeger.thrift", - # udp_split_oversized_batches=True, - ) + # relative + from .trace_decorator import instrument as _instrument - trace.get_tracer_provider().add_span_processor( - BatchSpanProcessor(jaeger_exporter) + service_name = os.environ.get("OTEL_SERVICE_NAME", "syft-backend") + trace.set_tracer_provider( + TracerProvider(resource=Resource.create({"service.name": service_name})) ) - # from opentelemetry.sdk.trace.export import ConsoleSpanExporter - # console_exporter = ConsoleSpanExporter() - # span_processor = BatchSpanProcessor(console_exporter) - # trace.get_tracer_provider().add_span_processor(span_processor) - - # third party - import opentelemetry.instrumentation.requests - - opentelemetry.instrumentation.requests.RequestsInstrumentor().instrument() - - # relative - # from opentelemetry.instrumentation.digma.trace_decorator import ( - # instrument as _instrument, - # ) - # - # until this is merged: - # https://github.com/digma-ai/opentelemetry-instrumentation-digma/pull/41 - from .trace_decorator import instrument as _instrument + # configured through env:OTEL_EXPORTER_OTLP_ENDPOINT + otlp_exporter = OTLPSpanExporter() + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) instrument = _instrument except Exception as e: diff --git a/scripts/dev_tools.sh b/scripts/dev_tools.sh index 763e602cd28..20a74b597e0 100755 --- a/scripts/dev_tools.sh +++ b/scripts/dev_tools.sh @@ -23,7 +23,7 @@ function docker_list_exposed_ports() { if [[ -z "$1" ]]; then # list db, redis, rabbitmq, and seaweedfs ports - docker_list_exposed_ports "db\|redis\|queue\|seaweedfs\|jaeger\|mongo" + docker_list_exposed_ports "db\|seaweedfs\|mongo" else PORT=$1 if docker ps | grep ":${PORT}" | grep -q 'redis'; then diff --git a/tox.ini b/tox.ini index d15fd50c6da..583ea6bae83 100644 --- a/tox.ini +++ b/tox.ini @@ -793,6 +793,16 @@ commands = ; restarts coredns bash -c 'kubectl delete pod -n kube-system -l k8s-app=kube-dns --context k3d-${CLUSTER_NAME}' +[testenv:dev.k8s.install.signoz] +description = Install Signoz on local Kubernetes cluster +changedir = {toxinidir} +passenv=HOME,USER +allowlist_externals = + bash +commands = + bash -c 'helm repo add signoz https://charts.signoz.io && helm repo update' + bash -c 'helm install syft signoz/signoz --namespace platform --create-namespace || true' + [testenv:dev.k8s.start] description = Start local Kubernetes registry & cluster with k3d changedir = {toxinidir} From 941db20fa15be1743efcd3d012f56da62def6f91 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 01:00:18 +0530 Subject: [PATCH 02/20] use FastAPIInstrumentor --- packages/grid/backend/grid/core/config.py | 2 + packages/grid/backend/grid/main.py | 7 ++++ packages/syft/src/syft/node/routes.py | 50 ++--------------------- packages/syft/src/syft/node/server.py | 7 ++++ 4 files changed, 20 insertions(+), 46 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 8c55b8cd3f7..5af696301a8 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,6 +155,8 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) + + TRACING_ENABLED: bool = str_to_bool(os.getenv("TRACING", "False")) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/main.py b/packages/grid/backend/grid/main.py index 459448c5f01..fca35ac4cc9 100644 --- a/packages/grid/backend/grid/main.py +++ b/packages/grid/backend/grid/main.py @@ -77,3 +77,10 @@ def healthcheck() -> dict[str, str]: probe on the pods backing the Service. """ return {"status": "ok"} + + +if settings.TRACING_ENABLED: + # third party + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 8be45245190..fca70160d69 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -29,7 +29,6 @@ from ..service.user.user import UserPrivateKey from ..service.user.user_service import UserService from ..types.uid import UID -from ..util.telemetry import TRACE_MODE from .credentials import SyftVerifyKey from .credentials import UserLoginCredentials from .worker import Worker @@ -38,15 +37,6 @@ def make_routes(worker: Worker) -> APIRouter: - if TRACE_MODE: - # third party - try: - # third party - from opentelemetry import trace - from opentelemetry.propagate import extract - except Exception as e: - logger.error("Failed to import opentelemetry", exc_info=e) - router = APIRouter() async def get_body(request: Request) -> bytes: @@ -129,15 +119,7 @@ def syft_new_api( request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE ) -> Response: user_verify_key: SyftVerifyKey = SyftVerifyKey.from_string(verify_key) - if TRACE_MODE: - with trace.get_tracer(syft_new_api.__module__).start_as_current_span( - syft_new_api.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_syft_new_api(user_verify_key, communication_protocol) - else: - return handle_syft_new_api(user_verify_key, communication_protocol) + return handle_syft_new_api(user_verify_key, communication_protocol) def handle_new_api_call(data: bytes) -> Response: obj_msg = deserialize(blob=data, from_bytes=True) @@ -152,15 +134,7 @@ def handle_new_api_call(data: bytes) -> Response: def syft_new_api_call( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: - if TRACE_MODE: - with trace.get_tracer(syft_new_api_call.__module__).start_as_current_span( - syft_new_api_call.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_new_api_call(data) - else: - return handle_new_api_call(data) + return handle_new_api_call(data) def handle_login(email: str, password: str, node: AbstractNode) -> Response: try: @@ -217,28 +191,12 @@ def login( email: Annotated[str, Body(example="info@openmined.org")], password: Annotated[str, Body(example="changethis")], ) -> Response: - if TRACE_MODE: - with trace.get_tracer(login.__module__).start_as_current_span( - login.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_login(email, password, worker) - else: - return handle_login(email, password, worker) + return handle_login(email, password, worker) @router.post("/register", name="register", status_code=200) def register( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: - if TRACE_MODE: - with trace.get_tracer(register.__module__).start_as_current_span( - register.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_register(data, worker) - else: - return handle_register(data, worker) + return handle_register(data, worker) return router diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 43b8359a1f9..ab11e78ae7d 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -20,6 +20,7 @@ from ..abstract_node import NodeSideType from ..client.client import API_PATH from ..util.constants import DEFAULT_TIMEOUT +from ..util.telemetry import TRACING_ENABLED from ..util.util import os_name from .domain import Domain from .enclave import Enclave @@ -53,6 +54,12 @@ def make_app(name: str, router: APIRouter) -> FastAPI: allow_headers=["*"], ) + if TRACING_ENABLED: + # third party + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor().instrument_app(app) + return app From 1d9bccbbddcd7be9abc9ba522c24bb4e95bb2abd Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 03:04:50 +0530 Subject: [PATCH 03/20] remove noisy traces because of @instrument --- packages/syft/src/syft/client/api.py | 4 ---- packages/syft/src/syft/client/client.py | 6 ------ packages/syft/src/syft/node/node.py | 10 ++++++++-- packages/syft/src/syft/service/api/api_service.py | 2 -- packages/syft/src/syft/service/code/status_service.py | 3 --- .../syft/src/syft/service/code/user_code_service.py | 2 -- packages/syft/src/syft/service/code/user_code_stash.py | 2 -- .../syft/service/code_history/code_history_service.py | 2 -- .../data_subject/data_subject_member_service.py | 3 --- .../syft/service/data_subject/data_subject_service.py | 3 --- .../syft/src/syft/service/dataset/dataset_service.py | 2 -- .../syft/src/syft/service/dataset/dataset_stash.py | 2 -- packages/syft/src/syft/service/job/job_service.py | 2 -- packages/syft/src/syft/service/job/job_stash.py | 2 -- packages/syft/src/syft/service/log/log_service.py | 2 -- packages/syft/src/syft/service/log/log_stash.py | 2 -- .../syft/src/syft/service/metadata/metadata_service.py | 2 -- .../syft/src/syft/service/network/network_service.py | 3 --- .../syft/service/notification/notification_service.py | 2 -- .../syft/service/notification/notification_stash.py | 2 -- .../syft/src/syft/service/notifier/notifier_stash.py | 2 -- .../syft/src/syft/service/output/output_service.py | 3 --- .../syft/src/syft/service/project/project_service.py | 2 -- .../syft/src/syft/service/project/project_stash.py | 2 -- packages/syft/src/syft/service/queue/queue_service.py | 2 -- packages/syft/src/syft/service/queue/queue_stash.py | 2 -- .../syft/src/syft/service/request/request_service.py | 2 -- .../syft/src/syft/service/request/request_stash.py | 2 -- .../syft/src/syft/service/settings/settings_stash.py | 2 -- packages/syft/src/syft/service/sync/sync_stash.py | 2 -- packages/syft/src/syft/service/user/user_service.py | 2 -- packages/syft/src/syft/service/user/user_stash.py | 2 -- .../syft/src/syft/service/worker/worker_service.py | 2 -- packages/syft/src/syft/service/worker/worker_stash.py | 2 -- packages/syft/src/syft/store/document_store.py | 5 ----- 35 files changed, 8 insertions(+), 84 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 9c8b244b129..2012623a1fe 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -63,7 +63,6 @@ from ..util.autoreload import autoreload_enabled from ..util.markdown import as_markdown_python_code from ..util.notebook_ui.components.tabulator_template import build_tabulator_table -from ..util.telemetry import instrument from ..util.util import prompt_warning_message from .connection import NodeConnection @@ -203,7 +202,6 @@ def is_valid(self) -> Result[SyftSuccess, SyftError]: return SyftSuccess(message="Credentials are valid") -@instrument @serializable() class SyftAPICall(SyftObject): # version @@ -230,7 +228,6 @@ def __repr__(self) -> str: return f"SyftAPICall(path={self.path}, args={self.args}, kwargs={self.kwargs}, blocking={self.blocking})" -@instrument @serializable() class SyftAPIData(SyftBaseObject): # version @@ -864,7 +861,6 @@ class SyftAPIV2(SyftObject): __syft_allow_autocomplete__ = ["services"] -@instrument @serializable( attrs=[ "endpoints", diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index f5bde266ae9..42576c5467c 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -50,7 +50,6 @@ from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.uid import UID -from ..util.telemetry import instrument from ..util.util import prompt_warning_message from ..util.util import thread_ident from ..util.util import verify_tls @@ -509,7 +508,6 @@ def get_client_type(self) -> type[SyftClient] | SyftError: return SyftError(message=f"Unknown node type {metadata.node_type}") -@instrument @serializable() class SyftClient: connection: NodeConnection @@ -971,7 +969,6 @@ def refresh_callback() -> SyftAPI: return _api -@instrument def connect( url: str | GridURL = DEFAULT_PYGRID_ADDRESS, node: AbstractNode | None = None, @@ -993,7 +990,6 @@ def connect( return client_type(connection=connection) -@instrument def register( url: str | GridURL, port: int, @@ -1013,7 +1009,6 @@ def register( ) -@instrument def login_as_guest( # HTTPConnection url: str | GridURL = DEFAULT_PYGRID_ADDRESS, @@ -1040,7 +1035,6 @@ def login_as_guest( return _client.guest() -@instrument def login( email: str, # HTTPConnection diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index ea9ab3f8a47..a1cca3fa1bb 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -313,7 +313,6 @@ def auth_context_for_user( return cls.__node_context_registry__.get(key) -@instrument class Node(AbstractNode): signing_key: SyftSigningKey | None required_signed_calls: bool = True @@ -956,7 +955,7 @@ def _construct_services(self) -> None: {"svc": SyftImageRegistryService}, {"svc": SyncService}, {"svc": OutputService}, - {"svc": UserCodeStatusService}, # this is lazy + {"svc": UserCodeStatusService}, ] for svc_kwargs in default_services: @@ -1113,6 +1112,7 @@ def await_future( return res sleep(0.1) + @instrument def resolve_future( self, credentials: SyftVerifyKey, uid: UID ) -> QueueItem | None | SyftError: @@ -1127,6 +1127,7 @@ def resolve_future( return queue_obj return result.err() + @instrument def forward_message( self, api_call: SyftAPICall | SignedSyftAPICall ) -> Result | QueueItem | SyftObject | SyftError | Any: @@ -1197,6 +1198,7 @@ def get_role_for_credentials(self, credentials: SyftVerifyKey) -> ServiceRole: ) return role + @instrument def handle_api_call( self, api_call: SyftAPICall | SignedSyftAPICall, @@ -1354,6 +1356,7 @@ def get_worker_pool_ref_by_name( ) return worker_pool_ref + @instrument def add_action_to_queue( self, action: Action, @@ -1406,6 +1409,7 @@ def add_action_to_queue( user_id=user_id, ) + @instrument def add_queueitem_to_queue( self, *, @@ -1506,6 +1510,7 @@ def _is_usercode_call_on_owned_kwargs( context, user_code_id, api_call.kwargs ) + @instrument def add_api_call_to_queue( self, api_call: SyftAPICall, parent_job_id: UID | None = None ) -> Job | SyftError: @@ -1627,6 +1632,7 @@ def get_worker_pool_by_name(self, name: str) -> WorkerPool | None | SyftError: worker_pool = result.ok() return worker_pool + @instrument def get_api( self, for_user: SyftVerifyKey | None = None, diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 0f8405db166..83e4acb249d 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -14,7 +14,6 @@ from ...service.action.action_object import ActionObject from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_service import ActionService from ..context import AuthedServiceContext from ..response import SyftError @@ -36,7 +35,6 @@ from .api_stash import TwinAPIEndpointStash -@instrument @serializable() class APIService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index e787ae6e096..fc87aef2497 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -12,7 +12,6 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -24,7 +23,6 @@ from .user_code import UserCodeStatusCollection -@instrument @serializable() class StatusStash(BaseUIDStoreStash): object_type = UserCodeStatusCollection @@ -46,7 +44,6 @@ def get_by_uid( return self.query_one(credentials=credentials, qks=qks) -@instrument @serializable() class UserCodeStatusService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 58ec982ac2c..b0a5ab34b8b 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -18,7 +18,6 @@ from ...types.syft_metaclass import Empty from ...types.twin_object import TwinObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission @@ -50,7 +49,6 @@ from .user_code_stash import UserCodeStash -@instrument @serializable() class UserCodeService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index fa9fad49b82..208170048c6 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -10,7 +10,6 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys -from ...util.telemetry import instrument from .user_code import CodeHashPartitionKey from .user_code import ServiceFuncNamePartitionKey from .user_code import SubmitTimePartitionKey @@ -18,7 +17,6 @@ from .user_code import UserVerifyKeyPartitionKey -@instrument @serializable() class UserCodeStash(BaseUIDStoreStash): object_type = UserCode diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index adfd6dbee5d..39cfceee653 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -5,7 +5,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..code.user_code import SubmitUserCode from ..code.user_code import UserCode from ..context import AuthedServiceContext @@ -22,7 +21,6 @@ from .code_history_stash import CodeHistoryStash -@instrument @serializable() class CodeHistoryService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index 57f38f445ec..e8af5deb2eb 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -10,7 +10,6 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -22,7 +21,6 @@ from .data_subject_member import ParentPartitionKey -@instrument @serializable() class DataSubjectMemberStash(BaseUIDStoreStash): object_type = DataSubjectMemberRelationship @@ -47,7 +45,6 @@ def get_all_for_child( return self.query_all(credentials=credentials, qks=qks) -@instrument @serializable() class DataSubjectMemberService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index 02027aebd7c..ba52d259faa 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -10,7 +10,6 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -24,7 +23,6 @@ from .data_subject_member_service import DataSubjectMemberService -@instrument @serializable() class DataSubjectStash(BaseUIDStoreStash): object_type = DataSubject @@ -54,7 +52,6 @@ def update( return super().update(credentials=credentials, obj=res.ok()) -@instrument @serializable() class DataSubjectService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 451746fa15a..78d04de86ca 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -7,7 +7,6 @@ from ...store.document_store import DocumentStore from ...types.dicttuple import DictTuple from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext @@ -65,7 +64,6 @@ def _paginate_dataset_collection( ) -@instrument @serializable() class DatasetService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index ee99a4411c7..a36be5ff151 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -12,7 +12,6 @@ from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys from ...types.uid import UID -from ...util.telemetry import instrument from .dataset import Dataset from .dataset import DatasetUpdate @@ -20,7 +19,6 @@ ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable() class DatasetStash(BaseUIDStoreStash): object_type = Dataset diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 368992ceaa5..99953746fe6 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -10,7 +10,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission @@ -44,7 +43,6 @@ def wait_until( return SyftError(message=f"Timeout reached for predicate {code_string}") -@instrument @serializable() class JobService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index e69c539cb5c..03e6f719c61 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -39,7 +39,6 @@ from ...util import options from ...util.colors import SURFACE from ...util.markdown import as_markdown_code -from ...util.telemetry import instrument from ...util.util import prompt_warning_message from ..action.action_object import Action from ..action.action_object import ActionObject @@ -825,7 +824,6 @@ def from_job( return info -@instrument @serializable() class JobStash(BaseUIDStoreStash): object_type = Job diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index 3390171d1a4..c9230366f84 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -2,7 +2,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext from ..response import SyftError @@ -16,7 +15,6 @@ from .log_stash import LogStash -@instrument @serializable() class LogService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index f1c37d9f6b2..1325a340f5b 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -3,11 +3,9 @@ from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings -from ...util.telemetry import instrument from .log import SyftLog -@instrument @serializable() class LogStash(BaseUIDStoreStash): object_type = SyftLog diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index a1e7eebb799..e10535096c9 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -3,7 +3,6 @@ # relative from ...serde.serializable import serializable from ...store.document_store import DocumentStore -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method @@ -11,7 +10,6 @@ from .node_metadata import NodeMetadata -@instrument @serializable() class MetadataService(AbstractService): def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index c32374ade31..a2085890b92 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -29,7 +29,6 @@ from ...types.transforms import transform from ...types.transforms import transform_method from ...types.uid import UID -from ...util.telemetry import instrument from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey @@ -70,7 +69,6 @@ class NodePeerAssociationStatus(Enum): PEER_NOT_FOUND = "PEER_NOT_FOUND" -@instrument @serializable() class NetworkStash(BaseUIDStoreStash): object_type = NodePeer @@ -145,7 +143,6 @@ def get_by_node_type( ) -@instrument @serializable() class NetworkService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index bfdc303c904..7c3adf06d4b 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -4,7 +4,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext from ..notifier.notifier import NotifierSettings @@ -25,7 +24,6 @@ from .notifications import ReplyNotification -@instrument @serializable() class NotificationService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index 84aafb33849..6ebd9029091 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -15,7 +15,6 @@ from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.uid import UID -from ...util.telemetry import instrument from .notifications import Notification from .notifications import NotificationStatus @@ -32,7 +31,6 @@ LinkedObjectPartitionKey = PartitionKey(key="linked_obj", type_=LinkedObject) -@instrument @serializable() class NotificationStash(BaseUIDStoreStash): object_type = Notification diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 9c02e153d74..7b7ac17a7db 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -14,7 +14,6 @@ from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from .notifier import NotifierSettings @@ -22,7 +21,6 @@ ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable() class NotifierStash(BaseStash): object_type = NotifierSettings diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 4efe75ec618..fa053ddc93a 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -21,7 +21,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext @@ -189,7 +188,6 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return res -@instrument @serializable() class OutputStash(BaseUIDStoreStash): object_type = ExecutionOutput @@ -244,7 +242,6 @@ def get_by_output_policy_id( ) -@instrument @serializable() class OutputService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 3b8cef606ac..a3413554c97 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -5,7 +5,6 @@ from ...store.document_store import DocumentStore from ...store.linked_obj import LinkedObject from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..notification.notification_service import NotificationService from ..notification.notifications import CreateNotification @@ -27,7 +26,6 @@ from .project_stash import ProjectStash -@instrument @serializable() class ProjectService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index 0866db4b252..521951424dd 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -12,7 +12,6 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.uid import UID -from ...util.telemetry import instrument from ..request.request import Request from ..response import SyftError from .project import Project @@ -21,7 +20,6 @@ NamePartitionKey = PartitionKey(key="name", type_=str) -@instrument @serializable() class ProjectStash(BaseUIDStoreStash): object_type = Project diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index d1cf119076a..6ee51af2579 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -4,7 +4,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError from ..service import AbstractService @@ -14,7 +13,6 @@ from .queue_stash import QueueStash -@instrument @serializable() class QueueService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index dc38735eff1..c6ab1e9be40 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -22,7 +22,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SyftObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..response import SyftError from ..response import SyftSuccess @@ -97,7 +96,6 @@ class APIEndpointQueueItem(QueueItem): service: str = "apiservice" -@instrument @serializable() class QueueStash(BaseStash): object_type = QueueItem diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 919effe4fcc..4391c6e6d4e 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -7,7 +7,6 @@ from ...store.document_store import DocumentStore from ...store.linked_obj import LinkedObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext @@ -37,7 +36,6 @@ from .request_stash import RequestStash -@instrument @serializable() class RequestService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index dedee590357..6ae364c3e53 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -13,7 +13,6 @@ from ...store.document_store import QueryKeys from ...types.datetime import DateTime from ...types.uid import UID -from ...util.telemetry import instrument from .request import Request RequestingUserVerifyKeyPartitionKey = PartitionKey( @@ -23,7 +22,6 @@ OrderByRequestTimeStampPartitionKey = PartitionKey(key="request_time", type_=DateTime) -@instrument @serializable() class RequestStash(BaseUIDStoreStash): object_type = Request diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 4abe81d0a92..4a750580b1d 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -11,7 +11,6 @@ from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from .settings import NodeSettings @@ -19,7 +18,6 @@ ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable() class SettingsStash(BaseUIDStoreStash): object_type = NodeSettings diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index e91f7b1c485..d87ac7993c6 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -13,14 +13,12 @@ from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.datetime import DateTime -from ...util.telemetry import instrument from ..context import AuthedServiceContext from .sync_state import SyncState OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime) -@instrument @serializable() class SyncStash(BaseUIDStoreStash): object_type = SyncState diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 25d85bccd9f..464be66c640 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -10,7 +10,6 @@ from ...store.linked_obj import LinkedObject from ...types.syft_metaclass import Empty from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext @@ -45,7 +44,6 @@ from .user_stash import UserStash -@instrument @serializable() class UserService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 3bc8ed2dcfe..a8ba8294f1f 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -15,7 +15,6 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..response import SyftSuccess from .user import User @@ -28,7 +27,6 @@ VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) -@instrument @serializable() class UserStash(BaseStash): object_type = User diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 6c574b99735..8916e5d0dd7 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -16,7 +16,6 @@ from ...store.document_store import DocumentStore from ...store.document_store import SyftSuccess from ...types.uid import UID -from ...util.telemetry import instrument from ..service import AbstractService from ..service import AuthedServiceContext from ..service import SyftError @@ -36,7 +35,6 @@ from .worker_stash import WorkerStash -@instrument @serializable() class WorkerService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 77e7dfd281a..81e2845ab2e 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -14,7 +14,6 @@ from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from .worker_pool import ConsumerState @@ -23,7 +22,6 @@ WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str) -@instrument @serializable() class WorkerStash(BaseUIDStoreStash): object_type = SyftWorker diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index e71110667c3..fbebcca84a7 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -28,7 +28,6 @@ from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftObject from ..types.uid import UID -from ..util.telemetry import instrument from .locks import LockingConfig from .locks import NoLockingConfig from .locks import SyftLock @@ -291,7 +290,6 @@ def searchable_keys(self) -> PartitionKeys: return PartitionKeys.from_dict(self.object_type._syft_searchable_keys_dict()) -@instrument @serializable(attrs=["settings", "store_config", "unique_cks", "searchable_cks"]) class StorePartition: """Base StorePartition @@ -551,7 +549,6 @@ def _migrate_data( raise NotImplementedError -@instrument @serializable() class DocumentStore: """Base Document Store @@ -588,7 +585,6 @@ def partition(self, settings: PartitionSettings) -> StorePartition: return self.partitions[settings.name] -@instrument class BaseStash: object_type: type[SyftObject] settings: PartitionSettings @@ -752,7 +748,6 @@ def update( return res -@instrument class BaseUIDStoreStash(BaseStash): def delete_by_uid( self, credentials: SyftVerifyKey, uid: UID From 26078be9fc4194c66f8690204ca6b41deec7511a Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 11:42:37 +0530 Subject: [PATCH 04/20] add trace to service_method --- packages/syft/src/syft/service/service.py | 5 +++ packages/syft/src/syft/util/telemetry.py | 24 +++++++++----- .../syft/src/syft/util/trace_decorator.py | 31 +++++++++---------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index c92695e2f6a..a25784f8e9e 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -37,6 +37,7 @@ from ..types.syft_object import SyftObject from ..types.syft_object import attach_attribute_to_syft_object from ..types.uid import UID +from ..util.telemetry import instrument from .context import AuthedServiceContext from .context import ChangeContext from .response import SyftError @@ -351,6 +352,10 @@ def wrapper(func: Any) -> Callable: input_signature = deepcopy(signature) + @instrument( # type: ignore + span_name=f"service_method::{_path}", + attributes={"service.name": name, "service.path": path}, + ) def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: communication_protocol = kwargs.pop("communication_protocol", None) diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index 0eb13108707..ef376d78fc0 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -17,8 +17,14 @@ T = TypeVar("T", bound=Callable | type) -def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: - return __func_or_class +def noop(__func_or_class: T | None = None, /, *args: Any, **kwargs: Any) -> T: + def noop_wrapper(__func_or_class: T) -> T: + return __func_or_class + + if __func_or_class is None: + return noop_wrapper # type: ignore + else: + return __func_or_class if not TRACING_ENABLED: @@ -37,16 +43,20 @@ def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: # relative from .trace_decorator import instrument as _instrument + # create a provider service_name = os.environ.get("OTEL_SERVICE_NAME", "syft-backend") - trace.set_tracer_provider( - TracerProvider(resource=Resource.create({"service.name": service_name})) - ) + resource = Resource.create({"service.name": service_name}) + provider = TracerProvider(resource=resource) - # configured through env:OTEL_EXPORTER_OTLP_ENDPOINT + # create a span processor otlp_exporter = OTLPSpanExporter() span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) + provider.add_span_processor(span_processor) + + # set the global trace provider + trace.set_tracer_provider(provider) + # expose the instrument decorator instrument = _instrument except Exception as e: logger.error("Failed to import opentelemetry", exc_info=e) diff --git a/packages/syft/src/syft/util/trace_decorator.py b/packages/syft/src/syft/util/trace_decorator.py index 87486b0cda4..719ca822afa 100644 --- a/packages/syft/src/syft/util/trace_decorator.py +++ b/packages/syft/src/syft/util/trace_decorator.py @@ -9,7 +9,6 @@ from typing import Any from typing import ClassVar from typing import TypeVar -from typing import cast # third party from opentelemetry import trace @@ -44,7 +43,7 @@ def set_default_attributes(cls, attributes: dict[str, str] | None = None) -> Non def instrument( - _func_or_class: T, + _func_or_class: T | None = None, /, *, span_name: str = "", @@ -99,16 +98,12 @@ def decorate_class(cls: T) -> T: return cls - # Check if this is a span or class decorator - if inspect.isclass(_func_or_class): - return decorate_class(_func_or_class) - def span_decorator(func_or_class: T) -> T: - if inspect.isclass(func_or_class): + if ignore: + return func_or_class + elif inspect.isclass(func_or_class): return decorate_class(func_or_class) - # sig = inspect.signature(func_or_class) - # Check if already decorated (happens if both class and function # decorated). If so, we keep the function decorator settings only undecorated_func = getattr(func_or_class, "__tracing_unwrapped__", None) @@ -155,16 +150,20 @@ async def wrap_with_span_async(*args: Any, **kwargs: Any) -> Callable: _set_attributes(span, attributes) return await func_or_class(*args, **kwargs) - if ignore: - return func_or_class - - wrapper = ( + span_wrapper = ( wrap_with_span_async if asyncio.iscoroutinefunction(func_or_class) else wrap_with_span_sync ) - wrapper.__signature__ = inspect.signature(func_or_class) + span_wrapper.__signature__ = inspect.signature(func_or_class) - return cast(T, wrapper) + return span_wrapper # type: ignore - return span_decorator(_func_or_class) + # decorator factory on a class or func + # @instrument or @instrument(span_name="my_span", ...) + if _func_or_class and inspect.isclass(_func_or_class): + return decorate_class(_func_or_class) + elif _func_or_class: + return span_decorator(_func_or_class) + else: + return span_decorator # type: ignore From d9741a271f8d48d23f9f25efaddcbfdc660000f9 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 11:42:51 +0530 Subject: [PATCH 05/20] add trace for pymongo --- packages/syft/src/syft/store/mongo_client.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index 55208d2e614..3f1899ab050 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -13,11 +13,21 @@ # relative from ..serde.serializable import serializable +from ..util.telemetry import TRACING_ENABLED from .document_store import PartitionSettings from .document_store import StoreClientConfig from .document_store import StoreConfig from .mongo_codecs import SYFT_CODEC_OPTIONS +if TRACING_ENABLED: + try: + # third party + from opentelemetry.instrumentation.pymongo import PymongoInstrumentor + + PymongoInstrumentor().instrument() + except Exception: + pass + @serializable() class MongoStoreClientConfig(StoreClientConfig): From 133bec6ffe6a8625de55c37c237119fc9fbf7c02 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 11:43:24 +0530 Subject: [PATCH 06/20] start.sh tweaks --- packages/grid/backend/grid/start.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 4d6571c4570..80aa1ab0089 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -1,7 +1,7 @@ #! /usr/bin/env bash set -e -echo "Running Syft with RELEASE=${RELEASE} and $(id)" +echo "Running Syft with RELEASE=${RELEASE}" APP_MODULE=grid.main:app LOG_LEVEL=${LOG_LEVEL:-info} @@ -38,7 +38,7 @@ then export OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_EXPORTER_OTLP_ENDPOINT:-"http://syft-signoz-otel-collector.platform:4317"} export OTEL_EXPORTER_OTLP_PROTOCOL=${OTEL_EXPORTER_OTLP_PROTOCOL:-grpc} - # TODO: Finalize if uvicorn is stable with OpenTelemetry + # TODO: uvicorn postfork is not stable with OpenTelemetry # ROOT_PROC="opentelemetry-instrument" fi From 12f004712d930b76d8c20bd33909295a6608f940 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 25 Jun 2024 11:54:24 +0530 Subject: [PATCH 07/20] add trace for botocore --- packages/syft/src/syft/store/blob_storage/seaweedfs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 03c6f442c26..f2715f87f3e 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -40,6 +40,7 @@ from ...types.grid_url import GridURL from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...util.constants import DEFAULT_TIMEOUT +from ...util.telemetry import TRACING_ENABLED logger = logging.getLogger(__name__) @@ -48,6 +49,15 @@ DEFAULT_FILE_PART_SIZE = 1024**3 # 1GB DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 800 # 800KB +if TRACING_ENABLED: + try: + # third party + from opentelemetry.instrumentation.botocore import BotocoreInstrumentor + + BotocoreInstrumentor().instrument() + except Exception: + pass + @serializable() class SeaweedFSBlobDeposit(BlobDeposit): From b1d5e267ff2d86a260fd3ab2643d37f491f7bcfc Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 2 Jul 2024 18:40:41 +0530 Subject: [PATCH 08/20] refactor zmq code to separate files --- packages/grid/backend/grid/core/node.py | 4 +- packages/syft/src/syft/node/node.py | 6 +- .../syft/src/syft/service/queue/zmq_client.py | 188 +++++++ .../syft/src/syft/service/queue/zmq_common.py | 109 ++++ .../src/syft/service/queue/zmq_consumer.py | 243 ++++++++ .../queue/{zmq_queue.py => zmq_producer.py} | 522 +----------------- packages/syft/tests/syft/zmq_queue_test.py | 10 +- 7 files changed, 570 insertions(+), 512 deletions(-) create mode 100644 packages/syft/src/syft/service/queue/zmq_client.py create mode 100644 packages/syft/src/syft/service/queue/zmq_common.py create mode 100644 packages/syft/src/syft/service/queue/zmq_consumer.py rename packages/syft/src/syft/service/queue/{zmq_queue.py => zmq_producer.py} (54%) diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index cde36f8c5fe..c134a99f435 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -10,8 +10,8 @@ from syft.node.node import get_node_side_type from syft.node.node import get_node_type from syft.node.node import get_node_uid_env -from syft.service.queue.zmq_queue import ZMQClientConfig -from syft.service.queue.zmq_queue import ZMQQueueConfig +from syft.service.queue.zmq_client import ZMQClientConfig +from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig from syft.store.mongo_client import MongoStoreClientConfig diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index d7730358c3a..81ece7e2ff6 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -89,9 +89,9 @@ from ..service.queue.queue_stash import ActionQueueItem from ..service.queue.queue_stash import QueueItem from ..service.queue.queue_stash import QueueStash -from ..service.queue.zmq_queue import QueueConfig -from ..service.queue.zmq_queue import ZMQClientConfig -from ..service.queue.zmq_queue import ZMQQueueConfig +from ..service.queue.zmq_client import QueueConfig +from ..service.queue.zmq_client import ZMQClientConfig +from ..service.queue.zmq_client import ZMQQueueConfig from ..service.request.request_service import RequestService from ..service.response import SyftError from ..service.service import AbstractService diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py new file mode 100644 index 00000000000..932056002de --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -0,0 +1,188 @@ +# stdlib +from collections import defaultdict +import socketserver + +# relative +from ...serde.serializable import serializable +from ...service.context import AuthedServiceContext +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.util import get_queue_address +from ..response import SyftError +from ..response import SyftSuccess +from ..worker.worker_stash import WorkerStash +from .base_queue import AbstractMessageHandler +from .base_queue import QueueClient +from .base_queue import QueueClientConfig +from .base_queue import QueueConfig +from .queue_stash import QueueStash +from .zmq_consumer import ZMQConsumer +from .zmq_producer import ZMQProducer + + +@serializable() +class ZMQClientConfig(SyftObject, QueueClientConfig): + __canonical_name__ = "ZMQClientConfig" + __version__ = SYFT_OBJECT_VERSION_4 + + id: UID | None = None # type: ignore[assignment] + hostname: str = "127.0.0.1" + queue_port: int | None = None + # TODO: setting this to false until we can fix the ZMQ + # port issue causing tests to randomly fail + create_producer: bool = False + n_consumers: int = 0 + consumer_service: str | None = None + + +@serializable(attrs=["host"]) +class ZMQClient(QueueClient): + """ZMQ Client for creating producers and consumers.""" + + producers: dict[str, ZMQProducer] + consumers: defaultdict[str, list[ZMQConsumer]] + + def __init__(self, config: ZMQClientConfig) -> None: + self.host = config.hostname + self.producers = {} + self.consumers = defaultdict(list) + self.config = config + + @staticmethod + def _get_free_tcp_port(host: str) -> int: + with socketserver.TCPServer((host, 0), None) as s: + free_port = s.server_address[1] + + return free_port + + def add_producer( + self, + queue_name: str, + port: int | None = None, + queue_stash: QueueStash | None = None, + worker_stash: WorkerStash | None = None, + context: AuthedServiceContext | None = None, + ) -> ZMQProducer: + """Add a producer of a queue. + + A queue can have at most one producer attached to it. + """ + + if port is None: + if self.config.queue_port is None: + self.config.queue_port = self._get_free_tcp_port(self.host) + port = self.config.queue_port + else: + port = self.config.queue_port + + producer = ZMQProducer( + queue_name=queue_name, + queue_stash=queue_stash, + port=port, + context=context, + worker_stash=worker_stash, + ) + self.producers[queue_name] = producer + return producer + + def add_consumer( + self, + queue_name: str, + message_handler: AbstractMessageHandler, + service_name: str, + address: str | None = None, + worker_stash: WorkerStash | None = None, + syft_worker_id: UID | None = None, + ) -> ZMQConsumer: + """Add a consumer to a queue + + A queue should have at least one producer attached to the group. + + """ + + if address is None: + address = get_queue_address(self.config.queue_port) + + consumer = ZMQConsumer( + queue_name=queue_name, + message_handler=message_handler, + address=address, + service_name=service_name, + syft_worker_id=syft_worker_id, + worker_stash=worker_stash, + ) + self.consumers[queue_name].append(consumer) + + return consumer + + def send_message( + self, + message: bytes, + queue_name: str, + worker: bytes | None = None, + ) -> SyftSuccess | SyftError: + producer = self.producers.get(queue_name) + if producer is None: + return SyftError( + message=f"No producer attached for queue: {queue_name}. Please add a producer for it." + ) + try: + producer.send(message=message, worker=worker) + except Exception as e: + # stdlib + return SyftError( + message=f"Failed to send message to: {queue_name} with error: {e}" + ) + return SyftSuccess( + message=f"Successfully queued message to : {queue_name}", + ) + + def close(self) -> SyftError | SyftSuccess: + try: + for consumers in self.consumers.values(): + for consumer in consumers: + # make sure look is stopped + consumer.close() + + for producer in self.producers.values(): + # make sure loop is stopped + producer.close() + # close existing connection. + except Exception as e: + return SyftError(message=f"Failed to close connection: {e}") + + return SyftSuccess(message="All connections closed.") + + def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: + if queue_name not in self.producers: + return SyftError(message=f"No producer running for : {queue_name}") + + producer = self.producers[queue_name] + + # close existing connection. + producer.close() + + # add a new connection + self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore + + return SyftSuccess(message=f"Queue: {queue_name} successfully purged") + + def purge_all(self) -> SyftError | SyftSuccess: + for queue_name in self.producers: + self.purge_queue(queue_name=queue_name) + + return SyftSuccess(message="Successfully purged all queues.") + + +@serializable() +class ZMQQueueConfig(QueueConfig): + def __init__( + self, + client_type: type[ZMQClient] | None = None, + client_config: ZMQClientConfig | None = None, + thread_workers: bool = False, + ): + self.client_type = client_type or ZMQClient + self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() + self.thread_workers = thread_workers diff --git a/packages/syft/src/syft/service/queue/zmq_common.py b/packages/syft/src/syft/service/queue/zmq_common.py new file mode 100644 index 00000000000..1b47e3edcfc --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_common.py @@ -0,0 +1,109 @@ +# stdlib +import threading +import time +from typing import Any + +# third party +from pydantic import field_validator + +# relative +from ...types.base import SyftBaseModel +from ...types.uid import UID + +# Producer/Consumer heartbeat interval (in seconds) +HEARTBEAT_INTERVAL_SEC = 2 + +# Thread join timeout (in seconds) +THREAD_TIMEOUT_SEC = 5 + +# Max duration (in ms) to wait for ZMQ poller to return +ZMQ_POLLER_TIMEOUT_MSEC = 1000 + +# Duration (in seconds) after which a worker without a heartbeat will be marked as expired +WORKER_TIMEOUT_SEC = 60 + +# Duration (in seconds) after which producer without a heartbeat will be marked as expired +PRODUCER_TIMEOUT_SEC = 60 + +# Lock for working on ZMQ socket +ZMQ_SOCKET_LOCK = threading.Lock() + +MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 + + +class ZMQHeader: + """Enum for ZMQ headers""" + + W_WORKER = b"MDPW01" + + +class ZMQCommand: + """Enum for ZMQ commands""" + + W_READY = b"0x01" + W_REQUEST = b"0x02" + W_REPLY = b"0x03" + W_HEARTBEAT = b"0x04" + W_DISCONNECT = b"0x05" + + +class Timeout: + def __init__(self, offset_sec: float): + self.__offset = float(offset_sec) + self.__next_ts: float = 0.0 + + self.reset() + + @property + def next_ts(self) -> float: + return self.__next_ts + + def reset(self) -> None: + self.__next_ts = self.now() + self.__offset + + def has_expired(self) -> bool: + return self.now() >= self.__next_ts + + @staticmethod + def now() -> float: + return time.time() + + +class Service: + def __init__(self, name: str) -> None: + self.name = name + self.requests: list[bytes] = [] + self.waiting: list[Worker] = [] # List of waiting workers + + +class Worker(SyftBaseModel): + address: bytes + identity: bytes + service: Service | None = None + syft_worker_id: UID | None = None + expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @field_validator("syft_worker_id", mode="before") + @classmethod + def set_syft_worker_id(cls, v: Any) -> Any: + if isinstance(v, str): + return UID(v) + return v + + def has_expired(self) -> bool: + return self.expiry_t.has_expired() + + def get_expiry(self) -> float: + return self.expiry_t.next_ts + + def reset_expiry(self) -> None: + self.expiry_t.reset() + + def __str__(self) -> str: + svc = self.service.name if self.service else None + return ( + f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " + f"syft_worker_id={self.syft_worker_id!r})" + ) diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py new file mode 100644 index 00000000000..82acf6cdb4a --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -0,0 +1,243 @@ +# stdlib +import logging +import threading + +# third party +import zmq +from zmq import Frame +from zmq.error import ContextTerminated + +# relative +from ...serde.deserialize import _deserialize +from ...serde.serializable import serializable +from ...types.uid import UID +from ..worker.worker_pool import ConsumerState +from ..worker.worker_stash import WorkerStash +from .base_queue import AbstractMessageHandler +from .base_queue import QueueConsumer +from .zmq_common import HEARTBEAT_INTERVAL_SEC +from .zmq_common import PRODUCER_TIMEOUT_SEC +from .zmq_common import THREAD_TIMEOUT_SEC +from .zmq_common import Timeout +from .zmq_common import ZMQCommand +from .zmq_common import ZMQHeader +from .zmq_common import ZMQ_POLLER_TIMEOUT_MSEC +from .zmq_common import ZMQ_SOCKET_LOCK + +logger = logging.getLogger(__name__) + + +@serializable(attrs=["_subscriber"]) +class ZMQConsumer(QueueConsumer): + def __init__( + self, + message_handler: AbstractMessageHandler, + address: str, + queue_name: str, + service_name: str, + syft_worker_id: UID | None = None, + worker_stash: WorkerStash | None = None, + verbose: bool = False, + ) -> None: + self.address = address + self.message_handler = message_handler + self.service_name = service_name + self.queue_name = queue_name + self.context = zmq.Context() + self.poller = zmq.Poller() + self.socket = None + self.verbose = verbose + self.id = UID().short() + self._stop = threading.Event() + self.syft_worker_id = syft_worker_id + self.worker_stash = worker_stash + self.post_init() + + def reconnect_to_producer(self) -> None: + """Connect or reconnect to producer""" + if self.socket: + self.poller.unregister(self.socket) # type: ignore[unreachable] + self.socket.close() + self.socket = self.context.socket(zmq.DEALER) + self.socket.linger = 0 + self.socket.setsockopt_string(zmq.IDENTITY, self.id) + self.socket.connect(self.address) + self.poller.register(self.socket, zmq.POLLIN) + + logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") + + # Register queue with the producer + self.send_to_producer( + ZMQCommand.W_READY, + [self.service_name.encode(), str(self.syft_worker_id).encode()], + ) + + def post_init(self) -> None: + self.thread: threading.Thread | None = None + self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) + self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) + self.reconnect_to_producer() + + def disconnect_from_producer(self) -> None: + self.send_to_producer(ZMQCommand.W_DISCONNECT) + + def close(self) -> None: + self.disconnect_from_producer() + self._stop.set() + try: + self.poller.unregister(self.socket) + except Exception as e: + logger.exception("Failed to unregister worker.", exc_info=e) + finally: + if self.thread is not None: + self.thread.join(timeout=THREAD_TIMEOUT_SEC) + self.thread = None + self.socket.close() + self.context.destroy() + self._stop.clear() + + def send_to_producer( + self, + command: bytes, + msg: bytes | list | None = None, + ) -> None: + """Send message to producer. + + If no msg is provided, creates one internally + """ + if self.socket.closed: + logger.warning("Socket is closed. Cannot send message.") + return + + if msg is None: + msg = [] + elif not isinstance(msg, list): + msg = [msg] + + # ZMQConsumer send frames: [empty, header, command, ...data] + core = [b"", ZMQHeader.W_WORKER, command] + msg = core + msg + + if command != ZMQCommand.W_HEARTBEAT: + logger.info(f"ZMQ Consumer send: {core}") + + with ZMQ_SOCKET_LOCK: + try: + self.socket.send_multipart(msg) + except zmq.ZMQError as e: + logger.error("ZMQConsumer send error", exc_info=e) + + def _run(self) -> None: + """Send reply, if any, to producer and wait for next request.""" + try: + while True: + if self._stop.is_set(): + logger.info("ZMQConsumer thread stopped") + return + + try: + items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except ContextTerminated: + logger.info("Context terminated") + return + except Exception as e: + logger.error("ZMQ poll error", exc_info=e) + continue + + if items: + msg = self.socket.recv_multipart() + + # mark as alive + self.set_producer_alive() + + if len(msg) < 3: + logger.error(f"ZMQConsumer invalid recv: {msg}") + continue + + # Message frames recieved by consumer: + # [empty, header, command, ...data] + (_, _, command, *data) = msg + + if command != ZMQCommand.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQConsumer recv: {msg[:-4]}") + + if command == ZMQCommand.W_REQUEST: + # Call Message Handler + try: + message = data.pop() + self.associate_job(message) + self.message_handler.handle_message( + message=message, + syft_worker_id=self.syft_worker_id, + ) + except Exception as e: + logger.exception("Couldn't handle message", exc_info=e) + finally: + self.clear_job() + elif command == ZMQCommand.W_HEARTBEAT: + self.set_producer_alive() + elif command == ZMQCommand.W_DISCONNECT: + self.reconnect_to_producer() + else: + logger.error(f"ZMQConsumer invalid command: {command}") + else: + if not self.is_producer_alive(): + logger.info("Producer check-alive timed out. Reconnecting.") + self.reconnect_to_producer() + self.set_producer_alive() + + self.send_heartbeat() + + except zmq.ZMQError as e: + if e.errno == zmq.ETERM: + logger.info("zmq.ETERM") + else: + logger.exception("zmq.ZMQError", exc_info=e) + except Exception as e: + logger.exception("ZMQConsumer thread exception", exc_info=e) + + def set_producer_alive(self) -> None: + self.producer_ping_t.reset() + + def is_producer_alive(self) -> bool: + # producer timer is within timeout + return not self.producer_ping_t.has_expired() + + def send_heartbeat(self) -> None: + if self.heartbeat_t.has_expired() and self.is_producer_alive(): + self.send_to_producer(ZMQCommand.W_HEARTBEAT) + self.heartbeat_t.reset() + + def run(self) -> None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + def associate_job(self, message: Frame) -> None: + try: + queue_item = _deserialize(message, from_bytes=True) + self._set_worker_job(queue_item.job_id) + except Exception as e: + logger.exception("Could not associate job", exc_info=e) + + def clear_job(self) -> None: + self._set_worker_job(None) + + def _set_worker_job(self, job_id: UID | None) -> None: + if self.worker_stash is not None: + consumer_state = ( + ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING + ) + res = self.worker_stash.update_consumer_state( + credentials=self.worker_stash.partition.root_verify_key, + worker_uid=self.syft_worker_id, + consumer_state=consumer_state, + ) + if res.is_err(): + logger.error( + f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" + ) + + @property + def alive(self) -> bool: + return not self.socket.closed and self.is_producer_alive() diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_producer.py similarity index 54% rename from packages/syft/src/syft/service/queue/zmq_queue.py rename to packages/syft/src/syft/service/queue/zmq_producer.py index 08ff386696e..84efaf74493 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -1,143 +1,44 @@ # stdlib from binascii import hexlify -from collections import defaultdict import itertools import logging -import socketserver import sys import threading -import time from time import sleep from typing import Any # third party -from pydantic import field_validator import zmq -from zmq import Frame from zmq import LINGER -from zmq.error import ContextTerminated # relative -from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize from ...service.action.action_object import ActionObject from ...service.context import AuthedServiceContext -from ...types.base import SyftBaseModel -from ...types.syft_object import SYFT_OBJECT_VERSION_4 -from ...types.syft_object import SyftObject from ...types.uid import UID from ...util.util import get_queue_address from ..response import SyftError -from ..response import SyftSuccess from ..service import AbstractService from ..worker.worker_pool import ConsumerState from ..worker.worker_stash import WorkerStash -from .base_queue import AbstractMessageHandler -from .base_queue import QueueClient -from .base_queue import QueueClientConfig -from .base_queue import QueueConfig -from .base_queue import QueueConsumer from .base_queue import QueueProducer from .queue_stash import ActionQueueItem from .queue_stash import QueueStash from .queue_stash import Status - -# Producer/Consumer heartbeat interval (in seconds) -HEARTBEAT_INTERVAL_SEC = 2 - -# Thread join timeout (in seconds) -THREAD_TIMEOUT_SEC = 5 - -# Max duration (in ms) to wait for ZMQ poller to return -ZMQ_POLLER_TIMEOUT_MSEC = 1000 - -# Duration (in seconds) after which a worker without a heartbeat will be marked as expired -WORKER_TIMEOUT_SEC = 60 - -# Duration (in seconds) after which producer without a heartbeat will be marked as expired -PRODUCER_TIMEOUT_SEC = 60 - -# Lock for working on ZMQ socket -ZMQ_SOCKET_LOCK = threading.Lock() +from .zmq_common import HEARTBEAT_INTERVAL_SEC +from .zmq_common import Service +from .zmq_common import THREAD_TIMEOUT_SEC +from .zmq_common import Timeout +from .zmq_common import Worker +from .zmq_common import ZMQCommand +from .zmq_common import ZMQHeader +from .zmq_common import ZMQ_POLLER_TIMEOUT_MSEC +from .zmq_common import ZMQ_SOCKET_LOCK logger = logging.getLogger(__name__) -class QueueMsgProtocol: - W_WORKER = b"MDPW01" - W_READY = b"0x01" - W_REQUEST = b"0x02" - W_REPLY = b"0x03" - W_HEARTBEAT = b"0x04" - W_DISCONNECT = b"0x05" - - -MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 - - -class Timeout: - def __init__(self, offset_sec: float): - self.__offset = float(offset_sec) - self.__next_ts: float = 0.0 - - self.reset() - - @property - def next_ts(self) -> float: - return self.__next_ts - - def reset(self) -> None: - self.__next_ts = self.now() + self.__offset - - def has_expired(self) -> bool: - return self.now() >= self.__next_ts - - @staticmethod - def now() -> float: - return time.time() - - -class Service: - def __init__(self, name: str) -> None: - self.name = name - self.requests: list[bytes] = [] - self.waiting: list[Worker] = [] # List of waiting workers - - -class Worker(SyftBaseModel): - address: bytes - identity: bytes - service: Service | None = None - syft_worker_id: UID | None = None - expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. - @field_validator("syft_worker_id", mode="before") - @classmethod - def set_syft_worker_id(cls, v: Any) -> Any: - if isinstance(v, str): - return UID(v) - return v - - def has_expired(self) -> bool: - return self.expiry_t.has_expired() - - def get_expiry(self) -> float: - return self.expiry_t.next_ts - - def reset_expiry(self) -> None: - self.expiry_t.reset() - - def __str__(self) -> str: - svc = self.service.name if self.service else None - return ( - f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " - f"syft_worker_id={self.syft_worker_id!r})" - ) - - @serializable() class ZMQProducer(QueueProducer): INTERNAL_SERVICE_PREFIX = b"mmi." @@ -403,7 +304,7 @@ def run(self) -> None: def send(self, worker: bytes, message: bytes | list[bytes]) -> None: worker_obj = self.require_worker(worker) - self.send_to_worker(worker_obj, QueueMsgProtocol.W_REQUEST, message) + self.send_to_worker(worker_obj, ZMQCommand.W_REQUEST, message) def bind(self, endpoint: str) -> None: """Bind producer to endpoint.""" @@ -414,7 +315,7 @@ def send_heartbeats(self) -> None: """Send heartbeats to idle workers if it's time""" if self.heartbeat_t.has_expired(): for worker in self.waiting: - self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT) + self.send_to_worker(worker, ZMQCommand.W_HEARTBEAT) self.heartbeat_t.reset() def purge_workers(self) -> None: @@ -484,7 +385,7 @@ def dispatch(self, service: Service, msg: bytes) -> None: msg = service.requests.pop(0) worker = service.waiting.pop(0) self.waiting.remove(worker) - self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, msg) + self.send_to_worker(worker, ZMQCommand.W_REQUEST, msg) def send_to_worker( self, @@ -507,10 +408,10 @@ def send_to_worker( msg = [msg] # ZMQProducer send frames: [address, empty, header, command, ...data] - core = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] + core = [worker.address, b"", ZMQHeader.W_WORKER, command] msg = core + msg - if command != QueueMsgProtocol.W_HEARTBEAT: + if command != ZMQCommand.W_HEARTBEAT: # log everything except the last frame which contains serialized data logger.info(f"ZMQProducer send: {core}") @@ -547,11 +448,11 @@ def _run(self) -> None: # ZMQProducer recv frames: [address, empty, header, command, ...data] (address, _, header, command, *data) = msg - if command != QueueMsgProtocol.W_HEARTBEAT: + if command != ZMQCommand.W_HEARTBEAT: # log everything except the last frame which contains serialized data logger.info(f"ZMQProducer recv: {msg[:4]}") - if header == QueueMsgProtocol.W_WORKER: + if header == ZMQHeader.W_WORKER: self.process_worker(address, command, data) else: logger.error(f"Invalid message header: {header}") @@ -574,7 +475,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N worker_ready = hexlify(address) in self.workers worker = self.require_worker(address) - if QueueMsgProtocol.W_READY == command: + if ZMQCommand.W_READY == command: service_name = data.pop(0).decode() syft_worker_id = data.pop(0).decode() if worker_ready: @@ -596,7 +497,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N worker.syft_worker_id = UID(syft_worker_id) self.worker_waiting(worker) - elif QueueMsgProtocol.W_HEARTBEAT == command: + elif ZMQCommand.W_HEARTBEAT == command: if worker_ready: # If worker is ready then reset expiry # and add it to worker waiting list @@ -605,7 +506,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N else: logger.info(f"Got heartbeat, but worker not ready. {worker}") self.delete_worker(worker, True) - elif QueueMsgProtocol.W_DISCONNECT == command: + elif ZMQCommand.W_DISCONNECT == command: logger.info(f"Removing disconnected worker: {worker}") self.delete_worker(worker, False) else: @@ -614,7 +515,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" if disconnect: - self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT) + self.send_to_worker(worker, ZMQHeader.W_DISCONNECT) if worker.service and worker in worker.service.waiting: worker.service.waiting.remove(worker) @@ -632,386 +533,3 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: @property def alive(self) -> bool: return not self.socket.closed - - -@serializable(attrs=["_subscriber"]) -class ZMQConsumer(QueueConsumer): - def __init__( - self, - message_handler: AbstractMessageHandler, - address: str, - queue_name: str, - service_name: str, - syft_worker_id: UID | None = None, - worker_stash: WorkerStash | None = None, - verbose: bool = False, - ) -> None: - self.address = address - self.message_handler = message_handler - self.service_name = service_name - self.queue_name = queue_name - self.context = zmq.Context() - self.poller = zmq.Poller() - self.socket = None - self.verbose = verbose - self.id = UID().short() - self._stop = threading.Event() - self.syft_worker_id = syft_worker_id - self.worker_stash = worker_stash - self.post_init() - - def reconnect_to_producer(self) -> None: - """Connect or reconnect to producer""" - if self.socket: - self.poller.unregister(self.socket) # type: ignore[unreachable] - self.socket.close() - self.socket = self.context.socket(zmq.DEALER) - self.socket.linger = 0 - self.socket.setsockopt_string(zmq.IDENTITY, self.id) - self.socket.connect(self.address) - self.poller.register(self.socket, zmq.POLLIN) - - logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") - - # Register queue with the producer - self.send_to_producer( - QueueMsgProtocol.W_READY, - [self.service_name.encode(), str(self.syft_worker_id).encode()], - ) - - def post_init(self) -> None: - self.thread: threading.Thread | None = None - self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) - self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) - self.reconnect_to_producer() - - def disconnect_from_producer(self) -> None: - self.send_to_producer(QueueMsgProtocol.W_DISCONNECT) - - def close(self) -> None: - self.disconnect_from_producer() - self._stop.set() - try: - self.poller.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister worker.", exc_info=e) - finally: - if self.thread is not None: - self.thread.join(timeout=THREAD_TIMEOUT_SEC) - self.thread = None - self.socket.close() - self.context.destroy() - self._stop.clear() - - def send_to_producer( - self, - command: bytes, - msg: bytes | list | None = None, - ) -> None: - """Send message to producer. - - If no msg is provided, creates one internally - """ - if self.socket.closed: - logger.warning("Socket is closed. Cannot send message.") - return - - if msg is None: - msg = [] - elif not isinstance(msg, list): - msg = [msg] - - # ZMQConsumer send frames: [empty, header, command, ...data] - core = [b"", QueueMsgProtocol.W_WORKER, command] - msg = core + msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - logger.info(f"ZMQ Consumer send: {core}") - - with ZMQ_SOCKET_LOCK: - try: - self.socket.send_multipart(msg) - except zmq.ZMQError as e: - logger.error("ZMQConsumer send error", exc_info=e) - - def _run(self) -> None: - """Send reply, if any, to producer and wait for next request.""" - try: - while True: - if self._stop.is_set(): - logger.info("ZMQConsumer thread stopped") - return - - try: - items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except ContextTerminated: - logger.info("Context terminated") - return - except Exception as e: - logger.error("ZMQ poll error", exc_info=e) - continue - - if items: - msg = self.socket.recv_multipart() - - # mark as alive - self.set_producer_alive() - - if len(msg) < 3: - logger.error(f"ZMQConsumer invalid recv: {msg}") - continue - - # Message frames recieved by consumer: - # [empty, header, command, ...data] - (_, _, command, *data) = msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQConsumer recv: {msg[:-4]}") - - if command == QueueMsgProtocol.W_REQUEST: - # Call Message Handler - try: - message = data.pop() - self.associate_job(message) - self.message_handler.handle_message( - message=message, - syft_worker_id=self.syft_worker_id, - ) - except Exception as e: - logger.exception("Couldn't handle message", exc_info=e) - finally: - self.clear_job() - elif command == QueueMsgProtocol.W_HEARTBEAT: - self.set_producer_alive() - elif command == QueueMsgProtocol.W_DISCONNECT: - self.reconnect_to_producer() - else: - logger.error(f"ZMQConsumer invalid command: {command}") - else: - if not self.is_producer_alive(): - logger.info("Producer check-alive timed out. Reconnecting.") - self.reconnect_to_producer() - self.set_producer_alive() - - self.send_heartbeat() - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - logger.info("zmq.ETERM") - else: - logger.exception("zmq.ZMQError", exc_info=e) - except Exception as e: - logger.exception("ZMQConsumer thread exception", exc_info=e) - - def set_producer_alive(self) -> None: - self.producer_ping_t.reset() - - def is_producer_alive(self) -> bool: - # producer timer is within timeout - return not self.producer_ping_t.has_expired() - - def send_heartbeat(self) -> None: - if self.heartbeat_t.has_expired() and self.is_producer_alive(): - self.send_to_producer(QueueMsgProtocol.W_HEARTBEAT) - self.heartbeat_t.reset() - - def run(self) -> None: - self.thread = threading.Thread(target=self._run) - self.thread.start() - - def associate_job(self, message: Frame) -> None: - try: - queue_item = _deserialize(message, from_bytes=True) - self._set_worker_job(queue_item.job_id) - except Exception as e: - logger.exception("Could not associate job", exc_info=e) - - def clear_job(self) -> None: - self._set_worker_job(None) - - def _set_worker_job(self, job_id: UID | None) -> None: - if self.worker_stash is not None: - consumer_state = ( - ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING - ) - res = self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, - worker_uid=self.syft_worker_id, - consumer_state=consumer_state, - ) - if res.is_err(): - logger.error( - f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" - ) - - @property - def alive(self) -> bool: - return not self.socket.closed and self.is_producer_alive() - - -@serializable() -class ZMQClientConfig(SyftObject, QueueClientConfig): - __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_4 - - id: UID | None = None # type: ignore[assignment] - hostname: str = "127.0.0.1" - queue_port: int | None = None - # TODO: setting this to false until we can fix the ZMQ - # port issue causing tests to randomly fail - create_producer: bool = False - n_consumers: int = 0 - consumer_service: str | None = None - - -@serializable(attrs=["host"]) -class ZMQClient(QueueClient): - """ZMQ Client for creating producers and consumers.""" - - producers: dict[str, ZMQProducer] - consumers: defaultdict[str, list[ZMQConsumer]] - - def __init__(self, config: ZMQClientConfig) -> None: - self.host = config.hostname - self.producers = {} - self.consumers = defaultdict(list) - self.config = config - - @staticmethod - def _get_free_tcp_port(host: str) -> int: - with socketserver.TCPServer((host, 0), None) as s: - free_port = s.server_address[1] - - return free_port - - def add_producer( - self, - queue_name: str, - port: int | None = None, - queue_stash: QueueStash | None = None, - worker_stash: WorkerStash | None = None, - context: AuthedServiceContext | None = None, - ) -> ZMQProducer: - """Add a producer of a queue. - - A queue can have at most one producer attached to it. - """ - - if port is None: - if self.config.queue_port is None: - self.config.queue_port = self._get_free_tcp_port(self.host) - port = self.config.queue_port - else: - port = self.config.queue_port - - producer = ZMQProducer( - queue_name=queue_name, - queue_stash=queue_stash, - port=port, - context=context, - worker_stash=worker_stash, - ) - self.producers[queue_name] = producer - return producer - - def add_consumer( - self, - queue_name: str, - message_handler: AbstractMessageHandler, - service_name: str, - address: str | None = None, - worker_stash: WorkerStash | None = None, - syft_worker_id: UID | None = None, - ) -> ZMQConsumer: - """Add a consumer to a queue - - A queue should have at least one producer attached to the group. - - """ - - if address is None: - address = get_queue_address(self.config.queue_port) - - consumer = ZMQConsumer( - queue_name=queue_name, - message_handler=message_handler, - address=address, - service_name=service_name, - syft_worker_id=syft_worker_id, - worker_stash=worker_stash, - ) - self.consumers[queue_name].append(consumer) - - return consumer - - def send_message( - self, - message: bytes, - queue_name: str, - worker: bytes | None = None, - ) -> SyftSuccess | SyftError: - producer = self.producers.get(queue_name) - if producer is None: - return SyftError( - message=f"No producer attached for queue: {queue_name}. Please add a producer for it." - ) - try: - producer.send(message=message, worker=worker) - except Exception as e: - # stdlib - return SyftError( - message=f"Failed to send message to: {queue_name} with error: {e}" - ) - return SyftSuccess( - message=f"Successfully queued message to : {queue_name}", - ) - - def close(self) -> SyftError | SyftSuccess: - try: - for consumers in self.consumers.values(): - for consumer in consumers: - # make sure look is stopped - consumer.close() - - for producer in self.producers.values(): - # make sure loop is stopped - producer.close() - # close existing connection. - except Exception as e: - return SyftError(message=f"Failed to close connection: {e}") - - return SyftSuccess(message="All connections closed.") - - def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: - if queue_name not in self.producers: - return SyftError(message=f"No producer running for : {queue_name}") - - producer = self.producers[queue_name] - - # close existing connection. - producer.close() - - # add a new connection - self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore - - return SyftSuccess(message=f"Queue: {queue_name} successfully purged") - - def purge_all(self) -> SyftError | SyftSuccess: - for queue_name in self.producers: - self.purge_queue(queue_name=queue_name) - - return SyftSuccess(message="Successfully purged all queues.") - - -@serializable() -class ZMQQueueConfig(QueueConfig): - def __init__( - self, - client_type: type[ZMQClient] | None = None, - client_config: ZMQClientConfig | None = None, - thread_workers: bool = False, - ): - self.client_type = client_type or ZMQClient - self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.thread_workers = thread_workers diff --git a/packages/syft/tests/syft/zmq_queue_test.py b/packages/syft/tests/syft/zmq_queue_test.py index 8c5b8dedebe..d9af3e3222f 100644 --- a/packages/syft/tests/syft/zmq_queue_test.py +++ b/packages/syft/tests/syft/zmq_queue_test.py @@ -13,11 +13,11 @@ import syft from syft.service.queue.base_queue import AbstractMessageHandler from syft.service.queue.queue import QueueManager -from syft.service.queue.zmq_queue import ZMQClient -from syft.service.queue.zmq_queue import ZMQClientConfig -from syft.service.queue.zmq_queue import ZMQConsumer -from syft.service.queue.zmq_queue import ZMQProducer -from syft.service.queue.zmq_queue import ZMQQueueConfig +from syft.service.queue.zmq_client import ZMQClient +from syft.service.queue.zmq_client import ZMQClientConfig +from syft.service.queue.zmq_client import ZMQQueueConfig +from syft.service.queue.zmq_consumer import ZMQConsumer +from syft.service.queue.zmq_producer import ZMQProducer from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.util.util import get_queue_address From dab848dcdbf63f8c9efbfe3f6bc03c0632ce7bfe Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 9 Jul 2024 19:42:43 +0530 Subject: [PATCH 09/20] add zmq code for rebasing --- .../syft/src/syft/service/queue/zmq_queue.py | 1054 +++++++++++++++++ 1 file changed, 1054 insertions(+) create mode 100644 packages/syft/src/syft/service/queue/zmq_queue.py diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py new file mode 100644 index 00000000000..de409dda676 --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -0,0 +1,1054 @@ +# stdlib +from binascii import hexlify +from collections import defaultdict +import itertools +import logging +import socketserver +import sys +import threading +from threading import Event +import time +from time import sleep +from typing import Any +from typing import cast + +# third party +from pydantic import field_validator +from result import Result +import zmq +from zmq import Frame +from zmq import LINGER +from zmq.error import ContextTerminated + +# relative +from ...node.credentials import SyftVerifyKey +from ...serde.deserialize import _deserialize +from ...serde.serializable import serializable +from ...serde.serialize import _serialize as serialize +from ...service.action.action_object import ActionObject +from ...service.context import AuthedServiceContext +from ...types.base import SyftBaseModel +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.util import get_queue_address +from ..response import SyftError +from ..response import SyftSuccess +from ..service import AbstractService +from ..worker.worker_pool import ConsumerState +from ..worker.worker_pool import SyftWorker +from ..worker.worker_stash import WorkerStash +from .base_queue import AbstractMessageHandler +from .base_queue import QueueClient +from .base_queue import QueueClientConfig +from .base_queue import QueueConfig +from .base_queue import QueueConsumer +from .base_queue import QueueProducer +from .queue_stash import ActionQueueItem +from .queue_stash import QueueStash +from .queue_stash import Status + +# Producer/Consumer heartbeat interval (in seconds) +HEARTBEAT_INTERVAL_SEC = 2 + +# Thread join timeout (in seconds) +THREAD_TIMEOUT_SEC = 30 + +# Max duration (in ms) to wait for ZMQ poller to return +ZMQ_POLLER_TIMEOUT_MSEC = 1000 + +# Duration (in seconds) after which a worker without a heartbeat will be marked as expired +WORKER_TIMEOUT_SEC = 60 + +# Duration (in seconds) after which producer without a heartbeat will be marked as expired +PRODUCER_TIMEOUT_SEC = 60 + +# Lock for working on ZMQ socket +ZMQ_SOCKET_LOCK = threading.Lock() + +logger = logging.getLogger(__name__) + + +class QueueMsgProtocol: + W_WORKER = b"MDPW01" + W_READY = b"0x01" + W_REQUEST = b"0x02" + W_REPLY = b"0x03" + W_HEARTBEAT = b"0x04" + W_DISCONNECT = b"0x05" + + +MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 + + +class Timeout: + def __init__(self, offset_sec: float): + self.__offset = float(offset_sec) + self.__next_ts: float = 0.0 + + self.reset() + + @property + def next_ts(self) -> float: + return self.__next_ts + + def reset(self) -> None: + self.__next_ts = self.now() + self.__offset + + def has_expired(self) -> bool: + return self.now() >= self.__next_ts + + @staticmethod + def now() -> float: + return time.time() + + +class Service: + def __init__(self, name: str) -> None: + self.name = name + self.requests: list[bytes] = [] + self.waiting: list[Worker] = [] # List of waiting workers + + +class Worker(SyftBaseModel): + address: bytes + identity: bytes + service: Service | None = None + syft_worker_id: UID | None = None + expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) + + @field_validator("syft_worker_id", mode="before") + @classmethod + def set_syft_worker_id(cls, v: Any) -> Any: + if isinstance(v, str): + return UID(v) + return v + + def has_expired(self) -> bool: + return self.expiry_t.has_expired() + + def get_expiry(self) -> float: + return self.expiry_t.next_ts + + def reset_expiry(self) -> None: + self.expiry_t.reset() + + def _syft_worker( + self, stash: WorkerStash, credentials: SyftVerifyKey + ) -> Result[SyftWorker | None, str]: + return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) + + def __str__(self) -> str: + svc = self.service.name if self.service else None + return ( + f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " + f"syft_worker_id={self.syft_worker_id!r})" + ) + + +@serializable() +class ZMQProducer(QueueProducer): + INTERNAL_SERVICE_PREFIX = b"mmi." + + def __init__( + self, + queue_name: str, + queue_stash: QueueStash, + worker_stash: WorkerStash, + port: int, + context: AuthedServiceContext, + ) -> None: + self.id = UID().short() + self.port = port + self.queue_stash = queue_stash + self.worker_stash = worker_stash + self.queue_name = queue_name + self.auth_context = context + self._stop = Event() + self.post_init() + + @property + def address(self) -> str: + return get_queue_address(self.port) + + def post_init(self) -> None: + """Initialize producer state.""" + + self.services: dict[str, Service] = {} + self.workers: dict[bytes, Worker] = {} + self.waiting: list[Worker] = [] + self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) + self.context = zmq.Context(1) + self.socket = self.context.socket(zmq.ROUTER) + self.socket.setsockopt(LINGER, 1) + self.socket.setsockopt_string(zmq.IDENTITY, self.id) + self.poll_workers = zmq.Poller() + self.poll_workers.register(self.socket, zmq.POLLIN) + self.bind(f"tcp://*:{self.port}") + self.thread: threading.Thread | None = None + self.producer_thread: threading.Thread | None = None + + def close(self) -> None: + self._stop.set() + try: + if self.thread: + self.thread.join(THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQProducer message sending thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) + self.thread = None + + if self.producer_thread: + self.producer_thread.join(THREAD_TIMEOUT_SEC) + if self.producer_thread.is_alive(): + logger.error( + f"ZMQProducer queue thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) + self.producer_thread = None + + self.poll_workers.unregister(self.socket) + except Exception as e: + logger.exception("Failed to unregister poller.", exc_info=e) + finally: + self.socket.close() + self.context.destroy() + + # self._stop.clear() + + @property + def action_service(self) -> AbstractService: + if self.auth_context.node is not None: + return self.auth_context.node.get_service("ActionService") + else: + raise Exception(f"{self.auth_context} does not have a node.") + + def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: + """recursively check collections for unresolved action objects""" + if isinstance(arg, UID): + arg = self.action_service.get(self.auth_context, arg).ok() + return self.contains_unresolved_action_objects(arg, recursion=recursion + 1) + if isinstance(arg, ActionObject): + if not arg.syft_resolved: + res = self.action_service.get(self.auth_context, arg) + if res.is_err(): + return True + arg = res.ok() + if not arg.syft_resolved: + return True + arg = arg.syft_action_data + + try: + value = False + if isinstance(arg, list): + for elem in arg: + value = self.contains_unresolved_action_objects( + elem, recursion=recursion + 1 + ) + if value: + return True + if isinstance(arg, dict): + for elem in arg.values(): + value = self.contains_unresolved_action_objects( + elem, recursion=recursion + 1 + ) + if value: + return True + return value + except Exception as e: + logger.exception("Failed to resolve action objects.", exc_info=e) + return True + + def unwrap_nested_actionobjects(self, data: Any) -> Any: + """recursively unwraps nested action objects""" + + if isinstance(data, list): + return [self.unwrap_nested_actionobjects(obj) for obj in data] + if isinstance(data, dict): + return { + key: self.unwrap_nested_actionobjects(obj) for key, obj in data.items() + } + if isinstance(data, ActionObject): + res = self.action_service.get(self.auth_context, data.id) + res = res.ok() if res.is_ok() else res.err() + if not isinstance(res, ActionObject): + return SyftError(message=f"{res}") + else: + nested_res = res.syft_action_data + if isinstance(nested_res, ActionObject): + raise ValueError( + "More than double nesting of ActionObjects is currently not supported" + ) + return nested_res + return data + + def contains_nested_actionobjects(self, data: Any) -> bool: + """ + returns if this is a list/set/dict that contains ActionObjects + """ + + def unwrap_collection(col: set | dict | list) -> [Any]: # type: ignore + return_values = [] + if isinstance(col, dict): + values = list(col.values()) + list(col.keys()) + else: + values = list(col) + for v in values: + if isinstance(v, list | dict | set): + return_values += unwrap_collection(v) + else: + return_values.append(v) + return return_values + + if isinstance(data, list | dict | set): + values = unwrap_collection(data) + has_action_object = any(isinstance(x, ActionObject) for x in values) + return has_action_object + elif isinstance(data, ActionObject): + return True + return False + + def preprocess_action_arg(self, arg: UID) -> UID | None: + """ "If the argument is a collection (of collections) of ActionObjects, + We want to flatten the collection and upload a new ActionObject that contains + its values. E.g. [[ActionObject1, ActionObject2],[ActionObject3, ActionObject4]] + -> [[value1, value2],[value3, value4]] + """ + res = self.action_service.get(context=self.auth_context, uid=arg) + if res.is_err(): + return arg + action_object = res.ok() + data = action_object.syft_action_data + if self.contains_nested_actionobjects(data): + new_data = self.unwrap_nested_actionobjects(data) + + new_action_object = ActionObject.from_obj( + new_data, + id=action_object.id, + syft_blob_storage_entry_id=action_object.syft_blob_storage_entry_id, + ) + res = self.action_service._set( + context=self.auth_context, action_object=new_action_object + ) + return None + + def read_items(self) -> None: + while True: + if self._stop.is_set(): + break + try: + sleep(1) + + # Items to be queued + items_to_queue = self.queue_stash.get_by_status( + self.queue_stash.partition.root_verify_key, + status=Status.CREATED, + ).ok() + + items_to_queue = [] if items_to_queue is None else items_to_queue + + # Queue Items that are in the processing state + items_processing = self.queue_stash.get_by_status( + self.queue_stash.partition.root_verify_key, + status=Status.PROCESSING, + ).ok() + + items_processing = [] if items_processing is None else items_processing + + for item in itertools.chain(items_to_queue, items_processing): + # TODO: if resolving fails, set queueitem to errored, and jobitem as well + if item.status == Status.CREATED: + if isinstance(item, ActionQueueItem): + action = item.kwargs["action"] + if self.contains_unresolved_action_objects( + action.args + ) or self.contains_unresolved_action_objects(action.kwargs): + continue + for arg in action.args: + self.preprocess_action_arg(arg) + for _, arg in action.kwargs.items(): + self.preprocess_action_arg(arg) + + msg_bytes = serialize(item, to_bytes=True) + worker_pool = item.worker_pool.resolve_with_context( + self.auth_context + ) + worker_pool = worker_pool.ok() + service_name = worker_pool.name + service: Service | None = self.services.get(service_name) + + # Skip adding message if corresponding service/pool + # is not registered. + if service is None: + continue + + # append request message to the corresponding service + # This list is processed in dispatch method. + + # TODO: Logic to evaluate the CAN RUN Condition + service.requests.append(msg_bytes) + item.status = Status.PROCESSING + res = self.queue_stash.update(item.syft_client_verify_key, item) + if res.is_err(): + logger.error( + f"Failed to update queue item={item} error={res.err()}" + ) + elif item.status == Status.PROCESSING: + # Evaluate Retry condition here + # If job running and timeout or job status is KILL + # or heartbeat fails + # or container id doesn't exists, kill process or container + # else decrease retry count and mark status as CREATED. + pass + except Exception as e: + print(e, file=sys.stderr) + item.status = Status.ERRORED + res = self.queue_stash.update(item.syft_client_verify_key, item) + if res.is_err(): + logger.error( + f"Failed to update queue item={item} error={res.err()}" + ) + + def run(self) -> None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + self.producer_thread = threading.Thread(target=self.read_items) + self.producer_thread.start() + + def send(self, worker: bytes, message: bytes | list[bytes]) -> None: + worker_obj = self.require_worker(worker) + self.send_to_worker(worker_obj, QueueMsgProtocol.W_REQUEST, message) + + def bind(self, endpoint: str) -> None: + """Bind producer to endpoint.""" + self.socket.bind(endpoint) + logger.info(f"ZMQProducer endpoint: {endpoint}") + + def send_heartbeats(self) -> None: + """Send heartbeats to idle workers if it's time""" + if self.heartbeat_t.has_expired(): + for worker in self.waiting: + self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT) + self.heartbeat_t.reset() + + def purge_workers(self) -> None: + """Look for & kill expired workers. + + Workers are oldest to most recent, so we stop at the first alive worker. + """ + # work on a copy of the iterator + for worker in self.waiting: + res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) + if res.is_err() or (syft_worker := res.ok()) is None: + logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}") + continue + + if worker.has_expired() or syft_worker.to_be_deleted: + logger.info(f"Deleting expired worker id={worker}") + self.delete_worker(worker, syft_worker.to_be_deleted) + + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.node.get_service(WorkerService) + ) + worker_service._delete(self.auth_context, syft_worker) + + def update_consumer_state_for_worker( + self, syft_worker_id: UID, consumer_state: ConsumerState + ) -> None: + if self.worker_stash is None: + logger.error( # type: ignore[unreachable] + f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" + ) + return + + try: + # Check if worker is present in the database + worker = self.worker_stash.get_by_uid( + credentials=self.worker_stash.partition.root_verify_key, + uid=syft_worker_id, + ) + if worker.is_ok() and worker.ok() is None: + return + + res = self.worker_stash.update_consumer_state( + credentials=self.worker_stash.partition.root_verify_key, + worker_uid=syft_worker_id, + consumer_state=consumer_state, + ) + if res.is_err(): + logger.error( + f"Failed to update consumer state for worker id={syft_worker_id} " + f"to state: {consumer_state} error={res.err()}", + ) + except Exception as e: + logger.error( + f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}", + exc_info=e, + ) + + def worker_waiting(self, worker: Worker) -> None: + """This worker is now waiting for work.""" + # Queue to broker and service waiting lists + if worker not in self.waiting: + self.waiting.append(worker) + if worker.service is not None and worker not in worker.service.waiting: + worker.service.waiting.append(worker) + worker.reset_expiry() + self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) + self.dispatch(worker.service, None) + + def dispatch(self, service: Service, msg: bytes) -> None: + """Dispatch requests to waiting workers as possible""" + if msg is not None: # Queue message if any + service.requests.append(msg) + + self.purge_workers() + while service.waiting and service.requests: + # One worker consuming only one message at a time. + msg = service.requests.pop(0) + worker = service.waiting.pop(0) + self.waiting.remove(worker) + self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, msg) + + def send_to_worker( + self, + worker: Worker, + command: bytes, + msg: bytes | list | None = None, + ) -> None: + """Send message to worker. + + If message is provided, sends that message. + """ + + if self.socket.closed: + logger.warning("Socket is closed. Cannot send message.") + return + + if msg is None: + msg = [] + elif not isinstance(msg, list): + msg = [msg] + + # ZMQProducer send frames: [address, empty, header, command, ...data] + core = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] + msg = core + msg + + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer send: {core}") + + with ZMQ_SOCKET_LOCK: + try: + self.socket.send_multipart(msg) + except zmq.ZMQError as e: + logger.error("ZMQProducer send error", exc_info=e) + + def _run(self) -> None: + try: + while True: + if self._stop.is_set(): + logger.info("ZMQProducer thread stopped") + return + + for service in self.services.values(): + self.dispatch(service, None) + + items = None + + try: + items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except Exception as e: + logger.exception("ZMQProducer poll error", exc_info=e) + + if items: + msg = self.socket.recv_multipart() + + if len(msg) < 3: + logger.error(f"ZMQProducer invalid recv: {msg}") + continue + + # ZMQProducer recv frames: [address, empty, header, command, ...data] + (address, _, header, command, *data) = msg + + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer recv: {msg[:4]}") + + if header == QueueMsgProtocol.W_WORKER: + self.process_worker(address, command, data) + else: + logger.error(f"Invalid message header: {header}") + + self.send_heartbeats() + self.purge_workers() + except Exception as e: + logger.exception("ZMQProducer thread exception", exc_info=e) + + def require_worker(self, address: bytes) -> Worker: + """Finds the worker (creates if necessary).""" + identity = hexlify(address) + worker = self.workers.get(identity) + if worker is None: + worker = Worker(identity=identity, address=address) + self.workers[identity] = worker + return worker + + def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> None: + worker_ready = hexlify(address) in self.workers + worker = self.require_worker(address) + + if QueueMsgProtocol.W_READY == command: + service_name = data.pop(0).decode() + syft_worker_id = data.pop(0).decode() + if worker_ready: + # Not first command in session or Reserved service name + # If worker was already present, then we disconnect it first + # and wait for it to re-register itself to the producer. This ensures that + # we always have a healthy worker in place that can talk to the producer. + self.delete_worker(worker, True) + else: + # Attach worker to service and mark as idle + if service_name in self.services: + service: Service | None = self.services.get(service_name) + else: + service = Service(service_name) + self.services[service_name] = service + if service is not None: + worker.service = service + logger.info(f"New worker: {worker}") + worker.syft_worker_id = UID(syft_worker_id) + self.worker_waiting(worker) + + elif QueueMsgProtocol.W_HEARTBEAT == command: + if worker_ready: + # If worker is ready then reset expiry + # and add it to worker waiting list + # if not already present + self.worker_waiting(worker) + else: + logger.info(f"Got heartbeat, but worker not ready. {worker}") + self.delete_worker(worker, True) + elif QueueMsgProtocol.W_DISCONNECT == command: + logger.info(f"Removing disconnected worker: {worker}") + self.delete_worker(worker, False) + else: + logger.error(f"Invalid command: {command!r}") + + def delete_worker(self, worker: Worker, disconnect: bool) -> None: + """Deletes worker from all data structures, and deletes worker.""" + if disconnect: + self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT) + + if worker.service and worker in worker.service.waiting: + worker.service.waiting.remove(worker) + + if worker in self.waiting: + self.waiting.remove(worker) + + self.workers.pop(worker.identity, None) + + if worker.syft_worker_id is not None: + self.update_consumer_state_for_worker( + worker.syft_worker_id, ConsumerState.DETACHED + ) + + @property + def alive(self) -> bool: + return not self.socket.closed + + +@serializable(attrs=["_subscriber"]) +class ZMQConsumer(QueueConsumer): + def __init__( + self, + message_handler: AbstractMessageHandler, + address: str, + queue_name: str, + service_name: str, + syft_worker_id: UID | None = None, + worker_stash: WorkerStash | None = None, + verbose: bool = False, + ) -> None: + self.address = address + self.message_handler = message_handler + self.service_name = service_name + self.queue_name = queue_name + self.context = zmq.Context() + self.poller = zmq.Poller() + self.socket = None + self.verbose = verbose + self.id = UID().short() + self._stop = Event() + self.syft_worker_id = syft_worker_id + self.worker_stash = worker_stash + self.post_init() + + def reconnect_to_producer(self) -> None: + """Connect or reconnect to producer""" + if self.socket: + self.poller.unregister(self.socket) # type: ignore[unreachable] + self.socket.close() + self.socket = self.context.socket(zmq.DEALER) + self.socket.linger = 0 + self.socket.setsockopt_string(zmq.IDENTITY, self.id) + self.socket.connect(self.address) + self.poller.register(self.socket, zmq.POLLIN) + + logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") + + # Register queue with the producer + self.send_to_producer( + QueueMsgProtocol.W_READY, + [self.service_name.encode(), str(self.syft_worker_id).encode()], + ) + + def post_init(self) -> None: + self.thread: threading.Thread | None = None + self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) + self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) + self.reconnect_to_producer() + + def disconnect_from_producer(self) -> None: + self.send_to_producer(QueueMsgProtocol.W_DISCONNECT) + + def close(self) -> None: + self.disconnect_from_producer() + self._stop.set() + try: + if self.thread is not None: + self.thread.join(timeout=THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQConsumer thread join timed out during closing. " + f"SyftWorker id {self.syft_worker_id}, " + f"service name {self.service_name}." + ) + self.thread = None + self.poller.unregister(self.socket) + except Exception as e: + logger.error("Failed to unregister worker.", exc_info=e) + finally: + self.socket.close() + self.context.destroy() + # self._stop.clear() + + def send_to_producer( + self, + command: bytes, + msg: bytes | list | None = None, + ) -> None: + """Send message to producer. + + If no msg is provided, creates one internally + """ + if self.socket.closed: + logger.warning("Socket is closed. Cannot send message.") + return + + if msg is None: + msg = [] + elif not isinstance(msg, list): + msg = [msg] + + # ZMQConsumer send frames: [empty, header, command, ...data] + core = [b"", QueueMsgProtocol.W_WORKER, command] + msg = core + msg + + if command != QueueMsgProtocol.W_HEARTBEAT: + logger.info(f"ZMQ Consumer send: {core}") + + with ZMQ_SOCKET_LOCK: + try: + self.socket.send_multipart(msg) + except zmq.ZMQError as e: + logger.error("ZMQConsumer send error", exc_info=e) + + def _run(self) -> None: + """Send reply, if any, to producer and wait for next request.""" + try: + while True: + if self._stop.is_set(): + logger.info("ZMQConsumer thread stopped") + return + + try: + items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except ContextTerminated: + logger.info("Context terminated") + return + except Exception as e: + logger.error("ZMQ poll error", exc_info=e) + continue + + if items: + msg = self.socket.recv_multipart() + + # mark as alive + self.set_producer_alive() + + if len(msg) < 3: + logger.error(f"ZMQConsumer invalid recv: {msg}") + continue + + # Message frames recieved by consumer: + # [empty, header, command, ...data] + (_, _, command, *data) = msg + + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQConsumer recv: {msg[:-4]}") + + if command == QueueMsgProtocol.W_REQUEST: + # Call Message Handler + try: + message = data.pop() + self.associate_job(message) + self.message_handler.handle_message( + message=message, + syft_worker_id=self.syft_worker_id, + ) + except Exception as e: + logger.exception("Couldn't handle message", exc_info=e) + finally: + self.clear_job() + elif command == QueueMsgProtocol.W_HEARTBEAT: + self.set_producer_alive() + elif command == QueueMsgProtocol.W_DISCONNECT: + self.reconnect_to_producer() + else: + logger.error(f"ZMQConsumer invalid command: {command}") + else: + if not self.is_producer_alive(): + logger.info("Producer check-alive timed out. Reconnecting.") + self.reconnect_to_producer() + self.set_producer_alive() + + if not self._stop.is_set(): + self.send_heartbeat() + + except zmq.ZMQError as e: + if e.errno == zmq.ETERM: + logger.info("zmq.ETERM") + else: + logger.exception("zmq.ZMQError", exc_info=e) + except Exception as e: + logger.exception("ZMQConsumer thread exception", exc_info=e) + + def set_producer_alive(self) -> None: + self.producer_ping_t.reset() + + def is_producer_alive(self) -> bool: + # producer timer is within timeout + return not self.producer_ping_t.has_expired() + + def send_heartbeat(self) -> None: + if self.heartbeat_t.has_expired() and self.is_producer_alive(): + self.send_to_producer(QueueMsgProtocol.W_HEARTBEAT) + self.heartbeat_t.reset() + + def run(self) -> None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + def associate_job(self, message: Frame) -> None: + try: + queue_item = _deserialize(message, from_bytes=True) + self._set_worker_job(queue_item.job_id) + except Exception as e: + logger.exception("Could not associate job", exc_info=e) + + def clear_job(self) -> None: + self._set_worker_job(None) + + def _set_worker_job(self, job_id: UID | None) -> None: + if self.worker_stash is not None: + consumer_state = ( + ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING + ) + res = self.worker_stash.update_consumer_state( + credentials=self.worker_stash.partition.root_verify_key, + worker_uid=self.syft_worker_id, + consumer_state=consumer_state, + ) + if res.is_err(): + logger.error( + f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" + ) + + @property + def alive(self) -> bool: + return not self.socket.closed and self.is_producer_alive() + + +@serializable() +class ZMQClientConfig(SyftObject, QueueClientConfig): + __canonical_name__ = "ZMQClientConfig" + __version__ = SYFT_OBJECT_VERSION_4 + + id: UID | None = None # type: ignore[assignment] + hostname: str = "127.0.0.1" + queue_port: int | None = None + # TODO: setting this to false until we can fix the ZMQ + # port issue causing tests to randomly fail + create_producer: bool = False + n_consumers: int = 0 + consumer_service: str | None = None + + +@serializable(attrs=["host"]) +class ZMQClient(QueueClient): + """ZMQ Client for creating producers and consumers.""" + + producers: dict[str, ZMQProducer] + consumers: defaultdict[str, list[ZMQConsumer]] + + def __init__(self, config: ZMQClientConfig) -> None: + self.host = config.hostname + self.producers = {} + self.consumers = defaultdict(list) + self.config = config + + @staticmethod + def _get_free_tcp_port(host: str) -> int: + with socketserver.TCPServer((host, 0), None) as s: + free_port = s.server_address[1] + + return free_port + + def add_producer( + self, + queue_name: str, + port: int | None = None, + queue_stash: QueueStash | None = None, + worker_stash: WorkerStash | None = None, + context: AuthedServiceContext | None = None, + ) -> ZMQProducer: + """Add a producer of a queue. + + A queue can have at most one producer attached to it. + """ + + if port is None: + if self.config.queue_port is None: + self.config.queue_port = self._get_free_tcp_port(self.host) + port = self.config.queue_port + else: + port = self.config.queue_port + + producer = ZMQProducer( + queue_name=queue_name, + queue_stash=queue_stash, + port=port, + context=context, + worker_stash=worker_stash, + ) + self.producers[queue_name] = producer + return producer + + def add_consumer( + self, + queue_name: str, + message_handler: AbstractMessageHandler, + service_name: str, + address: str | None = None, + worker_stash: WorkerStash | None = None, + syft_worker_id: UID | None = None, + ) -> ZMQConsumer: + """Add a consumer to a queue + + A queue should have at least one producer attached to the group. + + """ + + if address is None: + address = get_queue_address(self.config.queue_port) + + consumer = ZMQConsumer( + queue_name=queue_name, + message_handler=message_handler, + address=address, + service_name=service_name, + syft_worker_id=syft_worker_id, + worker_stash=worker_stash, + ) + self.consumers[queue_name].append(consumer) + + return consumer + + def send_message( + self, + message: bytes, + queue_name: str, + worker: bytes | None = None, + ) -> SyftSuccess | SyftError: + producer = self.producers.get(queue_name) + if producer is None: + return SyftError( + message=f"No producer attached for queue: {queue_name}. Please add a producer for it." + ) + try: + producer.send(message=message, worker=worker) + except Exception as e: + # stdlib + return SyftError( + message=f"Failed to send message to: {queue_name} with error: {e}" + ) + return SyftSuccess( + message=f"Successfully queued message to : {queue_name}", + ) + + def close(self) -> SyftError | SyftSuccess: + try: + for consumers in self.consumers.values(): + for consumer in consumers: + # make sure look is stopped + consumer.close() + + for producer in self.producers.values(): + # make sure loop is stopped + producer.close() + # close existing connection. + except Exception as e: + return SyftError(message=f"Failed to close connection: {e}") + + return SyftSuccess(message="All connections closed.") + + def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: + if queue_name not in self.producers: + return SyftError(message=f"No producer running for : {queue_name}") + + producer = self.producers[queue_name] + + # close existing connection. + producer.close() + + # add a new connection + self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore + + return SyftSuccess(message=f"Queue: {queue_name} successfully purged") + + def purge_all(self) -> SyftError | SyftSuccess: + for queue_name in self.producers: + self.purge_queue(queue_name=queue_name) + + return SyftSuccess(message="Successfully purged all queues.") + + +@serializable() +class ZMQQueueConfig(QueueConfig): + def __init__( + self, + client_type: type[ZMQClient] | None = None, + client_config: ZMQClientConfig | None = None, + thread_workers: bool = False, + ): + self.client_type = client_type or ZMQClient + self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() + self.thread_workers = thread_workers From c5e56b477b268932ca784895f4d740709ded1f7c Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 9 Jul 2024 19:51:16 +0530 Subject: [PATCH 10/20] refactor zmq code to separate files - rebased --- .../syft/src/syft/service/queue/zmq_common.py | 11 +- .../src/syft/service/queue/zmq_consumer.py | 22 +- .../src/syft/service/queue/zmq_producer.py | 46 +- .../syft/src/syft/service/queue/zmq_queue.py | 1054 ----------------- 4 files changed, 58 insertions(+), 1075 deletions(-) delete mode 100644 packages/syft/src/syft/service/queue/zmq_queue.py diff --git a/packages/syft/src/syft/service/queue/zmq_common.py b/packages/syft/src/syft/service/queue/zmq_common.py index 1b47e3edcfc..abe7947a806 100644 --- a/packages/syft/src/syft/service/queue/zmq_common.py +++ b/packages/syft/src/syft/service/queue/zmq_common.py @@ -5,10 +5,14 @@ # third party from pydantic import field_validator +from result import Result # relative +from ...node.credentials import SyftVerifyKey from ...types.base import SyftBaseModel from ...types.uid import UID +from ..worker.worker_pool import SyftWorker +from ..worker.worker_stash import WorkerStash # Producer/Consumer heartbeat interval (in seconds) HEARTBEAT_INTERVAL_SEC = 2 @@ -83,8 +87,6 @@ class Worker(SyftBaseModel): syft_worker_id: UID | None = None expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @field_validator("syft_worker_id", mode="before") @classmethod def set_syft_worker_id(cls, v: Any) -> Any: @@ -101,6 +103,11 @@ def get_expiry(self) -> float: def reset_expiry(self) -> None: self.expiry_t.reset() + def _syft_worker( + self, stash: WorkerStash, credentials: SyftVerifyKey + ) -> Result[SyftWorker | None, str]: + return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) + def __str__(self) -> str: svc = self.service.name if self.service else None return ( diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index 82acf6cdb4a..aafae40a635 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -1,6 +1,7 @@ # stdlib import logging import threading +from threading import Event # third party import zmq @@ -48,7 +49,7 @@ def __init__( self.socket = None self.verbose = verbose self.id = UID().short() - self._stop = threading.Event() + self._stop = Event() self.syft_worker_id = syft_worker_id self.worker_stash = worker_stash self.post_init() @@ -85,16 +86,22 @@ def close(self) -> None: self.disconnect_from_producer() self._stop.set() try: - self.poller.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister worker.", exc_info=e) - finally: if self.thread is not None: self.thread.join(timeout=THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQConsumer thread join timed out during closing. " + f"SyftWorker id {self.syft_worker_id}, " + f"service name {self.service_name}." + ) self.thread = None + self.poller.unregister(self.socket) + except Exception as e: + logger.error("Failed to unregister worker.", exc_info=e) + finally: self.socket.close() self.context.destroy() - self._stop.clear() + # self._stop.clear() def send_to_producer( self, @@ -187,7 +194,8 @@ def _run(self) -> None: self.reconnect_to_producer() self.set_producer_alive() - self.send_heartbeat() + if not self._stop.is_set(): + self.send_heartbeat() except zmq.ZMQError as e: if e.errno == zmq.ETERM: diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py index 84efaf74493..002ceb4281b 100644 --- a/packages/syft/src/syft/service/queue/zmq_producer.py +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -4,8 +4,10 @@ import logging import sys import threading +from threading import Event from time import sleep from typing import Any +from typing import cast # third party import zmq @@ -57,7 +59,7 @@ def __init__( self.worker_stash = worker_stash self.queue_name = queue_name self.auth_context = context - self._stop = threading.Event() + self._stop = Event() self.post_init() @property @@ -83,25 +85,32 @@ def post_init(self) -> None: def close(self) -> None: self._stop.set() - try: - self.poll_workers.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister poller.", exc_info=e) - finally: if self.thread: self.thread.join(THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQProducer message sending thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) self.thread = None if self.producer_thread: self.producer_thread.join(THREAD_TIMEOUT_SEC) + if self.producer_thread.is_alive(): + logger.error( + f"ZMQProducer queue thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) self.producer_thread = None + self.poll_workers.unregister(self.socket) + except Exception as e: + logger.exception("Failed to unregister poller.", exc_info=e) + finally: self.socket.close() self.context.destroy() - self._stop.clear() - @property def action_service(self) -> AbstractService: if self.auth_context.node is not None: @@ -324,10 +333,23 @@ def purge_workers(self) -> None: Workers are oldest to most recent, so we stop at the first alive worker. """ # work on a copy of the iterator - for worker in list(self.waiting): - if worker.has_expired(): + for worker in self.waiting: + res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) + if res.is_err() or (syft_worker := res.ok()) is None: + logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}") + continue + + if worker.has_expired() or syft_worker.to_be_deleted: logger.info(f"Deleting expired worker id={worker}") - self.delete_worker(worker, False) + self.delete_worker(worker, syft_worker.to_be_deleted) + + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.node.get_service(WorkerService) + ) + worker_service._delete(self.auth_context, syft_worker) def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState @@ -515,7 +537,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" if disconnect: - self.send_to_worker(worker, ZMQHeader.W_DISCONNECT) + self.send_to_worker(worker, ZMQCommand.W_DISCONNECT) if worker.service and worker in worker.service.waiting: worker.service.waiting.remove(worker) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py deleted file mode 100644 index de409dda676..00000000000 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ /dev/null @@ -1,1054 +0,0 @@ -# stdlib -from binascii import hexlify -from collections import defaultdict -import itertools -import logging -import socketserver -import sys -import threading -from threading import Event -import time -from time import sleep -from typing import Any -from typing import cast - -# third party -from pydantic import field_validator -from result import Result -import zmq -from zmq import Frame -from zmq import LINGER -from zmq.error import ContextTerminated - -# relative -from ...node.credentials import SyftVerifyKey -from ...serde.deserialize import _deserialize -from ...serde.serializable import serializable -from ...serde.serialize import _serialize as serialize -from ...service.action.action_object import ActionObject -from ...service.context import AuthedServiceContext -from ...types.base import SyftBaseModel -from ...types.syft_object import SYFT_OBJECT_VERSION_4 -from ...types.syft_object import SyftObject -from ...types.uid import UID -from ...util.util import get_queue_address -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..worker.worker_pool import ConsumerState -from ..worker.worker_pool import SyftWorker -from ..worker.worker_stash import WorkerStash -from .base_queue import AbstractMessageHandler -from .base_queue import QueueClient -from .base_queue import QueueClientConfig -from .base_queue import QueueConfig -from .base_queue import QueueConsumer -from .base_queue import QueueProducer -from .queue_stash import ActionQueueItem -from .queue_stash import QueueStash -from .queue_stash import Status - -# Producer/Consumer heartbeat interval (in seconds) -HEARTBEAT_INTERVAL_SEC = 2 - -# Thread join timeout (in seconds) -THREAD_TIMEOUT_SEC = 30 - -# Max duration (in ms) to wait for ZMQ poller to return -ZMQ_POLLER_TIMEOUT_MSEC = 1000 - -# Duration (in seconds) after which a worker without a heartbeat will be marked as expired -WORKER_TIMEOUT_SEC = 60 - -# Duration (in seconds) after which producer without a heartbeat will be marked as expired -PRODUCER_TIMEOUT_SEC = 60 - -# Lock for working on ZMQ socket -ZMQ_SOCKET_LOCK = threading.Lock() - -logger = logging.getLogger(__name__) - - -class QueueMsgProtocol: - W_WORKER = b"MDPW01" - W_READY = b"0x01" - W_REQUEST = b"0x02" - W_REPLY = b"0x03" - W_HEARTBEAT = b"0x04" - W_DISCONNECT = b"0x05" - - -MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 - - -class Timeout: - def __init__(self, offset_sec: float): - self.__offset = float(offset_sec) - self.__next_ts: float = 0.0 - - self.reset() - - @property - def next_ts(self) -> float: - return self.__next_ts - - def reset(self) -> None: - self.__next_ts = self.now() + self.__offset - - def has_expired(self) -> bool: - return self.now() >= self.__next_ts - - @staticmethod - def now() -> float: - return time.time() - - -class Service: - def __init__(self, name: str) -> None: - self.name = name - self.requests: list[bytes] = [] - self.waiting: list[Worker] = [] # List of waiting workers - - -class Worker(SyftBaseModel): - address: bytes - identity: bytes - service: Service | None = None - syft_worker_id: UID | None = None - expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - - @field_validator("syft_worker_id", mode="before") - @classmethod - def set_syft_worker_id(cls, v: Any) -> Any: - if isinstance(v, str): - return UID(v) - return v - - def has_expired(self) -> bool: - return self.expiry_t.has_expired() - - def get_expiry(self) -> float: - return self.expiry_t.next_ts - - def reset_expiry(self) -> None: - self.expiry_t.reset() - - def _syft_worker( - self, stash: WorkerStash, credentials: SyftVerifyKey - ) -> Result[SyftWorker | None, str]: - return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) - - def __str__(self) -> str: - svc = self.service.name if self.service else None - return ( - f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " - f"syft_worker_id={self.syft_worker_id!r})" - ) - - -@serializable() -class ZMQProducer(QueueProducer): - INTERNAL_SERVICE_PREFIX = b"mmi." - - def __init__( - self, - queue_name: str, - queue_stash: QueueStash, - worker_stash: WorkerStash, - port: int, - context: AuthedServiceContext, - ) -> None: - self.id = UID().short() - self.port = port - self.queue_stash = queue_stash - self.worker_stash = worker_stash - self.queue_name = queue_name - self.auth_context = context - self._stop = Event() - self.post_init() - - @property - def address(self) -> str: - return get_queue_address(self.port) - - def post_init(self) -> None: - """Initialize producer state.""" - - self.services: dict[str, Service] = {} - self.workers: dict[bytes, Worker] = {} - self.waiting: list[Worker] = [] - self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) - self.context = zmq.Context(1) - self.socket = self.context.socket(zmq.ROUTER) - self.socket.setsockopt(LINGER, 1) - self.socket.setsockopt_string(zmq.IDENTITY, self.id) - self.poll_workers = zmq.Poller() - self.poll_workers.register(self.socket, zmq.POLLIN) - self.bind(f"tcp://*:{self.port}") - self.thread: threading.Thread | None = None - self.producer_thread: threading.Thread | None = None - - def close(self) -> None: - self._stop.set() - try: - if self.thread: - self.thread.join(THREAD_TIMEOUT_SEC) - if self.thread.is_alive(): - logger.error( - f"ZMQProducer message sending thread join timed out during closing. " - f"Queue name {self.queue_name}, " - ) - self.thread = None - - if self.producer_thread: - self.producer_thread.join(THREAD_TIMEOUT_SEC) - if self.producer_thread.is_alive(): - logger.error( - f"ZMQProducer queue thread join timed out during closing. " - f"Queue name {self.queue_name}, " - ) - self.producer_thread = None - - self.poll_workers.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister poller.", exc_info=e) - finally: - self.socket.close() - self.context.destroy() - - # self._stop.clear() - - @property - def action_service(self) -> AbstractService: - if self.auth_context.node is not None: - return self.auth_context.node.get_service("ActionService") - else: - raise Exception(f"{self.auth_context} does not have a node.") - - def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: - """recursively check collections for unresolved action objects""" - if isinstance(arg, UID): - arg = self.action_service.get(self.auth_context, arg).ok() - return self.contains_unresolved_action_objects(arg, recursion=recursion + 1) - if isinstance(arg, ActionObject): - if not arg.syft_resolved: - res = self.action_service.get(self.auth_context, arg) - if res.is_err(): - return True - arg = res.ok() - if not arg.syft_resolved: - return True - arg = arg.syft_action_data - - try: - value = False - if isinstance(arg, list): - for elem in arg: - value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 - ) - if value: - return True - if isinstance(arg, dict): - for elem in arg.values(): - value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 - ) - if value: - return True - return value - except Exception as e: - logger.exception("Failed to resolve action objects.", exc_info=e) - return True - - def unwrap_nested_actionobjects(self, data: Any) -> Any: - """recursively unwraps nested action objects""" - - if isinstance(data, list): - return [self.unwrap_nested_actionobjects(obj) for obj in data] - if isinstance(data, dict): - return { - key: self.unwrap_nested_actionobjects(obj) for key, obj in data.items() - } - if isinstance(data, ActionObject): - res = self.action_service.get(self.auth_context, data.id) - res = res.ok() if res.is_ok() else res.err() - if not isinstance(res, ActionObject): - return SyftError(message=f"{res}") - else: - nested_res = res.syft_action_data - if isinstance(nested_res, ActionObject): - raise ValueError( - "More than double nesting of ActionObjects is currently not supported" - ) - return nested_res - return data - - def contains_nested_actionobjects(self, data: Any) -> bool: - """ - returns if this is a list/set/dict that contains ActionObjects - """ - - def unwrap_collection(col: set | dict | list) -> [Any]: # type: ignore - return_values = [] - if isinstance(col, dict): - values = list(col.values()) + list(col.keys()) - else: - values = list(col) - for v in values: - if isinstance(v, list | dict | set): - return_values += unwrap_collection(v) - else: - return_values.append(v) - return return_values - - if isinstance(data, list | dict | set): - values = unwrap_collection(data) - has_action_object = any(isinstance(x, ActionObject) for x in values) - return has_action_object - elif isinstance(data, ActionObject): - return True - return False - - def preprocess_action_arg(self, arg: UID) -> UID | None: - """ "If the argument is a collection (of collections) of ActionObjects, - We want to flatten the collection and upload a new ActionObject that contains - its values. E.g. [[ActionObject1, ActionObject2],[ActionObject3, ActionObject4]] - -> [[value1, value2],[value3, value4]] - """ - res = self.action_service.get(context=self.auth_context, uid=arg) - if res.is_err(): - return arg - action_object = res.ok() - data = action_object.syft_action_data - if self.contains_nested_actionobjects(data): - new_data = self.unwrap_nested_actionobjects(data) - - new_action_object = ActionObject.from_obj( - new_data, - id=action_object.id, - syft_blob_storage_entry_id=action_object.syft_blob_storage_entry_id, - ) - res = self.action_service._set( - context=self.auth_context, action_object=new_action_object - ) - return None - - def read_items(self) -> None: - while True: - if self._stop.is_set(): - break - try: - sleep(1) - - # Items to be queued - items_to_queue = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, - status=Status.CREATED, - ).ok() - - items_to_queue = [] if items_to_queue is None else items_to_queue - - # Queue Items that are in the processing state - items_processing = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, - status=Status.PROCESSING, - ).ok() - - items_processing = [] if items_processing is None else items_processing - - for item in itertools.chain(items_to_queue, items_processing): - # TODO: if resolving fails, set queueitem to errored, and jobitem as well - if item.status == Status.CREATED: - if isinstance(item, ActionQueueItem): - action = item.kwargs["action"] - if self.contains_unresolved_action_objects( - action.args - ) or self.contains_unresolved_action_objects(action.kwargs): - continue - for arg in action.args: - self.preprocess_action_arg(arg) - for _, arg in action.kwargs.items(): - self.preprocess_action_arg(arg) - - msg_bytes = serialize(item, to_bytes=True) - worker_pool = item.worker_pool.resolve_with_context( - self.auth_context - ) - worker_pool = worker_pool.ok() - service_name = worker_pool.name - service: Service | None = self.services.get(service_name) - - # Skip adding message if corresponding service/pool - # is not registered. - if service is None: - continue - - # append request message to the corresponding service - # This list is processed in dispatch method. - - # TODO: Logic to evaluate the CAN RUN Condition - service.requests.append(msg_bytes) - item.status = Status.PROCESSING - res = self.queue_stash.update(item.syft_client_verify_key, item) - if res.is_err(): - logger.error( - f"Failed to update queue item={item} error={res.err()}" - ) - elif item.status == Status.PROCESSING: - # Evaluate Retry condition here - # If job running and timeout or job status is KILL - # or heartbeat fails - # or container id doesn't exists, kill process or container - # else decrease retry count and mark status as CREATED. - pass - except Exception as e: - print(e, file=sys.stderr) - item.status = Status.ERRORED - res = self.queue_stash.update(item.syft_client_verify_key, item) - if res.is_err(): - logger.error( - f"Failed to update queue item={item} error={res.err()}" - ) - - def run(self) -> None: - self.thread = threading.Thread(target=self._run) - self.thread.start() - - self.producer_thread = threading.Thread(target=self.read_items) - self.producer_thread.start() - - def send(self, worker: bytes, message: bytes | list[bytes]) -> None: - worker_obj = self.require_worker(worker) - self.send_to_worker(worker_obj, QueueMsgProtocol.W_REQUEST, message) - - def bind(self, endpoint: str) -> None: - """Bind producer to endpoint.""" - self.socket.bind(endpoint) - logger.info(f"ZMQProducer endpoint: {endpoint}") - - def send_heartbeats(self) -> None: - """Send heartbeats to idle workers if it's time""" - if self.heartbeat_t.has_expired(): - for worker in self.waiting: - self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT) - self.heartbeat_t.reset() - - def purge_workers(self) -> None: - """Look for & kill expired workers. - - Workers are oldest to most recent, so we stop at the first alive worker. - """ - # work on a copy of the iterator - for worker in self.waiting: - res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) - if res.is_err() or (syft_worker := res.ok()) is None: - logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}") - continue - - if worker.has_expired() or syft_worker.to_be_deleted: - logger.info(f"Deleting expired worker id={worker}") - self.delete_worker(worker, syft_worker.to_be_deleted) - - # relative - from ...service.worker.worker_service import WorkerService - - worker_service = cast( - WorkerService, self.auth_context.node.get_service(WorkerService) - ) - worker_service._delete(self.auth_context, syft_worker) - - def update_consumer_state_for_worker( - self, syft_worker_id: UID, consumer_state: ConsumerState - ) -> None: - if self.worker_stash is None: - logger.error( # type: ignore[unreachable] - f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" - ) - return - - try: - # Check if worker is present in the database - worker = self.worker_stash.get_by_uid( - credentials=self.worker_stash.partition.root_verify_key, - uid=syft_worker_id, - ) - if worker.is_ok() and worker.ok() is None: - return - - res = self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, - worker_uid=syft_worker_id, - consumer_state=consumer_state, - ) - if res.is_err(): - logger.error( - f"Failed to update consumer state for worker id={syft_worker_id} " - f"to state: {consumer_state} error={res.err()}", - ) - except Exception as e: - logger.error( - f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}", - exc_info=e, - ) - - def worker_waiting(self, worker: Worker) -> None: - """This worker is now waiting for work.""" - # Queue to broker and service waiting lists - if worker not in self.waiting: - self.waiting.append(worker) - if worker.service is not None and worker not in worker.service.waiting: - worker.service.waiting.append(worker) - worker.reset_expiry() - self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) - self.dispatch(worker.service, None) - - def dispatch(self, service: Service, msg: bytes) -> None: - """Dispatch requests to waiting workers as possible""" - if msg is not None: # Queue message if any - service.requests.append(msg) - - self.purge_workers() - while service.waiting and service.requests: - # One worker consuming only one message at a time. - msg = service.requests.pop(0) - worker = service.waiting.pop(0) - self.waiting.remove(worker) - self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, msg) - - def send_to_worker( - self, - worker: Worker, - command: bytes, - msg: bytes | list | None = None, - ) -> None: - """Send message to worker. - - If message is provided, sends that message. - """ - - if self.socket.closed: - logger.warning("Socket is closed. Cannot send message.") - return - - if msg is None: - msg = [] - elif not isinstance(msg, list): - msg = [msg] - - # ZMQProducer send frames: [address, empty, header, command, ...data] - core = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] - msg = core + msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQProducer send: {core}") - - with ZMQ_SOCKET_LOCK: - try: - self.socket.send_multipart(msg) - except zmq.ZMQError as e: - logger.error("ZMQProducer send error", exc_info=e) - - def _run(self) -> None: - try: - while True: - if self._stop.is_set(): - logger.info("ZMQProducer thread stopped") - return - - for service in self.services.values(): - self.dispatch(service, None) - - items = None - - try: - items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except Exception as e: - logger.exception("ZMQProducer poll error", exc_info=e) - - if items: - msg = self.socket.recv_multipart() - - if len(msg) < 3: - logger.error(f"ZMQProducer invalid recv: {msg}") - continue - - # ZMQProducer recv frames: [address, empty, header, command, ...data] - (address, _, header, command, *data) = msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQProducer recv: {msg[:4]}") - - if header == QueueMsgProtocol.W_WORKER: - self.process_worker(address, command, data) - else: - logger.error(f"Invalid message header: {header}") - - self.send_heartbeats() - self.purge_workers() - except Exception as e: - logger.exception("ZMQProducer thread exception", exc_info=e) - - def require_worker(self, address: bytes) -> Worker: - """Finds the worker (creates if necessary).""" - identity = hexlify(address) - worker = self.workers.get(identity) - if worker is None: - worker = Worker(identity=identity, address=address) - self.workers[identity] = worker - return worker - - def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> None: - worker_ready = hexlify(address) in self.workers - worker = self.require_worker(address) - - if QueueMsgProtocol.W_READY == command: - service_name = data.pop(0).decode() - syft_worker_id = data.pop(0).decode() - if worker_ready: - # Not first command in session or Reserved service name - # If worker was already present, then we disconnect it first - # and wait for it to re-register itself to the producer. This ensures that - # we always have a healthy worker in place that can talk to the producer. - self.delete_worker(worker, True) - else: - # Attach worker to service and mark as idle - if service_name in self.services: - service: Service | None = self.services.get(service_name) - else: - service = Service(service_name) - self.services[service_name] = service - if service is not None: - worker.service = service - logger.info(f"New worker: {worker}") - worker.syft_worker_id = UID(syft_worker_id) - self.worker_waiting(worker) - - elif QueueMsgProtocol.W_HEARTBEAT == command: - if worker_ready: - # If worker is ready then reset expiry - # and add it to worker waiting list - # if not already present - self.worker_waiting(worker) - else: - logger.info(f"Got heartbeat, but worker not ready. {worker}") - self.delete_worker(worker, True) - elif QueueMsgProtocol.W_DISCONNECT == command: - logger.info(f"Removing disconnected worker: {worker}") - self.delete_worker(worker, False) - else: - logger.error(f"Invalid command: {command!r}") - - def delete_worker(self, worker: Worker, disconnect: bool) -> None: - """Deletes worker from all data structures, and deletes worker.""" - if disconnect: - self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT) - - if worker.service and worker in worker.service.waiting: - worker.service.waiting.remove(worker) - - if worker in self.waiting: - self.waiting.remove(worker) - - self.workers.pop(worker.identity, None) - - if worker.syft_worker_id is not None: - self.update_consumer_state_for_worker( - worker.syft_worker_id, ConsumerState.DETACHED - ) - - @property - def alive(self) -> bool: - return not self.socket.closed - - -@serializable(attrs=["_subscriber"]) -class ZMQConsumer(QueueConsumer): - def __init__( - self, - message_handler: AbstractMessageHandler, - address: str, - queue_name: str, - service_name: str, - syft_worker_id: UID | None = None, - worker_stash: WorkerStash | None = None, - verbose: bool = False, - ) -> None: - self.address = address - self.message_handler = message_handler - self.service_name = service_name - self.queue_name = queue_name - self.context = zmq.Context() - self.poller = zmq.Poller() - self.socket = None - self.verbose = verbose - self.id = UID().short() - self._stop = Event() - self.syft_worker_id = syft_worker_id - self.worker_stash = worker_stash - self.post_init() - - def reconnect_to_producer(self) -> None: - """Connect or reconnect to producer""" - if self.socket: - self.poller.unregister(self.socket) # type: ignore[unreachable] - self.socket.close() - self.socket = self.context.socket(zmq.DEALER) - self.socket.linger = 0 - self.socket.setsockopt_string(zmq.IDENTITY, self.id) - self.socket.connect(self.address) - self.poller.register(self.socket, zmq.POLLIN) - - logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") - - # Register queue with the producer - self.send_to_producer( - QueueMsgProtocol.W_READY, - [self.service_name.encode(), str(self.syft_worker_id).encode()], - ) - - def post_init(self) -> None: - self.thread: threading.Thread | None = None - self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) - self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) - self.reconnect_to_producer() - - def disconnect_from_producer(self) -> None: - self.send_to_producer(QueueMsgProtocol.W_DISCONNECT) - - def close(self) -> None: - self.disconnect_from_producer() - self._stop.set() - try: - if self.thread is not None: - self.thread.join(timeout=THREAD_TIMEOUT_SEC) - if self.thread.is_alive(): - logger.error( - f"ZMQConsumer thread join timed out during closing. " - f"SyftWorker id {self.syft_worker_id}, " - f"service name {self.service_name}." - ) - self.thread = None - self.poller.unregister(self.socket) - except Exception as e: - logger.error("Failed to unregister worker.", exc_info=e) - finally: - self.socket.close() - self.context.destroy() - # self._stop.clear() - - def send_to_producer( - self, - command: bytes, - msg: bytes | list | None = None, - ) -> None: - """Send message to producer. - - If no msg is provided, creates one internally - """ - if self.socket.closed: - logger.warning("Socket is closed. Cannot send message.") - return - - if msg is None: - msg = [] - elif not isinstance(msg, list): - msg = [msg] - - # ZMQConsumer send frames: [empty, header, command, ...data] - core = [b"", QueueMsgProtocol.W_WORKER, command] - msg = core + msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - logger.info(f"ZMQ Consumer send: {core}") - - with ZMQ_SOCKET_LOCK: - try: - self.socket.send_multipart(msg) - except zmq.ZMQError as e: - logger.error("ZMQConsumer send error", exc_info=e) - - def _run(self) -> None: - """Send reply, if any, to producer and wait for next request.""" - try: - while True: - if self._stop.is_set(): - logger.info("ZMQConsumer thread stopped") - return - - try: - items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except ContextTerminated: - logger.info("Context terminated") - return - except Exception as e: - logger.error("ZMQ poll error", exc_info=e) - continue - - if items: - msg = self.socket.recv_multipart() - - # mark as alive - self.set_producer_alive() - - if len(msg) < 3: - logger.error(f"ZMQConsumer invalid recv: {msg}") - continue - - # Message frames recieved by consumer: - # [empty, header, command, ...data] - (_, _, command, *data) = msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQConsumer recv: {msg[:-4]}") - - if command == QueueMsgProtocol.W_REQUEST: - # Call Message Handler - try: - message = data.pop() - self.associate_job(message) - self.message_handler.handle_message( - message=message, - syft_worker_id=self.syft_worker_id, - ) - except Exception as e: - logger.exception("Couldn't handle message", exc_info=e) - finally: - self.clear_job() - elif command == QueueMsgProtocol.W_HEARTBEAT: - self.set_producer_alive() - elif command == QueueMsgProtocol.W_DISCONNECT: - self.reconnect_to_producer() - else: - logger.error(f"ZMQConsumer invalid command: {command}") - else: - if not self.is_producer_alive(): - logger.info("Producer check-alive timed out. Reconnecting.") - self.reconnect_to_producer() - self.set_producer_alive() - - if not self._stop.is_set(): - self.send_heartbeat() - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - logger.info("zmq.ETERM") - else: - logger.exception("zmq.ZMQError", exc_info=e) - except Exception as e: - logger.exception("ZMQConsumer thread exception", exc_info=e) - - def set_producer_alive(self) -> None: - self.producer_ping_t.reset() - - def is_producer_alive(self) -> bool: - # producer timer is within timeout - return not self.producer_ping_t.has_expired() - - def send_heartbeat(self) -> None: - if self.heartbeat_t.has_expired() and self.is_producer_alive(): - self.send_to_producer(QueueMsgProtocol.W_HEARTBEAT) - self.heartbeat_t.reset() - - def run(self) -> None: - self.thread = threading.Thread(target=self._run) - self.thread.start() - - def associate_job(self, message: Frame) -> None: - try: - queue_item = _deserialize(message, from_bytes=True) - self._set_worker_job(queue_item.job_id) - except Exception as e: - logger.exception("Could not associate job", exc_info=e) - - def clear_job(self) -> None: - self._set_worker_job(None) - - def _set_worker_job(self, job_id: UID | None) -> None: - if self.worker_stash is not None: - consumer_state = ( - ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING - ) - res = self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, - worker_uid=self.syft_worker_id, - consumer_state=consumer_state, - ) - if res.is_err(): - logger.error( - f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" - ) - - @property - def alive(self) -> bool: - return not self.socket.closed and self.is_producer_alive() - - -@serializable() -class ZMQClientConfig(SyftObject, QueueClientConfig): - __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_4 - - id: UID | None = None # type: ignore[assignment] - hostname: str = "127.0.0.1" - queue_port: int | None = None - # TODO: setting this to false until we can fix the ZMQ - # port issue causing tests to randomly fail - create_producer: bool = False - n_consumers: int = 0 - consumer_service: str | None = None - - -@serializable(attrs=["host"]) -class ZMQClient(QueueClient): - """ZMQ Client for creating producers and consumers.""" - - producers: dict[str, ZMQProducer] - consumers: defaultdict[str, list[ZMQConsumer]] - - def __init__(self, config: ZMQClientConfig) -> None: - self.host = config.hostname - self.producers = {} - self.consumers = defaultdict(list) - self.config = config - - @staticmethod - def _get_free_tcp_port(host: str) -> int: - with socketserver.TCPServer((host, 0), None) as s: - free_port = s.server_address[1] - - return free_port - - def add_producer( - self, - queue_name: str, - port: int | None = None, - queue_stash: QueueStash | None = None, - worker_stash: WorkerStash | None = None, - context: AuthedServiceContext | None = None, - ) -> ZMQProducer: - """Add a producer of a queue. - - A queue can have at most one producer attached to it. - """ - - if port is None: - if self.config.queue_port is None: - self.config.queue_port = self._get_free_tcp_port(self.host) - port = self.config.queue_port - else: - port = self.config.queue_port - - producer = ZMQProducer( - queue_name=queue_name, - queue_stash=queue_stash, - port=port, - context=context, - worker_stash=worker_stash, - ) - self.producers[queue_name] = producer - return producer - - def add_consumer( - self, - queue_name: str, - message_handler: AbstractMessageHandler, - service_name: str, - address: str | None = None, - worker_stash: WorkerStash | None = None, - syft_worker_id: UID | None = None, - ) -> ZMQConsumer: - """Add a consumer to a queue - - A queue should have at least one producer attached to the group. - - """ - - if address is None: - address = get_queue_address(self.config.queue_port) - - consumer = ZMQConsumer( - queue_name=queue_name, - message_handler=message_handler, - address=address, - service_name=service_name, - syft_worker_id=syft_worker_id, - worker_stash=worker_stash, - ) - self.consumers[queue_name].append(consumer) - - return consumer - - def send_message( - self, - message: bytes, - queue_name: str, - worker: bytes | None = None, - ) -> SyftSuccess | SyftError: - producer = self.producers.get(queue_name) - if producer is None: - return SyftError( - message=f"No producer attached for queue: {queue_name}. Please add a producer for it." - ) - try: - producer.send(message=message, worker=worker) - except Exception as e: - # stdlib - return SyftError( - message=f"Failed to send message to: {queue_name} with error: {e}" - ) - return SyftSuccess( - message=f"Successfully queued message to : {queue_name}", - ) - - def close(self) -> SyftError | SyftSuccess: - try: - for consumers in self.consumers.values(): - for consumer in consumers: - # make sure look is stopped - consumer.close() - - for producer in self.producers.values(): - # make sure loop is stopped - producer.close() - # close existing connection. - except Exception as e: - return SyftError(message=f"Failed to close connection: {e}") - - return SyftSuccess(message="All connections closed.") - - def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: - if queue_name not in self.producers: - return SyftError(message=f"No producer running for : {queue_name}") - - producer = self.producers[queue_name] - - # close existing connection. - producer.close() - - # add a new connection - self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore - - return SyftSuccess(message=f"Queue: {queue_name} successfully purged") - - def purge_all(self) -> SyftError | SyftSuccess: - for queue_name in self.producers: - self.purge_queue(queue_name=queue_name) - - return SyftSuccess(message="Successfully purged all queues.") - - -@serializable() -class ZMQQueueConfig(QueueConfig): - def __init__( - self, - client_type: type[ZMQClient] | None = None, - client_config: ZMQClientConfig | None = None, - thread_workers: bool = False, - ): - self.client_type = client_type or ZMQClient - self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.thread_workers = thread_workers From 7ccb3250413597c30a82d53295fef4c681dbbfc3 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 23 Jul 2024 20:43:27 +0530 Subject: [PATCH 11/20] fix class version and serde in zmq consumer/producer/config --- packages/syft/src/syft/service/queue/zmq_client.py | 8 ++++---- packages/syft/src/syft/service/queue/zmq_common.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py index 932056002de..2aca090b5c6 100644 --- a/packages/syft/src/syft/service/queue/zmq_client.py +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext -from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.uid import UID from ...util.util import get_queue_address @@ -24,7 +24,7 @@ @serializable() class ZMQClientConfig(SyftObject, QueueClientConfig): __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_4 + __version__ = SYFT_OBJECT_VERSION_1 id: UID | None = None # type: ignore[assignment] hostname: str = "127.0.0.1" @@ -36,7 +36,7 @@ class ZMQClientConfig(SyftObject, QueueClientConfig): consumer_service: str | None = None -@serializable(attrs=["host"]) +@serializable(attrs=["host"], canonical_name="ZMQClient", version=1) class ZMQClient(QueueClient): """ZMQ Client for creating producers and consumers.""" @@ -175,7 +175,7 @@ def purge_all(self) -> SyftError | SyftSuccess: return SyftSuccess(message="Successfully purged all queues.") -@serializable() +@serializable(canonical_name="ZMQQueueConfig", version=1) class ZMQQueueConfig(QueueConfig): def __init__( self, diff --git a/packages/syft/src/syft/service/queue/zmq_common.py b/packages/syft/src/syft/service/queue/zmq_common.py index abe7947a806..5f72d476cd6 100644 --- a/packages/syft/src/syft/service/queue/zmq_common.py +++ b/packages/syft/src/syft/service/queue/zmq_common.py @@ -8,7 +8,7 @@ from result import Result # relative -from ...node.credentials import SyftVerifyKey +from ...server.credentials import SyftVerifyKey from ...types.base import SyftBaseModel from ...types.uid import UID from ..worker.worker_pool import SyftWorker From 17d766323d2a6b5a96a24c5d832b08f75f55186d Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 23 Jul 2024 20:58:28 +0530 Subject: [PATCH 12/20] fix reference of Values.node to Values.server --- .../grid/helm/syft/templates/backend/backend-statefulset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 0e14b9a9953..de55435f98b 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -140,7 +140,7 @@ spec: - name: MIN_SIZE_BLOB_STORAGE_MB value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }} # Tracing - {{- if .Values.node.tracing }} + {{- if .Values.server.tracing }} - name: TRACING value: "True" {{- end }} From 373f5feba5ed5cc7193b8d9d5875bf5ebe83ada8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 24 Jul 2024 15:41:51 +0530 Subject: [PATCH 13/20] pass OTEL_SERVICE_NAME and OTEL_EXPORTER_* as env to workers in workerpools --- packages/syft/src/syft/service/worker/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 55d608c2964..29bdd4d0b1a 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -350,6 +350,16 @@ def create_kubernetes_pool( "N_CONSUMERS": "1", "CREATE_PRODUCER": "False", "INMEMORY_WORKERS": "False", + "OTEL_SERVICE_NAME": f"{pool_name}", + "OTEL_PYTHON_LOG_CORRELATION": os.environ.get( + "OTEL_PYTHON_LOG_CORRELATION" + ), + "OTEL_EXPORTER_OTLP_ENDPOINT": os.environ.get( + "OTEL_EXPORTER_OTLP_ENDPOINT" + ), + "OTEL_EXPORTER_OTLP_PROTOCOL": os.environ.get( + "OTEL_EXPORTER_OTLP_PROTOCOL" + ), }, ) From a7c26cf3ed7075c925f482578975538e7bfd5bb0 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 10:35:20 +1000 Subject: [PATCH 14/20] Removed old Response return type --- packages/syft/src/syft/service/queue/zmq_common.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_common.py b/packages/syft/src/syft/service/queue/zmq_common.py index 5f72d476cd6..35331541c86 100644 --- a/packages/syft/src/syft/service/queue/zmq_common.py +++ b/packages/syft/src/syft/service/queue/zmq_common.py @@ -5,11 +5,12 @@ # third party from pydantic import field_validator -from result import Result # relative from ...server.credentials import SyftVerifyKey from ...types.base import SyftBaseModel +from ...types.errors import SyftException +from ...types.result import as_result from ...types.uid import UID from ..worker.worker_pool import SyftWorker from ..worker.worker_stash import WorkerStash @@ -103,10 +104,13 @@ def get_expiry(self) -> float: def reset_expiry(self) -> None: self.expiry_t.reset() + @as_result(SyftException) def _syft_worker( self, stash: WorkerStash, credentials: SyftVerifyKey - ) -> Result[SyftWorker | None, str]: - return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) + ) -> SyftWorker | None: + return stash.get_by_uid( + credentials=credentials, uid=self.syft_worker_id + ).unwrap() def __str__(self) -> str: svc = self.service.name if self.service else None From 6fc5a004e93f686ce072d9c1c0aa534a52950635 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 10:42:33 +1000 Subject: [PATCH 15/20] Added nosec to instrumentation import try catch blocks --- packages/syft/src/syft/store/blob_storage/seaweedfs.py | 2 +- packages/syft/src/syft/store/mongo_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 8603c34e68d..96996d7d9a9 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -57,7 +57,7 @@ from opentelemetry.instrumentation.botocore import BotocoreInstrumentor BotocoreInstrumentor().instrument() - except Exception: + except Exception: # nosec pass diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index 61012aa7c2b..c4591f6aa5e 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -24,7 +24,7 @@ from opentelemetry.instrumentation.pymongo import PymongoInstrumentor PymongoInstrumentor().instrument() - except Exception: + except Exception: # nosec pass From 0a903473531753abef6f3c956206a854bd7d4b92 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 10:55:54 +1000 Subject: [PATCH 16/20] Fixed ZMQClient interface changes are file split --- .../syft/src/syft/service/queue/zmq_client.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py index 2aca090b5c6..deeeb97a32b 100644 --- a/packages/syft/src/syft/service/queue/zmq_client.py +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -5,11 +5,11 @@ # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext +from ...types.errors import SyftException from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.uid import UID from ...util.util import get_queue_address -from ..response import SyftError from ..response import SyftSuccess from ..worker.worker_stash import WorkerStash from .base_queue import AbstractMessageHandler @@ -121,24 +121,24 @@ def send_message( message: bytes, queue_name: str, worker: bytes | None = None, - ) -> SyftSuccess | SyftError: + ) -> SyftSuccess: producer = self.producers.get(queue_name) if producer is None: - return SyftError( - message=f"No producer attached for queue: {queue_name}. Please add a producer for it." + raise SyftException( + public_message=f"No producer attached for queue: {queue_name}. Please add a producer for it." ) try: producer.send(message=message, worker=worker) except Exception as e: # stdlib - return SyftError( - message=f"Failed to send message to: {queue_name} with error: {e}" + raise SyftException( + public_message=f"Failed to send message to: {queue_name} with error: {e}" ) return SyftSuccess( message=f"Successfully queued message to : {queue_name}", ) - def close(self) -> SyftError | SyftSuccess: + def close(self) -> SyftSuccess: try: for consumers in self.consumers.values(): for consumer in consumers: @@ -150,13 +150,15 @@ def close(self) -> SyftError | SyftSuccess: producer.close() # close existing connection. except Exception as e: - return SyftError(message=f"Failed to close connection: {e}") + raise SyftException(public_message=f"Failed to close connection: {e}") return SyftSuccess(message="All connections closed.") - def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: + def purge_queue(self, queue_name: str) -> SyftSuccess: if queue_name not in self.producers: - return SyftError(message=f"No producer running for : {queue_name}") + raise SyftException( + public_message=f"No producer running for : {queue_name}" + ) producer = self.producers[queue_name] @@ -168,7 +170,7 @@ def purge_queue(self, queue_name: str) -> SyftError | SyftSuccess: return SyftSuccess(message=f"Queue: {queue_name} successfully purged") - def purge_all(self) -> SyftError | SyftSuccess: + def purge_all(self) -> SyftSuccess: for queue_name in self.producers: self.purge_queue(queue_name=queue_name) From 7f1063a70bb345a5ca85f1d5c125747ba33e864c Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 14:51:38 +1000 Subject: [PATCH 17/20] Upgrading OTEL libraries - Adding BatchLogRecordProcessor --- packages/grid/backend/grid/main.py | 54 +++++++++++++++++-- packages/grid/backend/grid/start.sh | 5 ++ packages/grid/helm/syft/values.yaml | 3 +- packages/syft/setup.cfg | 28 +++++----- packages/syft/src/syft/server/uvicorn.py | 1 + .../src/syft/service/notifier/notifier.py | 8 ++- .../src/syft/service/notifier/smtp_client.py | 11 +++- .../src/syft/store/blob_storage/seaweedfs.py | 3 ++ packages/syft/src/syft/store/mongo_client.py | 5 ++ tox.ini | 27 ++++++---- 10 files changed, 113 insertions(+), 32 deletions(-) diff --git a/packages/grid/backend/grid/main.py b/packages/grid/backend/grid/main.py index 2af898707e3..12af4165179 100644 --- a/packages/grid/backend/grid/main.py +++ b/packages/grid/backend/grid/main.py @@ -80,7 +80,55 @@ def healthcheck() -> dict[str, str]: if settings.TRACING_ENABLED: - # third party - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + try: + # stdlib + import os + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", None) + # third party + from opentelemetry._logs import set_logger_provider + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk._logs import LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + from opentelemetry.sdk.resources import Resource + + logger_provider = LoggerProvider( + resource=Resource.create( + { + "service.name": "backend-container", + } + ), + ) + set_logger_provider(logger_provider) + + exporter = OTLPLogExporter(insecure=True, endpoint=endpoint) + + logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) + handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) + + # Attach OTLP handler to root logger + logging.getLogger().addHandler(handler) + logger = logging.getLogger(__name__) + message = "> Added OTEL BatchLogRecordProcessor" + print(message) + logger.info(message) + + except Exception as e: + print(f"Failed to load OTLPLogExporter. {e}") - FastAPIInstrumentor.instrument_app(app) + # third party + try: + # third party + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) + message = "> Added OTEL FastAPIInstrumentor" + print(message) + logger = logging.getLogger(__name__) + logger.info(message) + except Exception as e: + error = f"Failed to load FastAPIInstrumentor. {e}" + print(error) + logger = logging.getLogger(__name__) + logger.error(message) diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index e584df16b7a..daee022cc9e 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -11,6 +11,9 @@ SERVER_TYPE=${SERVER_TYPE:-datasite} APPDIR=${APPDIR:-$HOME/app} RELOAD="" ROOT_PROC="" +TRACING=${TRACING:"False"} + +echo "Starting with TRACING=${TRACING}" if [[ ${DEV_MODE} == "True" ]]; then @@ -26,6 +29,8 @@ then ROOT_PROC="python -m debugpy --listen 0.0.0.0:5678 -m" fi + + if [[ ${TRACING} == "True" ]]; then echo "OpenTelemetry Enabled" diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 81c823fcc8a..b7d4bbcb882 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -10,7 +10,6 @@ global: kaniko: version: "v1.23.2" - # ================================================================================= mongo: @@ -186,7 +185,7 @@ server: smtp: # Existing secret for SMTP with key 'smtpPassword' existingSecret: null - host: smtp.sendgrid.net + host: hostname port: 587 from: noreply@openmined.org username: apikey diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index ca8f084c3bf..6de365022e5 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -104,20 +104,20 @@ dev = aiosmtpd==1.4.6 telemetry = - opentelemetry-api==1.25.0 - opentelemetry-sdk==1.25.0 - opentelemetry-exporter-otlp==1.25.0 - opentelemetry-instrumentation==0.46b0 - opentelemetry-instrumentation-requests==0.46b0 - opentelemetry-instrumentation-fastapi==0.46b0 - opentelemetry-instrumentation-pymongo==0.46b0 - opentelemetry-instrumentation-botocore==0.46b0 - opentelemetry-instrumentation-logging==0.46b0 - ; opentelemetry-instrumentation-asyncio==0.46b0 - ; opentelemetry-instrumentation-sqlite3==0.46b0 - ; opentelemetry-instrumentation-threading==0.46b0 - ; opentelemetry-instrumentation-jinja2==0.46b0 - ; opentelemetry-instrumentation-system-metrics==0.46b0 + opentelemetry-api==1.27.0 + opentelemetry-sdk==1.27.0 + opentelemetry-exporter-otlp==1.27.0 + opentelemetry-instrumentation==0.48b0 + opentelemetry-instrumentation-requests==0.48b0 + opentelemetry-instrumentation-fastapi==0.48b0 + opentelemetry-instrumentation-pymongo==0.48b0 + opentelemetry-instrumentation-botocore==0.48b0 + opentelemetry-instrumentation-logging==0.48b0 + ; opentelemetry-instrumentation-asyncio==0.48b0 + ; opentelemetry-instrumentation-sqlite3==0.48b0 + ; opentelemetry-instrumentation-threading==0.48b0 + ; opentelemetry-instrumentation-jinja2==0.48b0 + ; opentelemetry-instrumentation-system-metrics==0.48b0 # pytest>=8.0 broke pytest-lazy-fixture which doesn't seem to be actively maintained # temporarily pin to pytest<8 diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 088b22f44c9..0be6fb5ea9a 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -119,6 +119,7 @@ def app_factory() -> FastAPI: from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor FastAPIInstrumentor().instrument_app(app) + print("> Added OTEL FastAPIInstrumentor") return app diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 67974e4cd99..31ec3362170 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -128,10 +128,14 @@ def send( self.smtp_client.send( # type: ignore sender=self.sender, receiver=receiver_email, subject=subject, body=body ) - print(f"> Sent email: {subject} to {receiver_email}") + message = f"> Sent email: {subject} to {receiver_email}" + print(message) + logging.info(message) return SyftSuccess(message="Email sent successfully!") except Exception: - print(f"> Error sending email: {subject} to {receiver_email}") + message = f"> Error sending email: {subject} to {receiver_email}" + print(message) + logging.info(message) return SyftError(message="Failed to send an email.") # raise SyftException.from_exception( # exc, diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py index 0f8bfce81da..eef25440af8 100644 --- a/packages/syft/src/syft/service/notifier/smtp_client.py +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -1,6 +1,7 @@ # stdlib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +import logging import smtplib # third party @@ -12,6 +13,8 @@ SOCKET_TIMEOUT = 5 # seconds +logger = logging.getLogger(__name__) + class SMTPClient(BaseModel): server: str @@ -44,9 +47,10 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non text = msg.as_string() server.sendmail(sender, ", ".join(receiver), text) return None - except Exception: + except Exception as e: + logger.error(f"Unable to send email. {e}") raise SyftException( - public_message="Ops! Something went wrong while trying to send an email." + public_message="Oops! Something went wrong while trying to send an email." ) @classmethod @@ -73,4 +77,7 @@ def check_credentials( smtp_server.login(username, password) return True except Exception as e: + message = f"SMTP check_credentials failed. {e}" + print(message) + logger.error(message) raise SyftException(public_message=str(e)) diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 96996d7d9a9..3409035906f 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -57,6 +57,9 @@ from opentelemetry.instrumentation.botocore import BotocoreInstrumentor BotocoreInstrumentor().instrument() + message = "> Added OTEL BotocoreInstrumentor" + print(message) + logger.info(message) except Exception: # nosec pass diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index c4591f6aa5e..9767059a1cb 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -1,4 +1,5 @@ # stdlib +import logging from threading import Lock from typing import Any @@ -24,6 +25,10 @@ from opentelemetry.instrumentation.pymongo import PymongoInstrumentor PymongoInstrumentor().instrument() + message = "> Added OTEL PymongoInstrumentor" + print(message) + logger = logging.getLogger(__name__) + logger.info(message) except Exception: # nosec pass diff --git a/tox.ini b/tox.ini index 821656bd76c..3b71838f2ee 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = dev.k8s.cleanup dev.k8s.destroy dev.k8s.destroyall + dev.k8s.install.signoz lint stack.test.integration syft.docs @@ -422,8 +423,6 @@ setenv = TEST_QUERY_LIMIT_SIZE={env:TEST_QUERY_LIMIT_SIZE:500000} commands = bash -c "python --version || true" - bash -c "python3 --version || true" - bash -c "python3.12 --version || true" bash -c "echo Running with GITHUB_CI=$GITHUB_CI; date" bash -c "echo Running with TEST_EXTERNAL_REGISTRY=$TEST_EXTERNAL_REGISTRY; date" python -c 'import syft as sy; sy.stage_protocol_changes()' @@ -966,7 +965,7 @@ commands = description = Patch CoreDNS to resolve k3d-registry.localhost changedir = {toxinidir} passenv=HOME,USER,CLUSTER_NAME -setenv = +setenv= CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = bash @@ -980,12 +979,19 @@ commands = [testenv:dev.k8s.install.signoz] description = Install Signoz on local Kubernetes cluster changedir = {toxinidir} -passenv=HOME,USER -allowlist_externals = +passenv=HOME,USER,CLUSTER_NAME +allowlist_externals= bash -commands = +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:test-datasite-1} + SIGNOZ_PORT = {env:SIGNOZ_PORT:3301} +commands= + bash -c 'if [ "{posargs}" ]; then kubectl config use-context {posargs}; fi' bash -c 'helm repo add signoz https://charts.signoz.io && helm repo update' bash -c 'helm install syft signoz/signoz --namespace platform --create-namespace || true' + # bash -c 'k3d cluster edit ${CLUSTER_NAME} --port-add "${SIGNOZ_PORT}:3301@loadbalancer"' + # bash packages/grid/scripts/wait_for.sh service syft-signoz-frontend --context k3d-{env:CLUSTER_NAME} --namespace platform + [testenv:dev.k8s.start] description = Start local Kubernetes registry & cluster with k3d @@ -994,6 +1000,7 @@ passenv = HOME, USER setenv = CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} CLUSTER_HTTP_PORT = {env:CLUSTER_HTTP_PORT:8080} + # SIGNOZ_PORT = {env:SIGNOZ_PORT:3301} allowlist_externals = bash sleep @@ -1003,9 +1010,11 @@ commands = tox -e dev.k8s.registry ; for NodePort to work add the following --> -p "NodePort:NodePort@loadbalancer" - bash -c 'k3d cluster create ${CLUSTER_NAME} -p "${CLUSTER_HTTP_PORT}:80@loadbalancer" --registry-use k3d-registry.localhost:5800 {posargs} && \ - kubectl --context k3d-${CLUSTER_NAME} create namespace syft || true' - + bash -c 'k3d cluster create ${CLUSTER_NAME} \ + -p "${CLUSTER_HTTP_PORT}:80@loadbalancer" \ + --registry-use k3d-registry.localhost:5800 {posargs} && \ + kubectl --context k3d-${CLUSTER_NAME} create namespace syft || true' + # -p "${SIGNOZ_PORT}:3301@loadbalancer" \ ; patch coredns tox -e dev.k8s.patch.coredns From 2df2261c915fae85d8a50186a426582cdc870557 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 16:39:37 +1000 Subject: [PATCH 18/20] Changing TRACING to default to False - WIP: its not working, my yaml is wrong I think --- packages/grid/backend/grid/start.sh | 4 +++- packages/grid/devspace.yaml | 7 +++++++ .../syft/templates/backend/backend-statefulset.yaml | 6 ++++-- tox.ini | 12 ++++++++++-- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index daee022cc9e..9795f99a1dc 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -31,7 +31,7 @@ fi -if [[ ${TRACING} == "True" ]]; +if [[ "${TRACING,,}" == "true" ]]; then echo "OpenTelemetry Enabled" @@ -45,6 +45,8 @@ then # TODO: uvicorn postfork is not stable with OpenTelemetry # ROOT_PROC="opentelemetry-instrument" +else + echo "OpenTelemetry Disabled" fi export CREDENTIALS_PATH=${CREDENTIALS_PATH:-$HOME/data/creds/credentials.json} diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 05247fcba12..034067d4316 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -124,6 +124,13 @@ profiles: value: side: low + - name: tracing + description: "Enable Tracing" + patches: + - op: replace + path: deployments.syft.helm.values.server.tracing + value: "True" + - name: bigquery-scenario-tests description: "Deploy a datasite for bigquery scenario testing" patches: diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 42e5689f8b6..15a4bafc3ac 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -142,9 +142,11 @@ spec: - name: MIN_SIZE_BLOB_STORAGE_MB value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }} # Tracing - {{- if .Values.server.tracing }} - name: TRACING - value: "True" + {{- if .Values.server.tracing }} + value: {{ .Values.server.tracing | quote }} + {{- else }} + value: "False" {{- end }} # Enclave attestation {{- if .Values.attestation.enabled }} diff --git a/tox.ini b/tox.ini index 3b71838f2ee..49eb1f3e47e 100644 --- a/tox.ini +++ b/tox.ini @@ -421,6 +421,7 @@ setenv = SERVER_PORT = {env:SERVER_PORT:8080} TEST_EXTERNAL_REGISTRY = {env:TEST_EXTERNAL_REGISTRY:k3d-registry.localhost:5800} TEST_QUERY_LIMIT_SIZE={env:TEST_QUERY_LIMIT_SIZE:500000} + TRACING={env:TRACING:False} commands = bash -c "python --version || true" bash -c "echo Running with GITHUB_CI=$GITHUB_CI; date" @@ -466,7 +467,7 @@ commands = bash -c "pytest -s -x --nbmake notebooks/scenarios/bigquery -p no:randomly --ignore=notebooks/scenarios/bigquery/sync -vvvv --nbmake-timeout=1000 --log-cli-level=DEBUG --capture=no;" - ; # deleting clusters created + # deleting clusters created bash -c "CLUSTER_NAME=${DATASITE_CLUSTER_NAME} tox -e dev.k8s.destroy || true" bash -c "k3d registry delete k3d-registry.localhost || true" bash -c "docker volume rm k3d-${DATASITE_CLUSTER_NAME}-images --force || true" @@ -990,7 +991,7 @@ commands= bash -c 'helm repo add signoz https://charts.signoz.io && helm repo update' bash -c 'helm install syft signoz/signoz --namespace platform --create-namespace || true' # bash -c 'k3d cluster edit ${CLUSTER_NAME} --port-add "${SIGNOZ_PORT}:3301@loadbalancer"' - # bash packages/grid/scripts/wait_for.sh service syft-signoz-frontend --context k3d-{env:CLUSTER_NAME} --namespace platform + ; bash packages/grid/scripts/wait_for.sh service syft-signoz-frontend --context k3d-{env:CLUSTER_NAME} --namespace platform [testenv:dev.k8s.start] @@ -1027,15 +1028,22 @@ changedir = {toxinidir}/packages/grid passenv = HOME, USER, DEVSPACE_PROFILE setenv= CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} + TRACING = {env:TRACING:False} allowlist_externals = bash commands = ; deploy syft helm charts bash -c 'echo "profile=$DEVSPACE_PROFILE"' + bash -c "echo Running with TRACING=$TRACING; date" bash -c '\ if [[ -n "${DEVSPACE_PROFILE}" ]]; then export DEVSPACE_PROFILE="-p ${DEVSPACE_PROFILE}"; fi && \ + if [[ "${TRACING}" == "True" ]]; then DEVSPACE_PROFILE="${DEVSPACE_PROFILE} -p tracing"; fi && \ + if [[ "${TRACING}" == "True" ]]; then echo "TRACING PROFILE ENABLED"; fi && \ devspace deploy -b --kube-context k3d-${CLUSTER_NAME} --no-warn ${DEVSPACE_PROFILE} --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' + # if TRACING is enabled start signoz + ; bash -c 'if [[ "${TRACING}" == "True" ]]; then tox -e dev.k8s.install.signoz; fi' + [testenv:dev.k8s.hotreload] description = Start development with hot-reload in Kubernetes changedir = {toxinidir}/packages/grid From f73ecf4791918d59e0bbd6569504cc6907aad869 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 2 Sep 2024 20:01:52 +1000 Subject: [PATCH 19/20] Changed syntax - Added default value server.tracing=False --- packages/grid/devspace.yaml | 7 ++++--- packages/grid/helm/examples/dev/base.yaml | 1 + .../helm/syft/templates/backend/backend-statefulset.yaml | 6 +----- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 034067d4316..1afb7a7ce86 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -127,9 +127,10 @@ profiles: - name: tracing description: "Enable Tracing" patches: - - op: replace - path: deployments.syft.helm.values.server.tracing - value: "True" + - op: add + path: deployments.syft.helm.values.server + value: + tracing: "True" - name: bigquery-scenario-tests description: "Deploy a datasite for bigquery scenario testing" diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 9b14fbe29ed..7ca60412b57 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -7,6 +7,7 @@ global: server: rootEmail: info@openmined.org associationRequestAutoApproval: true + tracing: false resourcesPreset: null resources: null diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 15a4bafc3ac..43b67e5557f 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -143,11 +143,7 @@ spec: value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }} # Tracing - name: TRACING - {{- if .Values.server.tracing }} - value: {{ .Values.server.tracing | quote }} - {{- else }} - value: "False" - {{- end }} + value: {{ .Values.server.tracing | default "False" | quote }} # Enclave attestation {{- if .Values.attestation.enabled }} - name: ENCLAVE_ATTESTATION_ENABLED From cee25b0eca59ca6b826b5231e20d30ddc454b68d Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Mon, 2 Sep 2024 15:41:06 +0530 Subject: [PATCH 20/20] [project] set tracing off by default --- packages/grid/backend/backend.dockerfile | 1 + packages/grid/backend/grid/start.sh | 5 +---- packages/grid/devspace.yaml | 2 +- packages/grid/helm/syft/values.yaml | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index 65e06695310..c51ba31c8fd 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -79,6 +79,7 @@ ENV \ RELEASE="production" \ DEV_MODE="False" \ DEBUGGER_ENABLED="False" \ + TRACING="False" \ CONTAINER_HOST="docker" \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 9795f99a1dc..383e455e408 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -11,7 +11,6 @@ SERVER_TYPE=${SERVER_TYPE:-datasite} APPDIR=${APPDIR:-$HOME/app} RELOAD="" ROOT_PROC="" -TRACING=${TRACING:"False"} echo "Starting with TRACING=${TRACING}" @@ -29,9 +28,7 @@ then ROOT_PROC="python -m debugpy --listen 0.0.0.0:5678 -m" fi - - -if [[ "${TRACING,,}" == "true" ]]; +if [[ ${TRACING} == "true" ]]; then echo "OpenTelemetry Enabled" diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 1afb7a7ce86..6b11e475068 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -130,7 +130,7 @@ profiles: - op: add path: deployments.syft.helm.values.server value: - tracing: "True" + tracing: true - name: bigquery-scenario-tests description: "Deploy a datasite for bigquery scenario testing" diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index b7d4bbcb882..6fb3bb8fe87 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -173,7 +173,7 @@ server: debuggerEnabled: false associationRequestAutoApproval: false useInternalRegistry: true - tracing: true + tracing: false # Default Worker pool settings defaultWorkerPool: