diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index 256fdfd447b..c51ba31c8fd 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -41,7 +41,7 @@ COPY syft/src/syft/VERSION ./syft/src/syft/ RUN --mount=type=cache,target=/root/.cache,sharing=locked \ # remove torch because we already have the cpu version pre-installed sed --in-place /torch==/d ./syft/setup.cfg && \ - uv pip install -e ./syft[data_science] + uv pip install -e ./syft[data_science,telemetry] # ==================== [Final] Setup Syft Server ==================== # @@ -79,6 +79,7 @@ ENV \ RELEASE="production" \ DEV_MODE="False" \ DEBUGGER_ENABLED="False" \ + TRACING="False" \ CONTAINER_HOST="docker" \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 0cd6e026e03..63bda939c29 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -159,6 +159,8 @@ def get_emails_enabled(self) -> Self: REVERSE_TUNNEL_ENABLED: bool = str_to_bool( os.getenv("REVERSE_TUNNEL_ENABLED", "false") ) + TRACING_ENABLED: bool = str_to_bool(os.getenv("TRACING", "False")) + model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 9802eece8e8..3f401d7e349 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -10,8 +10,8 @@ from syft.server.server import get_server_side_type from syft.server.server import get_server_type from syft.server.server import get_server_uid_env -from syft.service.queue.zmq_queue import ZMQClientConfig -from syft.service.queue.zmq_queue import ZMQQueueConfig +from syft.service.queue.zmq_client import ZMQClientConfig +from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig from syft.store.mongo_client import MongoStoreClientConfig diff --git a/packages/grid/backend/grid/main.py b/packages/grid/backend/grid/main.py index 497a2dd7a90..12af4165179 100644 --- a/packages/grid/backend/grid/main.py +++ b/packages/grid/backend/grid/main.py @@ -77,3 +77,58 @@ def healthcheck() -> dict[str, str]: probe on the pods backing the Service. """ return {"status": "ok"} + + +if settings.TRACING_ENABLED: + try: + # stdlib + import os + + endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", None) + # third party + from opentelemetry._logs import set_logger_provider + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk._logs import LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + from opentelemetry.sdk.resources import Resource + + logger_provider = LoggerProvider( + resource=Resource.create( + { + "service.name": "backend-container", + } + ), + ) + set_logger_provider(logger_provider) + + exporter = OTLPLogExporter(insecure=True, endpoint=endpoint) + + logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) + handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) + + # Attach OTLP handler to root logger + logging.getLogger().addHandler(handler) + logger = logging.getLogger(__name__) + message = "> Added OTEL BatchLogRecordProcessor" + print(message) + logger.info(message) + + except Exception as e: + print(f"Failed to load OTLPLogExporter. {e}") + + # third party + try: + # third party + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor.instrument_app(app) + message = "> Added OTEL FastAPIInstrumentor" + print(message) + logger = logging.getLogger(__name__) + logger.info(message) + except Exception as e: + error = f"Failed to load FastAPIInstrumentor. {e}" + print(error) + logger = logging.getLogger(__name__) + logger.error(message) diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 02194eb840b..383e455e408 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -1,7 +1,7 @@ #! /usr/bin/env bash set -e -echo "Running Syft with RELEASE=${RELEASE} and $(id)" +echo "Running Syft with RELEASE=${RELEASE}" APP_MODULE=grid.main:app LOG_LEVEL=${LOG_LEVEL:-info} @@ -10,19 +10,40 @@ PORT=${PORT:-80} SERVER_TYPE=${SERVER_TYPE:-datasite} APPDIR=${APPDIR:-$HOME/app} RELOAD="" -DEBUG_CMD="" +ROOT_PROC="" + +echo "Starting with TRACING=${TRACING}" if [[ ${DEV_MODE} == "True" ]]; then - echo "DEV_MODE Enabled" + echo "Hot-reload Enabled" RELOAD="--reload" fi # only set by kubernetes to avoid conflict with docker tests if [[ ${DEBUGGER_ENABLED} == "True" ]]; then + echo "Debugger Enabled" uv pip install debugpy - DEBUG_CMD="python -m debugpy --listen 0.0.0.0:5678 -m" + ROOT_PROC="python -m debugpy --listen 0.0.0.0:5678 -m" +fi + +if [[ ${TRACING} == "true" ]]; +then + echo "OpenTelemetry Enabled" + + # TODOs: + # ! Handle case when OTEL_EXPORTER_OTLP_ENDPOINT is not set. + # ! syft-signoz-otel-collector.platform:4317 should be plumbed through helm charts + # ? Kubernetes OTel operator is recommended by signoz + export OTEL_PYTHON_LOG_CORRELATION=${OTEL_PYTHON_LOG_CORRELATION:-true} + export OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_EXPORTER_OTLP_ENDPOINT:-"http://syft-signoz-otel-collector.platform:4317"} + export OTEL_EXPORTER_OTLP_PROTOCOL=${OTEL_EXPORTER_OTLP_PROTOCOL:-grpc} + + # TODO: uvicorn postfork is not stable with OpenTelemetry + # ROOT_PROC="opentelemetry-instrument" +else + echo "OpenTelemetry Disabled" fi export CREDENTIALS_PATH=${CREDENTIALS_PATH:-$HOME/data/creds/credentials.json} @@ -33,4 +54,4 @@ export SERVER_TYPE=$SERVER_TYPE echo "SERVER_UID=$SERVER_UID" echo "SERVER_TYPE=$SERVER_TYPE" -exec $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-config=$APPDIR/grid/logging.yaml --log-level $LOG_LEVEL "$APP_MODULE" +exec $ROOT_PROC uvicorn $RELOAD --host $HOST --port $PORT --log-config=$APPDIR/grid/logging.yaml "$APP_MODULE" diff --git a/packages/grid/default.env b/packages/grid/default.env index 3018a4c2ce2..e1bc5c42557 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -95,11 +95,6 @@ REDIS_HOST=redis CONTAINER_HOST=docker RELATIVE_PATH="" -# Jaeger -TRACE=False -JAEGER_HOST=localhost -JAEGER_PORT=14268 - # Syft SYFT_TUTORIAL_MODE=False ENABLE_WARNINGS=True diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 05247fcba12..6b11e475068 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -124,6 +124,14 @@ profiles: value: side: low + - name: tracing + description: "Enable Tracing" + patches: + - op: add + path: deployments.syft.helm.values.server + value: + tracing: true + - name: bigquery-scenario-tests description: "Deploy a datasite for bigquery scenario testing" patches: diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 9b14fbe29ed..7ca60412b57 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -7,6 +7,7 @@ global: server: rootEmail: info@openmined.org associationRequestAutoApproval: true + tracing: false resourcesPreset: null resources: null diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 3dcefcd0f6b..43b67e5557f 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -87,7 +87,7 @@ spec: {{- end }} {{- if .Values.server.debuggerEnabled }} - name: DEBUGGER_ENABLED - value: "true" + value: "True" {{- end }} {{- if eq .Values.server.type "gateway" }} - name: ASSOCIATION_REQUEST_AUTO_APPROVAL @@ -142,14 +142,8 @@ spec: - name: MIN_SIZE_BLOB_STORAGE_MB value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }} # Tracing - - name: TRACE - value: "false" - - name: SERVICE_NAME - value: "backend" - - name: JAEGER_HOST - value: "localhost" - - name: JAEGER_PORT - value: "14268" + - name: TRACING + value: {{ .Values.server.tracing | default "False" | quote }} # Enclave attestation {{- if .Values.attestation.enabled }} - name: ENCLAVE_ATTESTATION_ENABLED diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index c5b4a2ff5f5..6fb3bb8fe87 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -10,7 +10,6 @@ global: kaniko: version: "v1.23.2" - # ================================================================================= mongo: @@ -174,6 +173,7 @@ server: debuggerEnabled: false associationRequestAutoApproval: false useInternalRegistry: true + tracing: false # Default Worker pool settings defaultWorkerPool: @@ -185,7 +185,7 @@ server: smtp: # Existing secret for SMTP with key 'smtpPassword' existingSecret: null - host: smtp.sendgrid.net + host: hostname port: 587 from: noreply@openmined.org username: apikey diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index bb3739acd1c..6de365022e5 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -104,12 +104,20 @@ dev = aiosmtpd==1.4.6 telemetry = - opentelemetry-api==1.14.0 - opentelemetry-sdk==1.14.0 - opentelemetry-exporter-jaeger==1.14.0 - opentelemetry-instrumentation==0.35b0 - opentelemetry-instrumentation-requests==0.35b0 - ; opentelemetry-instrumentation-digma==0.9.1-beta.7 + opentelemetry-api==1.27.0 + opentelemetry-sdk==1.27.0 + opentelemetry-exporter-otlp==1.27.0 + opentelemetry-instrumentation==0.48b0 + opentelemetry-instrumentation-requests==0.48b0 + opentelemetry-instrumentation-fastapi==0.48b0 + opentelemetry-instrumentation-pymongo==0.48b0 + opentelemetry-instrumentation-botocore==0.48b0 + opentelemetry-instrumentation-logging==0.48b0 + ; opentelemetry-instrumentation-asyncio==0.48b0 + ; opentelemetry-instrumentation-sqlite3==0.48b0 + ; opentelemetry-instrumentation-threading==0.48b0 + ; opentelemetry-instrumentation-jinja2==0.48b0 + ; opentelemetry-instrumentation-system-metrics==0.48b0 # pytest>=8.0 broke pytest-lazy-fixture which doesn't seem to be actively maintained # temporarily pin to pytest<8 diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 12bf44b8d9b..a5aa64ca736 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -59,7 +59,6 @@ from ..util.autoreload import autoreload_enabled from ..util.markdown import as_markdown_python_code from ..util.notebook_ui.components.tabulator_template import build_tabulator_table -from ..util.telemetry import instrument from ..util.util import index_syft_by_module_name from ..util.util import prompt_warning_message from .connection import ServerConnection @@ -237,7 +236,6 @@ def is_valid(self) -> bool: return True -@instrument @serializable() class SyftAPICall(SyftObject): # version @@ -264,7 +262,6 @@ def __repr__(self) -> str: return f"SyftAPICall(path={self.path}, args={self.args}, kwargs={self.kwargs}, blocking={self.blocking})" -@instrument @serializable() class SyftAPIData(SyftBaseObject): # version @@ -874,7 +871,6 @@ def result_needs_api_update(api_call_result: Any) -> bool: return False -@instrument @serializable( attrs=[ "endpoints", diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 30586c4b796..093dee625db 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -55,7 +55,6 @@ from ..types.server_url import ServerURL from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.uid import UID -from ..util.telemetry import instrument from ..util.util import prompt_warning_message from ..util.util import thread_ident from ..util.util import verify_tls @@ -648,7 +647,6 @@ def get_client_type(self) -> type[SyftClient]: raise SyftException(message=f"Unknown server type {metadata.server_type}") -@instrument @serializable(canonical_name="SyftClient", version=1) class SyftClient: connection: ServerConnection @@ -1117,7 +1115,6 @@ def refresh_callback() -> SyftAPI: return _api -@instrument def connect( url: str | ServerURL = DEFAULT_SYFT_UI_ADDRESS, server: AbstractServer | None = None, @@ -1135,7 +1132,6 @@ def connect( return client_type(connection=connection) -@instrument def register( url: str | ServerURL, port: int, @@ -1155,7 +1151,6 @@ def register( ) -@instrument def login_as_guest( # HTTPConnection url: str | ServerURL = DEFAULT_SYFT_UI_ADDRESS, @@ -1179,7 +1174,6 @@ def login_as_guest( return _client.guest() -@instrument def login( email: str, # HTTPConnection diff --git a/packages/syft/src/syft/server/routes.py b/packages/syft/src/syft/server/routes.py index 6e3d138bf59..ec5bedd6d01 100644 --- a/packages/syft/src/syft/server/routes.py +++ b/packages/syft/src/syft/server/routes.py @@ -32,7 +32,6 @@ from ..service.user.user_service import UserService from ..types.errors import SyftException from ..types.uid import UID -from ..util.telemetry import TRACE_MODE from .credentials import SyftVerifyKey from .credentials import UserLoginCredentials from .worker import Worker @@ -41,15 +40,6 @@ def make_routes(worker: Worker) -> APIRouter: - if TRACE_MODE: - # third party - try: - # third party - from opentelemetry import trace - from opentelemetry.propagate import extract - except Exception as e: - logger.error("Failed to import opentelemetry", exc_info=e) - router = APIRouter() async def get_body(request: Request) -> bytes: @@ -159,15 +149,7 @@ def syft_new_api( request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE ) -> Response: user_verify_key: SyftVerifyKey = SyftVerifyKey.from_string(verify_key) - if TRACE_MODE: - with trace.get_tracer(syft_new_api.__module__).start_as_current_span( - syft_new_api.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_syft_new_api(user_verify_key, communication_protocol) - else: - return handle_syft_new_api(user_verify_key, communication_protocol) + return handle_syft_new_api(user_verify_key, communication_protocol) def handle_new_api_call(data: bytes) -> Response: obj_msg = deserialize(blob=data, from_bytes=True) @@ -182,15 +164,7 @@ def handle_new_api_call(data: bytes) -> Response: def syft_new_api_call( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: - if TRACE_MODE: - with trace.get_tracer(syft_new_api_call.__module__).start_as_current_span( - syft_new_api_call.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_new_api_call(data) - else: - return handle_new_api_call(data) + return handle_new_api_call(data) def handle_forgot_password(email: str, server: AbstractServer) -> Response: try: @@ -278,15 +252,7 @@ def login( email: Annotated[str, Body(example="info@openmined.org")], password: Annotated[str, Body(example="changethis")], ) -> Response: - if TRACE_MODE: - with trace.get_tracer(login.__module__).start_as_current_span( - login.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_login(email, password, worker) - else: - return handle_login(email, password, worker) + return handle_login(email, password, worker) @router.post("/reset_password", name="reset_password", status_code=200) def reset_password( @@ -294,42 +260,18 @@ def reset_password( token: Annotated[str, Body(...)], new_password: Annotated[str, Body(...)], ) -> Response: - if TRACE_MODE: - with trace.get_tracer(reset_password.__module__).start_as_current_span( - reset_password.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_reset_password(token, new_password, worker) - else: - return handle_reset_password(token, new_password, worker) + return handle_reset_password(token, new_password, worker) @router.post("/forgot_password", name="forgot_password", status_code=200) def forgot_password( request: Request, email: str = Body(..., embed=True) ) -> Response: - if TRACE_MODE: - with trace.get_tracer(forgot_password.__module__).start_as_current_span( - forgot_password.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_forgot_password(email, worker) - else: - return handle_forgot_password(email, worker) + return handle_forgot_password(email, worker) @router.post("/register", name="register", status_code=200) def register( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: - if TRACE_MODE: - with trace.get_tracer(register.__module__).start_as_current_span( - register.__qualname__, - context=extract(request.headers), - kind=trace.SpanKind.SERVER, - ): - return handle_register(data, worker) - else: - return handle_register(data, worker) + return handle_register(data, worker) return router diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 5c4ab1bea31..2815bcdd318 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -65,9 +65,9 @@ from ..service.queue.queue_stash import ActionQueueItem from ..service.queue.queue_stash import QueueItem from ..service.queue.queue_stash import QueueStash -from ..service.queue.zmq_queue import QueueConfig -from ..service.queue.zmq_queue import ZMQClientConfig -from ..service.queue.zmq_queue import ZMQQueueConfig +from ..service.queue.zmq_client import QueueConfig +from ..service.queue.zmq_client import ZMQClientConfig +from ..service.queue.zmq_client import ZMQQueueConfig from ..service.response import SyftError from ..service.response import SyftSuccess from ..service.service import AbstractService @@ -296,7 +296,6 @@ def auth_context_for_user( return cls.__server_context_registry__.get(key) -@instrument class Server(AbstractServer): signing_key: SyftSigningKey | None required_signed_calls: bool = True @@ -1050,6 +1049,7 @@ def await_future(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: return result sleep(0.1) + @instrument def resolve_future(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: queue_obj = self.queue_stash.pop_on_complete(credentials, uid).unwrap() queue_obj._set_obj_location_( @@ -1058,6 +1058,7 @@ def resolve_future(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: ) return queue_obj + @instrument def forward_message( self, api_call: SyftAPICall | SignedSyftAPICall ) -> Result | QueueItem | SyftObject | Any: @@ -1123,6 +1124,7 @@ def get_role_for_credentials(self, credentials: SyftVerifyKey) -> ServiceRole: .unwrap() ) + @instrument def handle_api_call( self, api_call: SyftAPICall | SignedSyftAPICall, @@ -1330,6 +1332,7 @@ def get_worker_pool_ref_by_name( ) return worker_pool_ref + @instrument @as_result(SyftException) def add_action_to_queue( self, @@ -1379,6 +1382,7 @@ def add_action_to_queue( user_id=user_id, ).unwrap() + @instrument @as_result(SyftException) def add_queueitem_to_queue( self, @@ -1480,6 +1484,7 @@ def _is_usercode_call_on_owned_kwargs( context, user_code_id, api_call.kwargs ) + @instrument def add_api_call_to_queue( self, api_call: SyftAPICall, parent_job_id: UID | None = None ) -> SyftSuccess: @@ -1597,6 +1602,7 @@ def get_worker_pool_by_name(self, name: str) -> WorkerPool: credentials=self.verify_key, pool_name=name ).unwrap() + @instrument def get_api( self, for_user: SyftVerifyKey | None = None, diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 3549d7a5987..0be6fb5ea9a 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -27,6 +27,7 @@ from ..client.client import API_PATH from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT +from ..util.telemetry import TRACING_ENABLED from ..util.util import os_name from .datasite import Datasite from .enclave import Enclave @@ -112,6 +113,14 @@ def app_factory() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) + + if TRACING_ENABLED: + # third party + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + FastAPIInstrumentor().instrument_app(app) + print("> Added OTEL FastAPIInstrumentor") + return app diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index cc9b34d940b..2f530b207bd 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -16,7 +16,6 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_service import ActionService from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -37,7 +36,6 @@ from .api_stash import TwinAPIEndpointStash -@instrument @serializable(canonical_name="APIService", version=1) class APIService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 1936dbe9047..1ffb70ebb6f 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -13,7 +13,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftSuccess from ..service import AbstractService @@ -24,7 +23,6 @@ from .user_code import UserCodeStatusCollection -@instrument @serializable(canonical_name="StatusStash", version=1) class StatusStash(NewBaseUIDStoreStash): object_type = UserCodeStatusCollection @@ -47,7 +45,6 @@ def get_by_uid( return self.query_one(credentials=credentials, qks=qks).unwrap() -@instrument @serializable(canonical_name="UserCodeStatusService", version=1) class UserCodeStatusService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index cdf26d5aed2..696f961932c 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -15,7 +15,6 @@ from ...types.syft_metaclass import Empty from ...types.twin_object import TwinObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission @@ -60,7 +59,6 @@ class IsExecutionAllowedEnum(str, Enum): OUTPUT_POLICY_NOT_APPROVED = "Execution denied: Output policy not approved" -@instrument @serializable(canonical_name="UserCodeService", version=1) class UserCodeService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 8fbd4206235..308de4d28bf 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -1,5 +1,3 @@ -# stdlib - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey @@ -10,7 +8,6 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.telemetry import instrument from .user_code import CodeHashPartitionKey from .user_code import ServiceFuncNamePartitionKey from .user_code import SubmitTimePartitionKey @@ -18,7 +15,6 @@ from .user_code import UserVerifyKeyPartitionKey -@instrument @serializable(canonical_name="UserCodeStash", version=1) class UserCodeStash(NewBaseUIDStoreStash): object_type = UserCode diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 9e328bc6124..5383e8c9dcb 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -6,7 +6,6 @@ from ...store.document_store import DocumentStore from ...store.document_store_errors import NotFoundException from ...types.uid import UID -from ...util.telemetry import instrument from ..code.user_code import SubmitUserCode from ..code.user_code import UserCode from ..code.user_code_service import UserCodeService @@ -24,7 +23,6 @@ from .code_history_stash import CodeHistoryStash -@instrument @serializable(canonical_name="CodeHistoryService", version=1) class CodeHistoryService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index e52435e994c..c54e8bc67ea 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -9,7 +9,6 @@ from ...store.document_store import QueryKeys from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftSuccess from ..service import AbstractService @@ -20,7 +19,6 @@ from .data_subject_member import ParentPartitionKey -@instrument @serializable(canonical_name="DataSubjectMemberStash", version=1) class DataSubjectMemberStash(NewBaseUIDStoreStash): object_type = DataSubjectMemberRelationship @@ -47,7 +45,6 @@ def get_all_for_child( return self.query_all(credentials=credentials, qks=qks).unwrap() -@instrument @serializable(canonical_name="DataSubjectMemberService", version=1) class DataSubjectMemberService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index fd248792842..b386fd1ec8d 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -11,7 +11,6 @@ from ...store.document_store import QueryKeys from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftSuccess from ..service import AbstractService @@ -24,7 +23,6 @@ from .data_subject_member_service import DataSubjectMemberService -@instrument @serializable(canonical_name="DataSubjectStash", version=1) class DataSubjectStash(NewBaseUIDStoreStash): object_type = DataSubject @@ -52,7 +50,6 @@ def update( return super().update(credentials=credentials, obj=res).unwrap() -@instrument @serializable(canonical_name="DataSubjectService", version=1) class DataSubjectService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index d57be61905c..cd347c11b35 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -9,7 +9,6 @@ from ...store.document_store import DocumentStore from ...types.dicttuple import DictTuple from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..action.action_service import ActionService @@ -70,7 +69,6 @@ def _paginate_dataset_collection( ) -@instrument @serializable(canonical_name="DatasetService", version=1) class DatasetService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 4ad1c506a98..19fc33c5906 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -10,7 +10,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from .dataset import Dataset from .dataset import DatasetUpdate @@ -18,7 +17,6 @@ ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable(canonical_name="DatasetStash", version=1) class DatasetStash(NewBaseUIDStoreStash): object_type = Dataset diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 78179ce7c90..0ad153e4bda 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -10,7 +10,6 @@ from ...store.document_store import DocumentStore from ...types.errors import SyftException from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission @@ -41,7 +40,6 @@ def wait_until(predicate: Callable[[], bool], timeout: int = 10) -> SyftSuccess: raise SyftException(public_message=f"Timeout reached for predicate {code_string}") -@instrument @serializable(canonical_name="JobService", version=1) class JobService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 6e75b4bc62f..413d36ba753 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -42,7 +42,6 @@ from ...types.transforms import make_set_default from ...types.uid import UID from ...util.markdown import as_markdown_code -from ...util.telemetry import instrument from ...util.util import prompt_warning_message from ..action.action_object import Action from ..action.action_object import ActionObject @@ -742,7 +741,6 @@ def from_job( return info -@instrument @serializable(canonical_name="JobStash", version=1) class JobStash(NewBaseUIDStoreStash): object_type = Job diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index 7aeb0eebe18..d3529b0906f 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -2,7 +2,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -15,7 +14,6 @@ from .log_stash import LogStash -@instrument @serializable(canonical_name="LogService", version=1) class LogService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index b3acc6cee11..c4072bfcfa5 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -3,11 +3,9 @@ from ...store.document_store import DocumentStore from ...store.document_store import NewBaseUIDStoreStash from ...store.document_store import PartitionSettings -from ...util.telemetry import instrument from .log import SyftLog -@instrument @serializable(canonical_name="LogStash", version=1) class LogStash(NewBaseUIDStoreStash): object_type = SyftLog diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index ccdb0c0a8ec..4e4e84d1364 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -3,7 +3,6 @@ # relative from ...serde.serializable import serializable from ...store.document_store import DocumentStore -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method @@ -11,7 +10,6 @@ from .server_metadata import ServerMetadata -@instrument @serializable(canonical_name="MetadataService", version=1) class MetadataService(AbstractService): def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4c04e87448d..5e4dfdbadf4 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -31,7 +31,6 @@ from ...types.transforms import transform from ...types.transforms import transform_method from ...types.uid import UID -from ...util.telemetry import instrument from ...util.util import generate_token from ...util.util import get_env from ...util.util import prompt_warning_message @@ -81,7 +80,6 @@ class ServerPeerAssociationStatus(Enum): PEER_NOT_FOUND = "PEER_NOT_FOUND" -@instrument @serializable(canonical_name="NetworkStash", version=1) class NetworkStash(NewBaseUIDStoreStash): object_type = ServerPeer @@ -166,7 +164,6 @@ def get_by_server_type( ).unwrap() -@instrument @serializable(canonical_name="NetworkService", version=1) class NetworkService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 410b764b8d4..89e15738376 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -8,7 +8,6 @@ from ...types.result import OkErr from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext from ..notifier.notifier import NotifierSettings @@ -28,7 +27,6 @@ from .notifications import ReplyNotification -@instrument @serializable(canonical_name="NotificationService", version=1) class NotificationService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index 3f1b603cd6b..fd41ad0dda6 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -15,7 +15,6 @@ from ...types.datetime import DateTime from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from .notifications import Notification from .notifications import NotificationStatus @@ -32,7 +31,6 @@ LinkedObjectPartitionKey = PartitionKey(key="linked_obj", type_=LinkedObject) -@instrument @serializable(canonical_name="NotificationStash", version=1) class NotificationStash(NewBaseUIDStoreStash): object_type = Notification diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 1084932a801..cb000d823f5 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -130,7 +130,9 @@ def send( self.smtp_client.send( # type: ignore sender=sender, receiver=receiver_email, subject=subject, body=body ) - print(f"> Sent email: {subject} to {receiver_email} from: {sender}") + message = f"> Sent email: {subject} to {receiver_email}" + print(message) + logging.info(message) return SyftSuccess(message="Email sent successfully!") except Exception as e: message = f"> Error sending email: {subject} to {receiver_email} from: {sender}. {e}" diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 6a42be39bdc..8dbe8e31e8f 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -13,7 +13,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from .notifier import NotifierSettings @@ -21,7 +20,6 @@ ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable(canonical_name="NotifierStash", version=1) class NotifierStash(NewBaseStash): object_type = NotifierSettings diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py index a901623c904..eef25440af8 100644 --- a/packages/syft/src/syft/service/notifier/smtp_client.py +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -50,7 +50,7 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non except Exception as e: logger.error(f"Unable to send email. {e}") raise SyftException( - public_message="Ops! Something went wrong while trying to send an email." + public_message="Oops! Something went wrong while trying to send an email." ) @classmethod @@ -77,4 +77,7 @@ def check_credentials( smtp_server.login(username, password) return True except Exception as e: + message = f"SMTP check_credentials failed. {e}" + print(message) + logger.error(message) raise SyftException(public_message=str(e)) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 16f97a5e0e5..e62a4baafc7 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -21,7 +21,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext @@ -184,7 +183,6 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return res -@instrument @serializable(canonical_name="OutputStash", version=1) class OutputStash(NewBaseUIDStoreStash): object_type = ExecutionOutput @@ -237,7 +235,6 @@ def get_by_output_policy_id( ).unwrap() -@instrument @serializable(canonical_name="OutputService", version=1) class OutputService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 6bad114070f..30af109ee09 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -12,7 +12,6 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..network.network_service import NetworkService from ..notification.notification_service import NotificationService @@ -35,7 +34,6 @@ from .project_stash import ProjectStash -@instrument @serializable(canonical_name="ProjectService", version=1) class ProjectService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index 5eddaf58b4d..bf81bd5b9b1 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -14,7 +14,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..request.request import Request from .project import Project @@ -23,7 +22,6 @@ NamePartitionKey = PartitionKey(key="name", type_=str) -@instrument @serializable(canonical_name="ProjectStash", version=1) class ProjectStash(NewBaseUIDStoreStash): object_type = Project diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index 8504484e7d5..c898893ee35 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -4,7 +4,6 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method @@ -13,7 +12,6 @@ from .queue_stash import QueueStash -@instrument @serializable(canonical_name="QueueService", version=1) class QueueService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 7b2dc3a48fb..251f4a9fb63 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -20,7 +20,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission @@ -94,7 +93,6 @@ class APIEndpointQueueItem(QueueItem): service: str = "apiservice" -@instrument @serializable(canonical_name="QueueStash", version=1) class QueueStash(NewBaseStash): object_type = QueueItem diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py new file mode 100644 index 00000000000..deeeb97a32b --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -0,0 +1,190 @@ +# stdlib +from collections import defaultdict +import socketserver + +# relative +from ...serde.serializable import serializable +from ...service.context import AuthedServiceContext +from ...types.errors import SyftException +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.util import get_queue_address +from ..response import SyftSuccess +from ..worker.worker_stash import WorkerStash +from .base_queue import AbstractMessageHandler +from .base_queue import QueueClient +from .base_queue import QueueClientConfig +from .base_queue import QueueConfig +from .queue_stash import QueueStash +from .zmq_consumer import ZMQConsumer +from .zmq_producer import ZMQProducer + + +@serializable() +class ZMQClientConfig(SyftObject, QueueClientConfig): + __canonical_name__ = "ZMQClientConfig" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID | None = None # type: ignore[assignment] + hostname: str = "127.0.0.1" + queue_port: int | None = None + # TODO: setting this to false until we can fix the ZMQ + # port issue causing tests to randomly fail + create_producer: bool = False + n_consumers: int = 0 + consumer_service: str | None = None + + +@serializable(attrs=["host"], canonical_name="ZMQClient", version=1) +class ZMQClient(QueueClient): + """ZMQ Client for creating producers and consumers.""" + + producers: dict[str, ZMQProducer] + consumers: defaultdict[str, list[ZMQConsumer]] + + def __init__(self, config: ZMQClientConfig) -> None: + self.host = config.hostname + self.producers = {} + self.consumers = defaultdict(list) + self.config = config + + @staticmethod + def _get_free_tcp_port(host: str) -> int: + with socketserver.TCPServer((host, 0), None) as s: + free_port = s.server_address[1] + + return free_port + + def add_producer( + self, + queue_name: str, + port: int | None = None, + queue_stash: QueueStash | None = None, + worker_stash: WorkerStash | None = None, + context: AuthedServiceContext | None = None, + ) -> ZMQProducer: + """Add a producer of a queue. + + A queue can have at most one producer attached to it. + """ + + if port is None: + if self.config.queue_port is None: + self.config.queue_port = self._get_free_tcp_port(self.host) + port = self.config.queue_port + else: + port = self.config.queue_port + + producer = ZMQProducer( + queue_name=queue_name, + queue_stash=queue_stash, + port=port, + context=context, + worker_stash=worker_stash, + ) + self.producers[queue_name] = producer + return producer + + def add_consumer( + self, + queue_name: str, + message_handler: AbstractMessageHandler, + service_name: str, + address: str | None = None, + worker_stash: WorkerStash | None = None, + syft_worker_id: UID | None = None, + ) -> ZMQConsumer: + """Add a consumer to a queue + + A queue should have at least one producer attached to the group. + + """ + + if address is None: + address = get_queue_address(self.config.queue_port) + + consumer = ZMQConsumer( + queue_name=queue_name, + message_handler=message_handler, + address=address, + service_name=service_name, + syft_worker_id=syft_worker_id, + worker_stash=worker_stash, + ) + self.consumers[queue_name].append(consumer) + + return consumer + + def send_message( + self, + message: bytes, + queue_name: str, + worker: bytes | None = None, + ) -> SyftSuccess: + producer = self.producers.get(queue_name) + if producer is None: + raise SyftException( + public_message=f"No producer attached for queue: {queue_name}. Please add a producer for it." + ) + try: + producer.send(message=message, worker=worker) + except Exception as e: + # stdlib + raise SyftException( + public_message=f"Failed to send message to: {queue_name} with error: {e}" + ) + return SyftSuccess( + message=f"Successfully queued message to : {queue_name}", + ) + + def close(self) -> SyftSuccess: + try: + for consumers in self.consumers.values(): + for consumer in consumers: + # make sure look is stopped + consumer.close() + + for producer in self.producers.values(): + # make sure loop is stopped + producer.close() + # close existing connection. + except Exception as e: + raise SyftException(public_message=f"Failed to close connection: {e}") + + return SyftSuccess(message="All connections closed.") + + def purge_queue(self, queue_name: str) -> SyftSuccess: + if queue_name not in self.producers: + raise SyftException( + public_message=f"No producer running for : {queue_name}" + ) + + producer = self.producers[queue_name] + + # close existing connection. + producer.close() + + # add a new connection + self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore + + return SyftSuccess(message=f"Queue: {queue_name} successfully purged") + + def purge_all(self) -> SyftSuccess: + for queue_name in self.producers: + self.purge_queue(queue_name=queue_name) + + return SyftSuccess(message="Successfully purged all queues.") + + +@serializable(canonical_name="ZMQQueueConfig", version=1) +class ZMQQueueConfig(QueueConfig): + def __init__( + self, + client_type: type[ZMQClient] | None = None, + client_config: ZMQClientConfig | None = None, + thread_workers: bool = False, + ): + self.client_type = client_type or ZMQClient + self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() + self.thread_workers = thread_workers diff --git a/packages/syft/src/syft/service/queue/zmq_common.py b/packages/syft/src/syft/service/queue/zmq_common.py new file mode 100644 index 00000000000..35331541c86 --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_common.py @@ -0,0 +1,120 @@ +# stdlib +import threading +import time +from typing import Any + +# third party +from pydantic import field_validator + +# relative +from ...server.credentials import SyftVerifyKey +from ...types.base import SyftBaseModel +from ...types.errors import SyftException +from ...types.result import as_result +from ...types.uid import UID +from ..worker.worker_pool import SyftWorker +from ..worker.worker_stash import WorkerStash + +# Producer/Consumer heartbeat interval (in seconds) +HEARTBEAT_INTERVAL_SEC = 2 + +# Thread join timeout (in seconds) +THREAD_TIMEOUT_SEC = 5 + +# Max duration (in ms) to wait for ZMQ poller to return +ZMQ_POLLER_TIMEOUT_MSEC = 1000 + +# Duration (in seconds) after which a worker without a heartbeat will be marked as expired +WORKER_TIMEOUT_SEC = 60 + +# Duration (in seconds) after which producer without a heartbeat will be marked as expired +PRODUCER_TIMEOUT_SEC = 60 + +# Lock for working on ZMQ socket +ZMQ_SOCKET_LOCK = threading.Lock() + +MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 + + +class ZMQHeader: + """Enum for ZMQ headers""" + + W_WORKER = b"MDPW01" + + +class ZMQCommand: + """Enum for ZMQ commands""" + + W_READY = b"0x01" + W_REQUEST = b"0x02" + W_REPLY = b"0x03" + W_HEARTBEAT = b"0x04" + W_DISCONNECT = b"0x05" + + +class Timeout: + def __init__(self, offset_sec: float): + self.__offset = float(offset_sec) + self.__next_ts: float = 0.0 + + self.reset() + + @property + def next_ts(self) -> float: + return self.__next_ts + + def reset(self) -> None: + self.__next_ts = self.now() + self.__offset + + def has_expired(self) -> bool: + return self.now() >= self.__next_ts + + @staticmethod + def now() -> float: + return time.time() + + +class Service: + def __init__(self, name: str) -> None: + self.name = name + self.requests: list[bytes] = [] + self.waiting: list[Worker] = [] # List of waiting workers + + +class Worker(SyftBaseModel): + address: bytes + identity: bytes + service: Service | None = None + syft_worker_id: UID | None = None + expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) + + @field_validator("syft_worker_id", mode="before") + @classmethod + def set_syft_worker_id(cls, v: Any) -> Any: + if isinstance(v, str): + return UID(v) + return v + + def has_expired(self) -> bool: + return self.expiry_t.has_expired() + + def get_expiry(self) -> float: + return self.expiry_t.next_ts + + def reset_expiry(self) -> None: + self.expiry_t.reset() + + @as_result(SyftException) + def _syft_worker( + self, stash: WorkerStash, credentials: SyftVerifyKey + ) -> SyftWorker | None: + return stash.get_by_uid( + credentials=credentials, uid=self.syft_worker_id + ).unwrap() + + def __str__(self) -> str: + svc = self.service.name if self.service else None + return ( + f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " + f"syft_worker_id={self.syft_worker_id!r})" + ) diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py new file mode 100644 index 00000000000..1d327b814de --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -0,0 +1,251 @@ +# stdlib +import logging +import threading +from threading import Event + +# third party +import zmq +from zmq import Frame +from zmq.error import ContextTerminated + +# relative +from ...serde.deserialize import _deserialize +from ...serde.serializable import serializable +from ...types.uid import UID +from ..worker.worker_pool import ConsumerState +from ..worker.worker_stash import WorkerStash +from .base_queue import AbstractMessageHandler +from .base_queue import QueueConsumer +from .zmq_common import HEARTBEAT_INTERVAL_SEC +from .zmq_common import PRODUCER_TIMEOUT_SEC +from .zmq_common import THREAD_TIMEOUT_SEC +from .zmq_common import Timeout +from .zmq_common import ZMQCommand +from .zmq_common import ZMQHeader +from .zmq_common import ZMQ_POLLER_TIMEOUT_MSEC +from .zmq_common import ZMQ_SOCKET_LOCK + +logger = logging.getLogger(__name__) + + +@serializable(attrs=["_subscriber"], canonical_name="ZMQConsumer", version=1) +class ZMQConsumer(QueueConsumer): + def __init__( + self, + message_handler: AbstractMessageHandler, + address: str, + queue_name: str, + service_name: str, + syft_worker_id: UID | None = None, + worker_stash: WorkerStash | None = None, + verbose: bool = False, + ) -> None: + self.address = address + self.message_handler = message_handler + self.service_name = service_name + self.queue_name = queue_name + self.context = zmq.Context() + self.poller = zmq.Poller() + self.socket = None + self.verbose = verbose + self.id = UID().short() + self._stop = Event() + self.syft_worker_id = syft_worker_id + self.worker_stash = worker_stash + self.post_init() + + def reconnect_to_producer(self) -> None: + """Connect or reconnect to producer""" + if self.socket: + self.poller.unregister(self.socket) # type: ignore[unreachable] + self.socket.close() + self.socket = self.context.socket(zmq.DEALER) + self.socket.linger = 0 + self.socket.setsockopt_string(zmq.IDENTITY, self.id) + self.socket.connect(self.address) + self.poller.register(self.socket, zmq.POLLIN) + + logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") + + # Register queue with the producer + self.send_to_producer( + ZMQCommand.W_READY, + [self.service_name.encode(), str(self.syft_worker_id).encode()], + ) + + def post_init(self) -> None: + self.thread: threading.Thread | None = None + self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) + self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) + self.reconnect_to_producer() + + def disconnect_from_producer(self) -> None: + self.send_to_producer(ZMQCommand.W_DISCONNECT) + + def close(self) -> None: + self.disconnect_from_producer() + self._stop.set() + try: + if self.thread is not None: + self.thread.join(timeout=THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQConsumer thread join timed out during closing. " + f"SyftWorker id {self.syft_worker_id}, " + f"service name {self.service_name}." + ) + self.thread = None + self.poller.unregister(self.socket) + except Exception as e: + logger.error("Failed to unregister worker.", exc_info=e) + finally: + self.socket.close() + self.context.destroy() + # self._stop.clear() + + def send_to_producer( + self, + command: bytes, + msg: bytes | list | None = None, + ) -> None: + """Send message to producer. + + If no msg is provided, creates one internally + """ + if self.socket.closed: + logger.warning("Socket is closed. Cannot send message.") + return + + if msg is None: + msg = [] + elif not isinstance(msg, list): + msg = [msg] + + # ZMQConsumer send frames: [empty, header, command, ...data] + core = [b"", ZMQHeader.W_WORKER, command] + msg = core + msg + + if command != ZMQCommand.W_HEARTBEAT: + logger.info(f"ZMQ Consumer send: {core}") + + with ZMQ_SOCKET_LOCK: + try: + self.socket.send_multipart(msg) + except zmq.ZMQError as e: + logger.error("ZMQConsumer send error", exc_info=e) + + def _run(self) -> None: + """Send reply, if any, to producer and wait for next request.""" + try: + while True: + if self._stop.is_set(): + logger.info("ZMQConsumer thread stopped") + return + + try: + items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except ContextTerminated: + logger.info("Context terminated") + return + except Exception as e: + logger.error("ZMQ poll error", exc_info=e) + continue + + if items: + msg = self.socket.recv_multipart() + + # mark as alive + self.set_producer_alive() + + if len(msg) < 3: + logger.error(f"ZMQConsumer invalid recv: {msg}") + continue + + # Message frames recieved by consumer: + # [empty, header, command, ...data] + (_, _, command, *data) = msg + + if command != ZMQCommand.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQConsumer recv: {msg[:-4]}") + + if command == ZMQCommand.W_REQUEST: + # Call Message Handler + try: + message = data.pop() + self.associate_job(message) + self.message_handler.handle_message( + message=message, + syft_worker_id=self.syft_worker_id, + ) + except Exception as e: + logger.exception("Couldn't handle message", exc_info=e) + finally: + self.clear_job() + elif command == ZMQCommand.W_HEARTBEAT: + self.set_producer_alive() + elif command == ZMQCommand.W_DISCONNECT: + self.reconnect_to_producer() + else: + logger.error(f"ZMQConsumer invalid command: {command}") + else: + if not self.is_producer_alive(): + logger.info("Producer check-alive timed out. Reconnecting.") + self.reconnect_to_producer() + self.set_producer_alive() + + if not self._stop.is_set(): + self.send_heartbeat() + + except zmq.ZMQError as e: + if e.errno == zmq.ETERM: + logger.info("zmq.ETERM") + else: + logger.exception("zmq.ZMQError", exc_info=e) + except Exception as e: + logger.exception("ZMQConsumer thread exception", exc_info=e) + + def set_producer_alive(self) -> None: + self.producer_ping_t.reset() + + def is_producer_alive(self) -> bool: + # producer timer is within timeout + return not self.producer_ping_t.has_expired() + + def send_heartbeat(self) -> None: + if self.heartbeat_t.has_expired() and self.is_producer_alive(): + self.send_to_producer(ZMQCommand.W_HEARTBEAT) + self.heartbeat_t.reset() + + def run(self) -> None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + def associate_job(self, message: Frame) -> None: + try: + queue_item = _deserialize(message, from_bytes=True) + self._set_worker_job(queue_item.job_id) + except Exception as e: + logger.exception("Could not associate job", exc_info=e) + + def clear_job(self) -> None: + self._set_worker_job(None) + + def _set_worker_job(self, job_id: UID | None) -> None: + if self.worker_stash is not None: + consumer_state = ( + ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING + ) + res = self.worker_stash.update_consumer_state( + credentials=self.worker_stash.partition.root_verify_key, + worker_uid=self.syft_worker_id, + consumer_state=consumer_state, + ) + if res.is_err(): + logger.error( + f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" + ) + + @property + def alive(self) -> bool: + return not self.socket.closed and self.is_producer_alive() diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py new file mode 100644 index 00000000000..85dbb0edbf0 --- /dev/null +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -0,0 +1,472 @@ +# stdlib +from binascii import hexlify +import itertools +import logging +import sys +import threading +from threading import Event +from time import sleep +from typing import Any +from typing import cast + +# third party +import zmq +from zmq import LINGER + +# relative +from ...serde.serializable import serializable +from ...serde.serialize import _serialize as serialize +from ...service.action.action_object import ActionObject +from ...service.context import AuthedServiceContext +from ...types.errors import SyftException +from ...types.result import as_result +from ...types.uid import UID +from ...util.util import get_queue_address +from ..service import AbstractService +from ..worker.worker_pool import ConsumerState +from ..worker.worker_stash import WorkerStash +from .base_queue import QueueProducer +from .queue_stash import ActionQueueItem +from .queue_stash import QueueStash +from .queue_stash import Status +from .zmq_common import HEARTBEAT_INTERVAL_SEC +from .zmq_common import Service +from .zmq_common import THREAD_TIMEOUT_SEC +from .zmq_common import Timeout +from .zmq_common import Worker +from .zmq_common import ZMQCommand +from .zmq_common import ZMQHeader +from .zmq_common import ZMQ_POLLER_TIMEOUT_MSEC +from .zmq_common import ZMQ_SOCKET_LOCK + +logger = logging.getLogger(__name__) + + +@serializable(canonical_name="ZMQProducer", version=1) +class ZMQProducer(QueueProducer): + INTERNAL_SERVICE_PREFIX = b"mmi." + + def __init__( + self, + queue_name: str, + queue_stash: QueueStash, + worker_stash: WorkerStash, + port: int, + context: AuthedServiceContext, + ) -> None: + self.id = UID().short() + self.port = port + self.queue_stash = queue_stash + self.worker_stash = worker_stash + self.queue_name = queue_name + self.auth_context = context + self._stop = Event() + self.post_init() + + @property + def address(self) -> str: + return get_queue_address(self.port) + + def post_init(self) -> None: + """Initialize producer state.""" + + self.services: dict[str, Service] = {} + self.workers: dict[bytes, Worker] = {} + self.waiting: list[Worker] = [] + self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) + self.context = zmq.Context(1) + self.socket = self.context.socket(zmq.ROUTER) + self.socket.setsockopt(LINGER, 1) + self.socket.setsockopt_string(zmq.IDENTITY, self.id) + self.poll_workers = zmq.Poller() + self.poll_workers.register(self.socket, zmq.POLLIN) + self.bind(f"tcp://*:{self.port}") + self.thread: threading.Thread | None = None + self.producer_thread: threading.Thread | None = None + + def close(self) -> None: + self._stop.set() + try: + if self.thread: + self.thread.join(THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQProducer message sending thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) + self.thread = None + + if self.producer_thread: + self.producer_thread.join(THREAD_TIMEOUT_SEC) + if self.producer_thread.is_alive(): + logger.error( + f"ZMQProducer queue thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) + self.producer_thread = None + + self.poll_workers.unregister(self.socket) + except Exception as e: + logger.exception("Failed to unregister poller.", exc_info=e) + finally: + self.socket.close() + self.context.destroy() + + @property + def action_service(self) -> AbstractService: + if self.auth_context.server is not None: + return self.auth_context.server.get_service("ActionService") + else: + raise Exception(f"{self.auth_context} does not have a server.") + + @as_result(SyftException) + def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: + """recursively check collections for unresolved action objects""" + if isinstance(arg, UID): + arg = self.action_service.get(self.auth_context, arg) + return self.contains_unresolved_action_objects( + arg, recursion=recursion + 1 + ).unwrap() + if isinstance(arg, ActionObject): + if not arg.syft_resolved: + arg = self.action_service.get(self.auth_context, arg) + if not arg.syft_resolved: + return True + arg = arg.syft_action_data + + value = False + if isinstance(arg, list): + for elem in arg: + value = self.contains_unresolved_action_objects( + elem, recursion=recursion + 1 + ).unwrap() + if value: + return True + if isinstance(arg, dict): + for elem in arg.values(): + value = self.contains_unresolved_action_objects( + elem, recursion=recursion + 1 + ).unwrap() + if value: + return True + return value + + def read_items(self) -> None: + while True: + if self._stop.is_set(): + break + try: + sleep(1) + + # Items to be queued + items_to_queue = self.queue_stash.get_by_status( + self.queue_stash.partition.root_verify_key, + status=Status.CREATED, + ).unwrap() + + items_to_queue = [] if items_to_queue is None else items_to_queue + + # Queue Items that are in the processing state + items_processing = self.queue_stash.get_by_status( + self.queue_stash.partition.root_verify_key, + status=Status.PROCESSING, + ).unwrap() + + items_processing = [] if items_processing is None else items_processing + + for item in itertools.chain(items_to_queue, items_processing): + # TODO: if resolving fails, set queueitem to errored, and jobitem as well + if item.status == Status.CREATED: + if isinstance(item, ActionQueueItem): + action = item.kwargs["action"] + if ( + self.contains_unresolved_action_objects( + action.args + ).unwrap() + or self.contains_unresolved_action_objects( + action.kwargs + ).unwrap() + ): + continue + + msg_bytes = serialize(item, to_bytes=True) + worker_pool = item.worker_pool.resolve_with_context( + self.auth_context + ).unwrap() + service_name = worker_pool.name + service: Service | None = self.services.get(service_name) + + # Skip adding message if corresponding service/pool + # is not registered. + if service is None: + continue + + # append request message to the corresponding service + # This list is processed in dispatch method. + + # TODO: Logic to evaluate the CAN RUN Condition + item.status = Status.PROCESSING + self.queue_stash.update( + item.syft_client_verify_key, item + ).unwrap(public_message=f"failed to update queue item {item}") + service.requests.append(msg_bytes) + elif item.status == Status.PROCESSING: + # Evaluate Retry condition here + # If job running and timeout or job status is KILL + # or heartbeat fails + # or container id doesn't exists, kill process or container + # else decrease retry count and mark status as CREATED. + pass + except Exception as e: + # stdlib + import traceback + + print(e, traceback.format_exc(), file=sys.stderr) + item.status = Status.ERRORED + self.queue_stash.update(item.syft_client_verify_key, item).unwrap() + + def run(self) -> None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + self.producer_thread = threading.Thread(target=self.read_items) + self.producer_thread.start() + + def send(self, worker: bytes, message: bytes | list[bytes]) -> None: + worker_obj = self.require_worker(worker) + self.send_to_worker(worker_obj, ZMQCommand.W_REQUEST, message) + + def bind(self, endpoint: str) -> None: + """Bind producer to endpoint.""" + self.socket.bind(endpoint) + logger.info(f"ZMQProducer endpoint: {endpoint}") + + def send_heartbeats(self) -> None: + """Send heartbeats to idle workers if it's time""" + if self.heartbeat_t.has_expired(): + for worker in self.waiting: + self.send_to_worker(worker, ZMQCommand.W_HEARTBEAT) + self.heartbeat_t.reset() + + def purge_workers(self) -> None: + """Look for & kill expired workers. + + Workers are oldest to most recent, so we stop at the first alive worker. + """ + # work on a copy of the iterator + for worker in self.waiting: + res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) + if res.is_err() or (syft_worker := res.ok()) is None: + logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}") + continue + + if worker.has_expired() or syft_worker.to_be_deleted: + logger.info(f"Deleting expired worker id={worker}") + self.delete_worker(worker, syft_worker.to_be_deleted) + + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.server.get_service(WorkerService) + ) + worker_service._delete(self.auth_context, syft_worker) + + def update_consumer_state_for_worker( + self, syft_worker_id: UID, consumer_state: ConsumerState + ) -> None: + if self.worker_stash is None: + logger.error( # type: ignore[unreachable] + f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" + ) + return + + try: + try: + self.worker_stash.get_by_uid( + credentials=self.worker_stash.partition.root_verify_key, + uid=syft_worker_id, + ).unwrap() + except Exception: + return None + + self.worker_stash.update_consumer_state( + credentials=self.worker_stash.partition.root_verify_key, + worker_uid=syft_worker_id, + consumer_state=consumer_state, + ).unwrap() + except Exception: + logger.exception( + f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}", + ) + + def worker_waiting(self, worker: Worker) -> None: + """This worker is now waiting for work.""" + # Queue to broker and service waiting lists + if worker not in self.waiting: + self.waiting.append(worker) + if worker.service is not None and worker not in worker.service.waiting: + worker.service.waiting.append(worker) + worker.reset_expiry() + self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) + self.dispatch(worker.service, None) + + def dispatch(self, service: Service, msg: bytes) -> None: + """Dispatch requests to waiting workers as possible""" + if msg is not None: # Queue message if any + service.requests.append(msg) + + self.purge_workers() + while service.waiting and service.requests: + # One worker consuming only one message at a time. + msg = service.requests.pop(0) + worker = service.waiting.pop(0) + self.waiting.remove(worker) + self.send_to_worker(worker, ZMQCommand.W_REQUEST, msg) + + def send_to_worker( + self, + worker: Worker, + command: bytes, + msg: bytes | list | None = None, + ) -> None: + """Send message to worker. + + If message is provided, sends that message. + """ + + if self.socket.closed: + logger.warning("Socket is closed. Cannot send message.") + return + + if msg is None: + msg = [] + elif not isinstance(msg, list): + msg = [msg] + + # ZMQProducer send frames: [address, empty, header, command, ...data] + core = [worker.address, b"", ZMQHeader.W_WORKER, command] + msg = core + msg + + if command != ZMQCommand.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer send: {core}") + + with ZMQ_SOCKET_LOCK: + try: + self.socket.send_multipart(msg) + except zmq.ZMQError: + logger.exception("ZMQProducer send error") + + def _run(self) -> None: + try: + while True: + if self._stop.is_set(): + logger.info("ZMQProducer thread stopped") + return + + for service in self.services.values(): + self.dispatch(service, None) + + items = None + + try: + items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except Exception as e: + logger.exception("ZMQProducer poll error", exc_info=e) + + if items: + msg = self.socket.recv_multipart() + + if len(msg) < 3: + logger.error(f"ZMQProducer invalid recv: {msg}") + continue + + # ZMQProducer recv frames: [address, empty, header, command, ...data] + (address, _, header, command, *data) = msg + + if command != ZMQCommand.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer recv: {msg[:4]}") + + if header == ZMQHeader.W_WORKER: + self.process_worker(address, command, data) + else: + logger.error(f"Invalid message header: {header}") + + self.send_heartbeats() + self.purge_workers() + except Exception as e: + logger.exception("ZMQProducer thread exception", exc_info=e) + + def require_worker(self, address: bytes) -> Worker: + """Finds the worker (creates if necessary).""" + identity = hexlify(address) + worker = self.workers.get(identity) + if worker is None: + worker = Worker(identity=identity, address=address) + self.workers[identity] = worker + return worker + + def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> None: + worker_ready = hexlify(address) in self.workers + worker = self.require_worker(address) + + if ZMQCommand.W_READY == command: + service_name = data.pop(0).decode() + syft_worker_id = data.pop(0).decode() + if worker_ready: + # Not first command in session or Reserved service name + # If worker was already present, then we disconnect it first + # and wait for it to re-register itself to the producer. This ensures that + # we always have a healthy worker in place that can talk to the producer. + self.delete_worker(worker, True) + else: + # Attach worker to service and mark as idle + if service_name in self.services: + service: Service | None = self.services.get(service_name) + else: + service = Service(service_name) + self.services[service_name] = service + if service is not None: + worker.service = service + logger.info(f"New worker: {worker}") + worker.syft_worker_id = UID(syft_worker_id) + self.worker_waiting(worker) + + elif ZMQCommand.W_HEARTBEAT == command: + if worker_ready: + # If worker is ready then reset expiry + # and add it to worker waiting list + # if not already present + self.worker_waiting(worker) + else: + logger.info(f"Got heartbeat, but worker not ready. {worker}") + self.delete_worker(worker, True) + elif ZMQCommand.W_DISCONNECT == command: + logger.info(f"Removing disconnected worker: {worker}") + self.delete_worker(worker, False) + else: + logger.error(f"Invalid command: {command!r}") + + def delete_worker(self, worker: Worker, disconnect: bool) -> None: + """Deletes worker from all data structures, and deletes worker.""" + if disconnect: + self.send_to_worker(worker, ZMQCommand.W_DISCONNECT) + + if worker.service and worker in worker.service.waiting: + worker.service.waiting.remove(worker) + + if worker in self.waiting: + self.waiting.remove(worker) + + self.workers.pop(worker.identity, None) + + if worker.syft_worker_id is not None: + self.update_consumer_state_for_worker( + worker.syft_worker_id, ConsumerState.DETACHED + ) + + @property + def alive(self) -> bool: + return not self.socket.closed diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py deleted file mode 100644 index fd53dc1c44a..00000000000 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ /dev/null @@ -1,974 +0,0 @@ -# stdlib -from binascii import hexlify -from collections import defaultdict -import itertools -import logging -import socketserver -import sys -import threading -from threading import Event -import time -from time import sleep -from typing import Any -from typing import cast - -# third party -from pydantic import field_validator -import zmq -from zmq import Frame -from zmq import LINGER -from zmq.error import ContextTerminated - -# relative -from ...serde.deserialize import _deserialize -from ...serde.serializable import serializable -from ...serde.serialize import _serialize as serialize -from ...server.credentials import SyftVerifyKey -from ...service.action.action_object import ActionObject -from ...service.context import AuthedServiceContext -from ...types.base import SyftBaseModel -from ...types.errors import SyftException -from ...types.result import as_result -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.uid import UID -from ...util.util import get_queue_address -from ..response import SyftSuccess -from ..service import AbstractService -from ..worker.worker_pool import ConsumerState -from ..worker.worker_pool import SyftWorker -from ..worker.worker_stash import WorkerStash -from .base_queue import AbstractMessageHandler -from .base_queue import QueueClient -from .base_queue import QueueClientConfig -from .base_queue import QueueConfig -from .base_queue import QueueConsumer -from .base_queue import QueueProducer -from .queue_stash import ActionQueueItem -from .queue_stash import QueueStash -from .queue_stash import Status - -# Producer/Consumer heartbeat interval (in seconds) -HEARTBEAT_INTERVAL_SEC = 2 - -# Thread join timeout (in seconds) -THREAD_TIMEOUT_SEC = 30 - -# Max duration (in ms) to wait for ZMQ poller to return -ZMQ_POLLER_TIMEOUT_MSEC = 1000 - -# Duration (in seconds) after which a worker without a heartbeat will be marked as expired -WORKER_TIMEOUT_SEC = 60 - -# Duration (in seconds) after which producer without a heartbeat will be marked as expired -PRODUCER_TIMEOUT_SEC = 60 - -# Lock for working on ZMQ socket -ZMQ_SOCKET_LOCK = threading.Lock() - -logger = logging.getLogger(__name__) - - -class QueueMsgProtocol: - W_WORKER = b"MDPW01" - W_READY = b"0x01" - W_REQUEST = b"0x02" - W_REPLY = b"0x03" - W_HEARTBEAT = b"0x04" - W_DISCONNECT = b"0x05" - - -MAX_RECURSION_NESTED_ACTIONOBJECTS = 5 - - -class Timeout: - def __init__(self, offset_sec: float): - self.__offset = float(offset_sec) - self.__next_ts: float = 0.0 - - self.reset() - - @property - def next_ts(self) -> float: - return self.__next_ts - - def reset(self) -> None: - self.__next_ts = self.now() + self.__offset - - def has_expired(self) -> bool: - return self.now() >= self.__next_ts - - @staticmethod - def now() -> float: - return time.time() - - -class Service: - def __init__(self, name: str) -> None: - self.name = name - self.requests: list[bytes] = [] - self.waiting: list[Worker] = [] # List of waiting workers - - -class Worker(SyftBaseModel): - address: bytes - identity: bytes - service: Service | None = None - syft_worker_id: UID | None = None - expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - - @field_validator("syft_worker_id", mode="before") - @classmethod - def set_syft_worker_id(cls, v: Any) -> Any: - if isinstance(v, str): - return UID(v) - return v - - def has_expired(self) -> bool: - return self.expiry_t.has_expired() - - def get_expiry(self) -> float: - return self.expiry_t.next_ts - - def reset_expiry(self) -> None: - self.expiry_t.reset() - - @as_result(SyftException) - def _syft_worker( - self, stash: WorkerStash, credentials: SyftVerifyKey - ) -> SyftWorker | None: - return stash.get_by_uid( - credentials=credentials, uid=self.syft_worker_id - ).unwrap() - - def __str__(self) -> str: - svc = self.service.name if self.service else None - return ( - f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " - f"syft_worker_id={self.syft_worker_id!r})" - ) - - -@serializable(canonical_name="ZMQProducer", version=1) -class ZMQProducer(QueueProducer): - INTERNAL_SERVICE_PREFIX = b"mmi." - - def __init__( - self, - queue_name: str, - queue_stash: QueueStash, - worker_stash: WorkerStash, - port: int, - context: AuthedServiceContext, - ) -> None: - self.id = UID().short() - self.port = port - self.queue_stash = queue_stash - self.worker_stash = worker_stash - self.queue_name = queue_name - self.auth_context = context - self._stop = Event() - self.post_init() - - @property - def address(self) -> str: - return get_queue_address(self.port) - - def post_init(self) -> None: - """Initialize producer state.""" - - self.services: dict[str, Service] = {} - self.workers: dict[bytes, Worker] = {} - self.waiting: list[Worker] = [] - self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) - self.context = zmq.Context(1) - self.socket = self.context.socket(zmq.ROUTER) - self.socket.setsockopt(LINGER, 1) - self.socket.setsockopt_string(zmq.IDENTITY, self.id) - self.poll_workers = zmq.Poller() - self.poll_workers.register(self.socket, zmq.POLLIN) - self.bind(f"tcp://*:{self.port}") - self.thread: threading.Thread | None = None - self.producer_thread: threading.Thread | None = None - - def close(self) -> None: - self._stop.set() - try: - if self.thread: - self.thread.join(THREAD_TIMEOUT_SEC) - if self.thread.is_alive(): - logger.error( - f"ZMQProducer message sending thread join timed out during closing. " - f"Queue name {self.queue_name}, " - ) - self.thread = None - - if self.producer_thread: - self.producer_thread.join(THREAD_TIMEOUT_SEC) - if self.producer_thread.is_alive(): - logger.error( - f"ZMQProducer queue thread join timed out during closing. " - f"Queue name {self.queue_name}, " - ) - self.producer_thread = None - - self.poll_workers.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister poller.", exc_info=e) - finally: - self.socket.close() - self.context.destroy() - - # self._stop.clear() - - @property - def action_service(self) -> AbstractService: - if self.auth_context.server is not None: - return self.auth_context.server.get_service("ActionService") - else: - raise Exception(f"{self.auth_context} does not have a server.") - - @as_result(SyftException) - def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: - """recursively check collections for unresolved action objects""" - if isinstance(arg, UID): - arg = self.action_service.get(self.auth_context, arg) - return self.contains_unresolved_action_objects( - arg, recursion=recursion + 1 - ).unwrap() - if isinstance(arg, ActionObject): - if not arg.syft_resolved: - arg = self.action_service.get(self.auth_context, arg) - if not arg.syft_resolved: - return True - arg = arg.syft_action_data - - value = False - if isinstance(arg, list): - for elem in arg: - value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 - ).unwrap() - if value: - return True - if isinstance(arg, dict): - for elem in arg.values(): - value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 - ).unwrap() - if value: - return True - return value - - def read_items(self) -> None: - while True: - if self._stop.is_set(): - break - try: - sleep(1) - - # Items to be queued - items_to_queue = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, - status=Status.CREATED, - ).unwrap() - - items_to_queue = [] if items_to_queue is None else items_to_queue - - # Queue Items that are in the processing state - items_processing = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, - status=Status.PROCESSING, - ).unwrap() - - items_processing = [] if items_processing is None else items_processing - - for item in itertools.chain(items_to_queue, items_processing): - # TODO: if resolving fails, set queueitem to errored, and jobitem as well - if item.status == Status.CREATED: - if isinstance(item, ActionQueueItem): - action = item.kwargs["action"] - if ( - self.contains_unresolved_action_objects( - action.args - ).unwrap() - or self.contains_unresolved_action_objects( - action.kwargs - ).unwrap() - ): - continue - - msg_bytes = serialize(item, to_bytes=True) - worker_pool = item.worker_pool.resolve_with_context( - self.auth_context - ).unwrap() - service_name = worker_pool.name - service: Service | None = self.services.get(service_name) - - # Skip adding message if corresponding service/pool - # is not registered. - if service is None: - continue - - # append request message to the corresponding service - # This list is processed in dispatch method. - - # TODO: Logic to evaluate the CAN RUN Condition - item.status = Status.PROCESSING - self.queue_stash.update( - item.syft_client_verify_key, item - ).unwrap(public_message=f"failed to update queue item {item}") - service.requests.append(msg_bytes) - elif item.status == Status.PROCESSING: - # Evaluate Retry condition here - # If job running and timeout or job status is KILL - # or heartbeat fails - # or container id doesn't exists, kill process or container - # else decrease retry count and mark status as CREATED. - pass - except Exception as e: - # stdlib - import traceback - - print(e, traceback.format_exc(), file=sys.stderr) - item.status = Status.ERRORED - self.queue_stash.update(item.syft_client_verify_key, item).unwrap() - - def run(self) -> None: - self.thread = threading.Thread(target=self._run) - self.thread.start() - - self.producer_thread = threading.Thread(target=self.read_items) - self.producer_thread.start() - - def send(self, worker: bytes, message: bytes | list[bytes]) -> None: - worker_obj = self.require_worker(worker) - self.send_to_worker(worker_obj, QueueMsgProtocol.W_REQUEST, message) - - def bind(self, endpoint: str) -> None: - """Bind producer to endpoint.""" - self.socket.bind(endpoint) - logger.info(f"ZMQProducer endpoint: {endpoint}") - - def send_heartbeats(self) -> None: - """Send heartbeats to idle workers if it's time""" - if self.heartbeat_t.has_expired(): - for worker in self.waiting: - self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT) - self.heartbeat_t.reset() - - def purge_workers(self) -> None: - """Look for & kill expired workers. - - Workers are oldest to most recent, so we stop at the first alive worker. - """ - # work on a copy of the iterator - for worker in self.waiting: - res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) - if res.is_err() or (syft_worker := res.ok()) is None: - logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}") - continue - - if worker.has_expired() or syft_worker.to_be_deleted: - logger.info(f"Deleting expired worker id={worker}") - self.delete_worker(worker, syft_worker.to_be_deleted) - - # relative - from ...service.worker.worker_service import WorkerService - - worker_service = cast( - WorkerService, self.auth_context.server.get_service(WorkerService) - ) - worker_service._delete(self.auth_context, syft_worker) - - def update_consumer_state_for_worker( - self, syft_worker_id: UID, consumer_state: ConsumerState - ) -> None: - if self.worker_stash is None: - logger.error( # type: ignore[unreachable] - f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" - ) - return - - try: - try: - self.worker_stash.get_by_uid( - credentials=self.worker_stash.partition.root_verify_key, - uid=syft_worker_id, - ).unwrap() - except Exception: - return None - - self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, - worker_uid=syft_worker_id, - consumer_state=consumer_state, - ).unwrap() - except Exception: - logger.exception( - f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}", - ) - - def worker_waiting(self, worker: Worker) -> None: - """This worker is now waiting for work.""" - # Queue to broker and service waiting lists - if worker not in self.waiting: - self.waiting.append(worker) - if worker.service is not None and worker not in worker.service.waiting: - worker.service.waiting.append(worker) - worker.reset_expiry() - self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) - self.dispatch(worker.service, None) - - def dispatch(self, service: Service, msg: bytes) -> None: - """Dispatch requests to waiting workers as possible""" - if msg is not None: # Queue message if any - service.requests.append(msg) - - self.purge_workers() - while service.waiting and service.requests: - # One worker consuming only one message at a time. - msg = service.requests.pop(0) - worker = service.waiting.pop(0) - self.waiting.remove(worker) - self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, msg) - - def send_to_worker( - self, - worker: Worker, - command: bytes, - msg: bytes | list | None = None, - ) -> None: - """Send message to worker. - - If message is provided, sends that message. - """ - - if self.socket.closed: - logger.warning("Socket is closed. Cannot send message.") - return - - if msg is None: - msg = [] - elif not isinstance(msg, list): - msg = [msg] - - # ZMQProducer send frames: [address, empty, header, command, ...data] - core = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] - msg = core + msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQProducer send: {core}") - - with ZMQ_SOCKET_LOCK: - try: - self.socket.send_multipart(msg) - except zmq.ZMQError: - logger.exception("ZMQProducer send error") - - def _run(self) -> None: - try: - while True: - if self._stop.is_set(): - logger.info("ZMQProducer thread stopped") - return - - for service in self.services.values(): - self.dispatch(service, None) - - items = None - - try: - items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except Exception as e: - logger.exception("ZMQProducer poll error", exc_info=e) - - if items: - msg = self.socket.recv_multipart() - - if len(msg) < 3: - logger.error(f"ZMQProducer invalid recv: {msg}") - continue - - # ZMQProducer recv frames: [address, empty, header, command, ...data] - (address, _, header, command, *data) = msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQProducer recv: {msg[:4]}") - - if header == QueueMsgProtocol.W_WORKER: - self.process_worker(address, command, data) - else: - logger.error(f"Invalid message header: {header}") - - self.send_heartbeats() - self.purge_workers() - except Exception as e: - logger.exception("ZMQProducer thread exception", exc_info=e) - - def require_worker(self, address: bytes) -> Worker: - """Finds the worker (creates if necessary).""" - identity = hexlify(address) - worker = self.workers.get(identity) - if worker is None: - worker = Worker(identity=identity, address=address) - self.workers[identity] = worker - return worker - - def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> None: - worker_ready = hexlify(address) in self.workers - worker = self.require_worker(address) - - if QueueMsgProtocol.W_READY == command: - service_name = data.pop(0).decode() - syft_worker_id = data.pop(0).decode() - if worker_ready: - # Not first command in session or Reserved service name - # If worker was already present, then we disconnect it first - # and wait for it to re-register itself to the producer. This ensures that - # we always have a healthy worker in place that can talk to the producer. - self.delete_worker(worker, True) - else: - # Attach worker to service and mark as idle - if service_name in self.services: - service: Service | None = self.services.get(service_name) - else: - service = Service(service_name) - self.services[service_name] = service - if service is not None: - worker.service = service - logger.info(f"New worker: {worker}") - worker.syft_worker_id = UID(syft_worker_id) - self.worker_waiting(worker) - - elif QueueMsgProtocol.W_HEARTBEAT == command: - if worker_ready: - # If worker is ready then reset expiry - # and add it to worker waiting list - # if not already present - self.worker_waiting(worker) - else: - logger.info(f"Got heartbeat, but worker not ready. {worker}") - self.delete_worker(worker, True) - elif QueueMsgProtocol.W_DISCONNECT == command: - logger.info(f"Removing disconnected worker: {worker}") - self.delete_worker(worker, False) - else: - logger.error(f"Invalid command: {command!r}") - - def delete_worker(self, worker: Worker, disconnect: bool) -> None: - """Deletes worker from all data structures, and deletes worker.""" - if disconnect: - self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT) - - if worker.service and worker in worker.service.waiting: - worker.service.waiting.remove(worker) - - if worker in self.waiting: - self.waiting.remove(worker) - - self.workers.pop(worker.identity, None) - - if worker.syft_worker_id is not None: - self.update_consumer_state_for_worker( - worker.syft_worker_id, ConsumerState.DETACHED - ) - - @property - def alive(self) -> bool: - return not self.socket.closed - - -@serializable(attrs=["_subscriber"], canonical_name="ZMQConsumer", version=1) -class ZMQConsumer(QueueConsumer): - def __init__( - self, - message_handler: AbstractMessageHandler, - address: str, - queue_name: str, - service_name: str, - syft_worker_id: UID | None = None, - worker_stash: WorkerStash | None = None, - verbose: bool = False, - ) -> None: - self.address = address - self.message_handler = message_handler - self.service_name = service_name - self.queue_name = queue_name - self.context = zmq.Context() - self.poller = zmq.Poller() - self.socket = None - self.verbose = verbose - self.id = UID().short() - self._stop = Event() - self.syft_worker_id = syft_worker_id - self.worker_stash = worker_stash - self.post_init() - - def reconnect_to_producer(self) -> None: - """Connect or reconnect to producer""" - if self.socket: - self.poller.unregister(self.socket) # type: ignore[unreachable] - self.socket.close() - self.socket = self.context.socket(zmq.DEALER) - self.socket.linger = 0 - self.socket.setsockopt_string(zmq.IDENTITY, self.id) - self.socket.connect(self.address) - self.poller.register(self.socket, zmq.POLLIN) - - logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") - - # Register queue with the producer - self.send_to_producer( - QueueMsgProtocol.W_READY, - [self.service_name.encode(), str(self.syft_worker_id).encode()], - ) - - def post_init(self) -> None: - self.thread: threading.Thread | None = None - self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) - self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) - self.reconnect_to_producer() - - def disconnect_from_producer(self) -> None: - self.send_to_producer(QueueMsgProtocol.W_DISCONNECT) - - def close(self) -> None: - self.disconnect_from_producer() - self._stop.set() - try: - if self.thread is not None: - self.thread.join(timeout=THREAD_TIMEOUT_SEC) - if self.thread.is_alive(): - logger.error( - f"ZMQConsumer thread join timed out during closing. " - f"SyftWorker id {self.syft_worker_id}, " - f"service name {self.service_name}." - ) - self.thread = None - self.poller.unregister(self.socket) - except Exception: - logger.exception("Failed to unregister worker.") - finally: - self.socket.close() - self.context.destroy() - # self._stop.clear() - - def send_to_producer( - self, - command: bytes, - msg: bytes | list | None = None, - ) -> None: - """Send message to producer. - - If no msg is provided, creates one internally - """ - if self.socket.closed: - logger.warning("Socket is closed. Cannot send message.") - return - - if msg is None: - msg = [] - elif not isinstance(msg, list): - msg = [msg] - - # ZMQConsumer send frames: [empty, header, command, ...data] - core = [b"", QueueMsgProtocol.W_WORKER, command] - msg = core + msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - logger.info(f"ZMQ Consumer send: {core}") - - with ZMQ_SOCKET_LOCK: - try: - self.socket.send_multipart(msg) - except zmq.ZMQError: - logger.exception("ZMQConsumer send error") - - def _run(self) -> None: - """Send reply, if any, to producer and wait for next request.""" - try: - while True: - if self._stop.is_set(): - logger.info("ZMQConsumer thread stopped") - return - - try: - items = self.poller.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except ContextTerminated: - logger.info("Context terminated") - return - except Exception as e: - logger.error("ZMQ poll error", exc_info=e) - continue - - if items: - msg = self.socket.recv_multipart() - - # mark as alive - self.set_producer_alive() - - if len(msg) < 3: - logger.error(f"ZMQConsumer invalid recv: {msg}") - continue - - # Message frames recieved by consumer: - # [empty, header, command, ...data] - (_, _, command, *data) = msg - - if command != QueueMsgProtocol.W_HEARTBEAT: - # log everything except the last frame which contains serialized data - logger.info(f"ZMQConsumer recv: {msg[:-4]}") - - if command == QueueMsgProtocol.W_REQUEST: - # Call Message Handler - try: - message = data.pop() - self.associate_job(message) - self.message_handler.handle_message( - message=message, - syft_worker_id=self.syft_worker_id, - ) - except Exception as e: - logger.exception("Couldn't handle message", exc_info=e) - finally: - self.clear_job() - elif command == QueueMsgProtocol.W_HEARTBEAT: - self.set_producer_alive() - elif command == QueueMsgProtocol.W_DISCONNECT: - self.reconnect_to_producer() - else: - logger.error(f"ZMQConsumer invalid command: {command}") - else: - if not self.is_producer_alive(): - logger.info("Producer check-alive timed out. Reconnecting.") - self.reconnect_to_producer() - self.set_producer_alive() - - if not self._stop.is_set(): - self.send_heartbeat() - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - logger.info("zmq.ETERM") - else: - logger.exception("zmq.ZMQError", exc_info=e) - except Exception as e: - logger.exception("ZMQConsumer thread exception", exc_info=e) - - def set_producer_alive(self) -> None: - self.producer_ping_t.reset() - - def is_producer_alive(self) -> bool: - # producer timer is within timeout - return not self.producer_ping_t.has_expired() - - def send_heartbeat(self) -> None: - if self.heartbeat_t.has_expired() and self.is_producer_alive(): - self.send_to_producer(QueueMsgProtocol.W_HEARTBEAT) - self.heartbeat_t.reset() - - def run(self) -> None: - self.thread = threading.Thread(target=self._run) - self.thread.start() - - def associate_job(self, message: Frame) -> None: - try: - queue_item = _deserialize(message, from_bytes=True) - self._set_worker_job(queue_item.job_id) - except Exception as e: - logger.exception("Could not associate job", exc_info=e) - - def clear_job(self) -> None: - self._set_worker_job(None) - - def _set_worker_job(self, job_id: UID | None) -> None: - if self.worker_stash is not None: - consumer_state = ( - ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING - ) - try: - self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, - worker_uid=self.syft_worker_id, - consumer_state=consumer_state, - ).unwrap() - except SyftException as exc: - logger.error( - f"Failed to update consumer state for {self.service_name}-{self.id}, error={exc.public}" - ) - - @property - def alive(self) -> bool: - return not self.socket.closed and self.is_producer_alive() - - -@serializable() -class ZMQClientConfig(SyftObject, QueueClientConfig): - __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID | None = None # type: ignore[assignment] - hostname: str = "127.0.0.1" - queue_port: int | None = None - # TODO: setting this to false until we can fix the ZMQ - # port issue causing tests to randomly fail - create_producer: bool = False - n_consumers: int = 0 - consumer_service: str | None = None - - -@serializable(attrs=["host"], canonical_name="ZMQClient", version=1) -class ZMQClient(QueueClient): - """ZMQ Client for creating producers and consumers.""" - - producers: dict[str, ZMQProducer] - consumers: defaultdict[str, list[ZMQConsumer]] - - def __init__(self, config: ZMQClientConfig) -> None: - self.host = config.hostname - self.producers = {} - self.consumers = defaultdict(list) - self.config = config - - @staticmethod - def _get_free_tcp_port(host: str) -> int: - with socketserver.TCPServer((host, 0), None) as s: - free_port = s.server_address[1] - - return free_port - - def add_producer( - self, - queue_name: str, - port: int | None = None, - queue_stash: QueueStash | None = None, - worker_stash: WorkerStash | None = None, - context: AuthedServiceContext | None = None, - ) -> ZMQProducer: - """Add a producer of a queue. - - A queue can have at most one producer attached to it. - """ - - if port is None: - if self.config.queue_port is None: - self.config.queue_port = self._get_free_tcp_port(self.host) - port = self.config.queue_port - else: - port = self.config.queue_port - - producer = ZMQProducer( - queue_name=queue_name, - queue_stash=queue_stash, - port=port, - context=context, - worker_stash=worker_stash, - ) - self.producers[queue_name] = producer - return producer - - def add_consumer( - self, - queue_name: str, - message_handler: AbstractMessageHandler, - service_name: str, - address: str | None = None, - worker_stash: WorkerStash | None = None, - syft_worker_id: UID | None = None, - ) -> ZMQConsumer: - """Add a consumer to a queue - - A queue should have at least one producer attached to the group. - - """ - - if address is None: - address = get_queue_address(self.config.queue_port) - - consumer = ZMQConsumer( - queue_name=queue_name, - message_handler=message_handler, - address=address, - service_name=service_name, - syft_worker_id=syft_worker_id, - worker_stash=worker_stash, - ) - self.consumers[queue_name].append(consumer) - - return consumer - - def send_message( - self, - message: bytes, - queue_name: str, - worker: bytes | None = None, - ) -> SyftSuccess: - producer = self.producers.get(queue_name) - if producer is None: - raise SyftException( - public_message=f"No producer attached for queue: {queue_name}. Please add a producer for it." - ) - try: - producer.send(message=message, worker=worker) - except Exception as e: - # stdlib - raise SyftException( - public_message=f"Failed to send message to: {queue_name} with error: {e}" - ) - return SyftSuccess( - message=f"Successfully queued message to : {queue_name}", - ) - - def close(self) -> SyftSuccess: - try: - for consumers in self.consumers.values(): - for consumer in consumers: - # make sure look is stopped - consumer.close() - - for producer in self.producers.values(): - # make sure loop is stopped - producer.close() - # close existing connection. - except Exception as e: - raise SyftException(public_message=f"Failed to close connection: {e}") - - return SyftSuccess(message="All connections closed.") - - def purge_queue(self, queue_name: str) -> SyftSuccess: - if queue_name not in self.producers: - raise SyftException( - public_message=f"No producer running for : {queue_name}" - ) - - producer = self.producers[queue_name] - - # close existing connection. - producer.close() - - # add a new connection - self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore - - return SyftSuccess(message=f"Queue: {queue_name} successfully purged") - - def purge_all(self) -> SyftSuccess: - for queue_name in self.producers: - self.purge_queue(queue_name=queue_name) - - return SyftSuccess(message="Successfully purged all queues.") - - -@serializable(canonical_name="ZMQQueueConfig", version=1) -class ZMQQueueConfig(QueueConfig): - def __init__( - self, - client_type: type[ZMQClient] | None = None, - client_config: ZMQClientConfig | None = None, - thread_workers: bool = False, - ): - self.client_type = client_type or ZMQClient - self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.thread_workers = thread_workers diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index c29afdb814c..e13fb988c56 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -3,7 +3,6 @@ from ...store.document_store import DocumentStore from ...store.linked_obj import LinkedObject from ...types.uid import UID -from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..notification.email_templates import RequestEmailTemplate from ..notification.email_templates import RequestUpdateEmailTemplate @@ -29,7 +28,6 @@ from .request_stash import RequestStash -@instrument @serializable(canonical_name="RequestService", version=1) class RequestService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index b41301c9ebe..19bb2f5720e 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -13,7 +13,6 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from .request import Request RequestingUserVerifyKeyPartitionKey = PartitionKey( @@ -23,7 +22,6 @@ OrderByRequestTimeStampPartitionKey = PartitionKey(key="request_time", type_=DateTime) -@instrument @serializable(canonical_name="RequestStash", version=1) class RequestStash(NewBaseUIDStoreStash): object_type = Request diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index ad3b0eb6e57..784eca2e340 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -50,6 +50,7 @@ from ..types.syft_object import SyftObject from ..types.syft_object import attach_attribute_to_syft_object from ..types.uid import UID +from ..util.telemetry import instrument from .context import AuthedServiceContext from .context import ChangeContext from .user.user_roles import DATA_OWNER_ROLE_LEVEL @@ -466,6 +467,10 @@ def wrapper(func: Any) -> Callable: if autosplat is not None and len(autosplat) > 0: signature = expand_signature(signature=input_signature, autosplat=autosplat) + @instrument( # type: ignore + span_name=f"service_method::{_path}", + attributes={"service.name": name, "service.path": path}, + ) @functools.wraps(func) def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: communication_protocol = kwargs.pop("communication_protocol", None) diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 038e710ca0b..aa02847504a 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -1,7 +1,3 @@ -# stdlib - -# third party - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey @@ -12,14 +8,12 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from .settings import ServerSettings NamePartitionKey = PartitionKey(key="name", type_=str) ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) -@instrument @serializable(canonical_name="SettingsStash", version=1) class SettingsStash(NewBaseUIDStoreStash): object_type = ServerSettings diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index ea3da2bdb3c..8633f31f130 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -1,7 +1,3 @@ -# stdlib - -# stdlib - # stdlib import threading @@ -14,14 +10,12 @@ from ...store.document_store_errors import StashException from ...types.datetime import DateTime from ...types.result import as_result -from ...util.telemetry import instrument from ..context import AuthedServiceContext from .sync_state import SyncState OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime) -@instrument @serializable(canonical_name="SyncStash", version=1) class SyncStash(NewBaseUIDStoreStash): object_type = SyncState diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index e2fcf3eeaf9..c276ea08cb0 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -20,7 +20,6 @@ from ...types.result import as_result from ...types.syft_metaclass import Empty from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext @@ -81,7 +80,6 @@ def _paginate( return list_objs -@instrument @serializable(canonical_name="UserService", version=1) class UserService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 2abf640869e..3272f9a946c 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -12,7 +12,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from .user import User from .user_roles import ServiceRole @@ -24,7 +23,6 @@ VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) -@instrument @serializable(canonical_name="UserStash", version=1) class UserStash(NewBaseUIDStoreStash): object_type = User diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 0532b98c6f6..64171c5f90b 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -352,6 +352,16 @@ def create_kubernetes_pool( "N_CONSUMERS": "1", "CREATE_PRODUCER": "False", "INMEMORY_WORKERS": "False", + "OTEL_SERVICE_NAME": f"{pool_name}", + "OTEL_PYTHON_LOG_CORRELATION": os.environ.get( + "OTEL_PYTHON_LOG_CORRELATION" + ), + "OTEL_EXPORTER_OTLP_ENDPOINT": os.environ.get( + "OTEL_EXPORTER_OTLP_ENDPOINT" + ), + "OTEL_EXPORTER_OTLP_PROTOCOL": os.environ.get( + "OTEL_EXPORTER_OTLP_PROTOCOL" + ), }, ) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index cb093c55d2f..a324035b2d2 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -19,7 +19,6 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..service import AbstractService from ..service import AuthedServiceContext from ..service import service_method @@ -38,7 +37,6 @@ from .worker_stash import WorkerStash -@instrument @serializable(canonical_name="WorkerService", version=1) class WorkerService(AbstractService): store: DocumentStore diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index ddcdfa733a2..b2b059ffec5 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -14,7 +14,6 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from .worker_pool import ConsumerState @@ -23,7 +22,6 @@ WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str) -@instrument @serializable(canonical_name="WorkerStash", version=1) class WorkerStash(NewBaseUIDStoreStash): object_type = SyftWorker diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 97d8484ebc0..3409035906f 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -42,6 +42,7 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.uid import UID from ...util.constants import DEFAULT_TIMEOUT +from ...util.telemetry import TRACING_ENABLED logger = logging.getLogger(__name__) @@ -50,6 +51,18 @@ DEFAULT_FILE_PART_SIZE = 1024**3 # 1GB DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 800 # 800KB +if TRACING_ENABLED: + try: + # third party + from opentelemetry.instrumentation.botocore import BotocoreInstrumentor + + BotocoreInstrumentor().instrument() + message = "> Added OTEL BotocoreInstrumentor" + print(message) + logger.info(message) + except Exception: # nosec + pass + @serializable() class SeaweedFSBlobDeposit(BlobDeposit): diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 45dd95d218c..cc97802a08b 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -307,7 +307,6 @@ def searchable_keys(self) -> PartitionKeys: return PartitionKeys.from_dict(self.object_type._syft_searchable_keys_dict()) -@instrument @serializable( attrs=["settings", "store_config", "unique_cks", "searchable_cks"], canonical_name="StorePartition", @@ -600,7 +599,6 @@ def _migrate_data( raise NotImplementedError -@instrument @serializable(canonical_name="DocumentStore", version=1) class DocumentStore: """Base Document Store diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index 30488017e7a..9767059a1cb 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -1,4 +1,5 @@ # stdlib +import logging from threading import Lock from typing import Any @@ -12,11 +13,25 @@ from ..serde.serializable import serializable from ..types.errors import SyftException from ..types.result import as_result +from ..util.telemetry import TRACING_ENABLED from .document_store import PartitionSettings from .document_store import StoreClientConfig from .document_store import StoreConfig from .mongo_codecs import SYFT_CODEC_OPTIONS +if TRACING_ENABLED: + try: + # third party + from opentelemetry.instrumentation.pymongo import PymongoInstrumentor + + PymongoInstrumentor().instrument() + message = "> Added OTEL PymongoInstrumentor" + print(message) + logger = logging.getLogger(__name__) + logger.info(message) + except Exception: # nosec + pass + @serializable(canonical_name="MongoStoreClientConfig", version=1) class MongoStoreClientConfig(StoreClientConfig): diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index d03f240a1de..ef376d78fc0 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -5,76 +5,58 @@ from typing import Any from typing import TypeVar -logger = logging.getLogger(__name__) - - -def str_to_bool(bool_str: str | None) -> bool: - result = False - bool_str = str(bool_str).lower() - if bool_str == "true" or bool_str == "1": - result = True - return result +# relative +from .util import str_to_bool +__all__ = ["TRACING_ENABLED", "instrument"] -TRACE_MODE = str_to_bool(os.environ.get("TRACE", "False")) +logger = logging.getLogger(__name__) +TRACING_ENABLED = str_to_bool(os.environ.get("TRACING", "False")) T = TypeVar("T", bound=Callable | type) -def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: - return __func_or_class +def noop(__func_or_class: T | None = None, /, *args: Any, **kwargs: Any) -> T: + def noop_wrapper(__func_or_class: T) -> T: + return __func_or_class + + if __func_or_class is None: + return noop_wrapper # type: ignore + else: + return __func_or_class -if not TRACE_MODE: +if not TRACING_ENABLED: instrument = noop else: try: - service_name = os.environ.get("SERVICE_NAME", "client") - jaeger_host = os.environ.get("JAEGER_HOST", "localhost") - jaeger_port = int(os.environ.get("JAEGER_PORT", "14268")) - # third party from opentelemetry import trace - from opentelemetry.exporter.jaeger.thrift import JaegerExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter, + ) from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.resources import SERVICE_NAME from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor - trace.set_tracer_provider( - TracerProvider(resource=Resource.create({SERVICE_NAME: service_name})) - ) - jaeger_exporter = JaegerExporter( - # agent_host_name=jaeger_host, - # agent_port=jaeger_port, - collector_endpoint=f"http://{jaeger_host}:{jaeger_port}/api/traces?format=jaeger.thrift", - # udp_split_oversized_batches=True, - ) - - trace.get_tracer_provider().add_span_processor( - BatchSpanProcessor(jaeger_exporter) - ) - - # from opentelemetry.sdk.trace.export import ConsoleSpanExporter - # console_exporter = ConsoleSpanExporter() - # span_processor = BatchSpanProcessor(console_exporter) - # trace.get_tracer_provider().add_span_processor(span_processor) + # relative + from .trace_decorator import instrument as _instrument - # third party - import opentelemetry.instrumentation.requests + # create a provider + service_name = os.environ.get("OTEL_SERVICE_NAME", "syft-backend") + resource = Resource.create({"service.name": service_name}) + provider = TracerProvider(resource=resource) - opentelemetry.instrumentation.requests.RequestsInstrumentor().instrument() + # create a span processor + otlp_exporter = OTLPSpanExporter() + span_processor = BatchSpanProcessor(otlp_exporter) + provider.add_span_processor(span_processor) - # relative - # from opentelemetry.instrumentation.digma.trace_decorator import ( - # instrument as _instrument, - # ) - # - # until this is merged: - # https://github.com/digma-ai/opentelemetry-instrumentation-digma/pull/41 - from .trace_decorator import instrument as _instrument + # set the global trace provider + trace.set_tracer_provider(provider) + # expose the instrument decorator instrument = _instrument except Exception as e: logger.error("Failed to import opentelemetry", exc_info=e) diff --git a/packages/syft/src/syft/util/trace_decorator.py b/packages/syft/src/syft/util/trace_decorator.py index 87486b0cda4..719ca822afa 100644 --- a/packages/syft/src/syft/util/trace_decorator.py +++ b/packages/syft/src/syft/util/trace_decorator.py @@ -9,7 +9,6 @@ from typing import Any from typing import ClassVar from typing import TypeVar -from typing import cast # third party from opentelemetry import trace @@ -44,7 +43,7 @@ def set_default_attributes(cls, attributes: dict[str, str] | None = None) -> Non def instrument( - _func_or_class: T, + _func_or_class: T | None = None, /, *, span_name: str = "", @@ -99,16 +98,12 @@ def decorate_class(cls: T) -> T: return cls - # Check if this is a span or class decorator - if inspect.isclass(_func_or_class): - return decorate_class(_func_or_class) - def span_decorator(func_or_class: T) -> T: - if inspect.isclass(func_or_class): + if ignore: + return func_or_class + elif inspect.isclass(func_or_class): return decorate_class(func_or_class) - # sig = inspect.signature(func_or_class) - # Check if already decorated (happens if both class and function # decorated). If so, we keep the function decorator settings only undecorated_func = getattr(func_or_class, "__tracing_unwrapped__", None) @@ -155,16 +150,20 @@ async def wrap_with_span_async(*args: Any, **kwargs: Any) -> Callable: _set_attributes(span, attributes) return await func_or_class(*args, **kwargs) - if ignore: - return func_or_class - - wrapper = ( + span_wrapper = ( wrap_with_span_async if asyncio.iscoroutinefunction(func_or_class) else wrap_with_span_sync ) - wrapper.__signature__ = inspect.signature(func_or_class) + span_wrapper.__signature__ = inspect.signature(func_or_class) - return cast(T, wrapper) + return span_wrapper # type: ignore - return span_decorator(_func_or_class) + # decorator factory on a class or func + # @instrument or @instrument(span_name="my_span", ...) + if _func_or_class and inspect.isclass(_func_or_class): + return decorate_class(_func_or_class) + elif _func_or_class: + return span_decorator(_func_or_class) + else: + return span_decorator # type: ignore diff --git a/packages/syft/tests/syft/zmq_queue_test.py b/packages/syft/tests/syft/zmq_queue_test.py index 64ffc84cdbd..a995d09dced 100644 --- a/packages/syft/tests/syft/zmq_queue_test.py +++ b/packages/syft/tests/syft/zmq_queue_test.py @@ -13,11 +13,11 @@ import syft from syft.service.queue.base_queue import AbstractMessageHandler from syft.service.queue.queue import QueueManager -from syft.service.queue.zmq_queue import ZMQClient -from syft.service.queue.zmq_queue import ZMQClientConfig -from syft.service.queue.zmq_queue import ZMQConsumer -from syft.service.queue.zmq_queue import ZMQProducer -from syft.service.queue.zmq_queue import ZMQQueueConfig +from syft.service.queue.zmq_client import ZMQClient +from syft.service.queue.zmq_client import ZMQClientConfig +from syft.service.queue.zmq_client import ZMQQueueConfig +from syft.service.queue.zmq_consumer import ZMQConsumer +from syft.service.queue.zmq_producer import ZMQProducer from syft.service.response import SyftSuccess from syft.types.errors import SyftException from syft.util.util import get_queue_address diff --git a/scripts/dev_tools.sh b/scripts/dev_tools.sh index 763e602cd28..20a74b597e0 100755 --- a/scripts/dev_tools.sh +++ b/scripts/dev_tools.sh @@ -23,7 +23,7 @@ function docker_list_exposed_ports() { if [[ -z "$1" ]]; then # list db, redis, rabbitmq, and seaweedfs ports - docker_list_exposed_ports "db\|redis\|queue\|seaweedfs\|jaeger\|mongo" + docker_list_exposed_ports "db\|seaweedfs\|mongo" else PORT=$1 if docker ps | grep ":${PORT}" | grep -q 'redis'; then diff --git a/tox.ini b/tox.ini index a3924b6d09a..3716fc9de44 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = dev.k8s.cleanup dev.k8s.destroy dev.k8s.destroyall + dev.k8s.install.signoz lint stack.test.integration syft.docs @@ -420,10 +421,9 @@ setenv = SERVER_PORT = {env:SERVER_PORT:8080} TEST_EXTERNAL_REGISTRY = {env:TEST_EXTERNAL_REGISTRY:k3d-registry.localhost:5800} TEST_QUERY_LIMIT_SIZE={env:TEST_QUERY_LIMIT_SIZE:500000} + TRACING={env:TRACING:False} commands = bash -c "python --version || true" - bash -c "python3 --version || true" - bash -c "python3.12 --version || true" bash -c "echo Running with GITHUB_CI=$GITHUB_CI; date" bash -c "echo Running with TEST_EXTERNAL_REGISTRY=$TEST_EXTERNAL_REGISTRY; date" python -c 'import syft as sy; sy.stage_protocol_changes()' @@ -467,7 +467,7 @@ commands = bash -c "pytest -s -x --nbmake notebooks/scenarios/bigquery -p no:randomly --ignore=notebooks/scenarios/bigquery/sync -vvvv --nbmake-timeout=1000 --log-cli-level=DEBUG --capture=no;" - ; # deleting clusters created + # deleting clusters created bash -c "CLUSTER_NAME=${DATASITE_CLUSTER_NAME} tox -e dev.k8s.destroy || true" bash -c "k3d registry delete k3d-registry.localhost || true" bash -c "docker volume rm k3d-${DATASITE_CLUSTER_NAME}-images --force || true" @@ -966,7 +966,7 @@ commands = description = Patch CoreDNS to resolve k3d-registry.localhost changedir = {toxinidir} passenv=HOME,USER,CLUSTER_NAME -setenv = +setenv= CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = bash @@ -977,6 +977,23 @@ commands = ; restarts coredns bash -c 'kubectl delete pod -n kube-system -l k8s-app=kube-dns --context k3d-${CLUSTER_NAME}' +[testenv:dev.k8s.install.signoz] +description = Install Signoz on local Kubernetes cluster +changedir = {toxinidir} +passenv=HOME,USER,CLUSTER_NAME +allowlist_externals= + bash +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:test-datasite-1} + SIGNOZ_PORT = {env:SIGNOZ_PORT:3301} +commands= + bash -c 'if [ "{posargs}" ]; then kubectl config use-context {posargs}; fi' + bash -c 'helm repo add signoz https://charts.signoz.io && helm repo update' + bash -c 'helm install syft signoz/signoz --namespace platform --create-namespace || true' + # bash -c 'k3d cluster edit ${CLUSTER_NAME} --port-add "${SIGNOZ_PORT}:3301@loadbalancer"' + ; bash packages/grid/scripts/wait_for.sh service syft-signoz-frontend --context k3d-{env:CLUSTER_NAME} --namespace platform + + [testenv:dev.k8s.start] description = Start local Kubernetes registry & cluster with k3d changedir = {toxinidir} @@ -984,6 +1001,7 @@ passenv = HOME, USER setenv = CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} CLUSTER_HTTP_PORT = {env:CLUSTER_HTTP_PORT:8080} + # SIGNOZ_PORT = {env:SIGNOZ_PORT:3301} allowlist_externals = bash sleep @@ -993,9 +1011,11 @@ commands = tox -e dev.k8s.registry ; for NodePort to work add the following --> -p "NodePort:NodePort@loadbalancer" - bash -c 'k3d cluster create ${CLUSTER_NAME} -p "${CLUSTER_HTTP_PORT}:80@loadbalancer" --registry-use k3d-registry.localhost:5800 {posargs} && \ - kubectl --context k3d-${CLUSTER_NAME} create namespace syft || true' - + bash -c 'k3d cluster create ${CLUSTER_NAME} \ + -p "${CLUSTER_HTTP_PORT}:80@loadbalancer" \ + --registry-use k3d-registry.localhost:5800 {posargs} && \ + kubectl --context k3d-${CLUSTER_NAME} create namespace syft || true' + # -p "${SIGNOZ_PORT}:3301@loadbalancer" \ ; patch coredns tox -e dev.k8s.patch.coredns @@ -1008,15 +1028,22 @@ changedir = {toxinidir}/packages/grid passenv = HOME, USER, DEVSPACE_PROFILE setenv= CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} + TRACING = {env:TRACING:False} allowlist_externals = bash commands = ; deploy syft helm charts bash -c 'echo "profile=$DEVSPACE_PROFILE"' + bash -c "echo Running with TRACING=$TRACING; date" bash -c '\ if [[ -n "${DEVSPACE_PROFILE}" ]]; then export DEVSPACE_PROFILE="-p ${DEVSPACE_PROFILE}"; fi && \ + if [[ "${TRACING}" == "True" ]]; then DEVSPACE_PROFILE="${DEVSPACE_PROFILE} -p tracing"; fi && \ + if [[ "${TRACING}" == "True" ]]; then echo "TRACING PROFILE ENABLED"; fi && \ devspace deploy -b --kube-context k3d-${CLUSTER_NAME} --no-warn ${DEVSPACE_PROFILE} --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' + # if TRACING is enabled start signoz + ; bash -c 'if [[ "${TRACING}" == "True" ]]; then tox -e dev.k8s.install.signoz; fi' + [testenv:dev.k8s.hotreload] description = Start development with hot-reload in Kubernetes changedir = {toxinidir}/packages/grid