From 75f1e483d3073cbabddb4079457551541d83b699 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Wed, 19 Jun 2024 21:27:45 +0530 Subject: [PATCH 1/9] fix logging --- packages/grid/backend/grid/logger/__init__.py | 0 packages/grid/backend/grid/logger/config.py | 59 ----- packages/grid/backend/grid/logger/handler.py | 108 --------- packages/grid/backend/grid/logging.yaml | 46 ++++ packages/grid/backend/grid/main.py | 22 +- packages/grid/backend/grid/start.sh | 2 +- packages/syft/setup.cfg | 1 - packages/syft/src/syft/__init__.py | 142 ++++++------ packages/syft/src/syft/client/api.py | 4 +- packages/syft/src/syft/client/client.py | 6 +- .../syft/src/syft/client/domain_client.py | 10 +- packages/syft/src/syft/client/registry.py | 17 +- packages/syft/src/syft/client/syncing.py | 9 +- packages/syft/src/syft/node/node.py | 54 ++--- packages/syft/src/syft/node/routes.py | 14 +- packages/syft/src/syft/node/server.py | 10 +- .../src/syft/service/action/action_object.py | 68 +++--- .../src/syft/service/action/action_types.py | 10 +- .../syft/service/network/network_service.py | 4 +- .../src/syft/service/network/node_peer.py | 11 +- .../syft/src/syft/service/network/utils.py | 12 +- .../syft/service/notifier/notifier_service.py | 20 +- packages/syft/src/syft/service/queue/queue.py | 7 +- .../syft/src/syft/service/queue/zmq_queue.py | 209 +++++++++--------- .../syft/src/syft/service/request/request.py | 9 +- packages/syft/src/syft/service/service.py | 5 +- .../src/syft/service/settings/settings.py | 7 +- .../syft/src/syft/service/sync/diff_state.py | 11 +- .../src/syft/service/sync/sync_service.py | 4 +- .../syft/src/syft/service/worker/utils.py | 20 +- .../service/worker/worker_pool_service.py | 5 +- .../src/syft/store/blob_storage/__init__.py | 7 +- .../src/syft/store/blob_storage/seaweedfs.py | 5 +- .../syft/src/syft/store/document_store.py | 1 - packages/syft/src/syft/store/locks.py | 4 +- .../src/syft/store/sqlite_document_store.py | 5 +- packages/syft/src/syft/types/grid_url.py | 5 +- packages/syft/src/syft/types/syft_object.py | 8 +- packages/syft/src/syft/types/uid.py | 10 +- packages/syft/src/syft/util/logger.py | 134 ----------- .../components/tabulator_template.py | 6 +- packages/syft/src/syft/util/table.py | 6 +- packages/syft/src/syft/util/telemetry.py | 8 +- packages/syft/src/syft/util/util.py | 35 ++- ruff.toml | 1 + 45 files changed, 464 insertions(+), 677 deletions(-) delete mode 100644 packages/grid/backend/grid/logger/__init__.py delete mode 100644 packages/grid/backend/grid/logger/config.py delete mode 100644 packages/grid/backend/grid/logger/handler.py create mode 100644 packages/grid/backend/grid/logging.yaml delete mode 100644 packages/syft/src/syft/util/logger.py diff --git a/packages/grid/backend/grid/logger/__init__.py b/packages/grid/backend/grid/logger/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/packages/grid/backend/grid/logger/config.py b/packages/grid/backend/grid/logger/config.py deleted file mode 100644 index 000a9c9c713..00000000000 --- a/packages/grid/backend/grid/logger/config.py +++ /dev/null @@ -1,59 +0,0 @@ -"""This file defines the configuration for `loguru` which is used as the python logging client. -For more information refer to `loguru` documentation: https://loguru.readthedocs.io/en/stable/overview.html -""" - -# stdlib -from datetime import time -from datetime import timedelta -from enum import Enum -from functools import lru_cache - -# third party -from pydantic_settings import BaseSettings - - -# LOGURU_LEVEL type for version>3.8 -class LogLevel(Enum): - """Types of logging levels.""" - - TRACE = "TRACE" - DEBUG = "DEBUG" - INFO = "INFO" - SUCCESS = "SUCCESS" - WARNING = "WARNING" - ERROR = "ERROR" - CRITICAL = "CRITICAL" - - -class LogConfig(BaseSettings): - """Configuration for the logging client.""" - - # Logging format - LOGURU_FORMAT: str = ( - "{time:YYYY-MM-DD HH:mm:ss} | " - "{level: <8} | " - "{name}:{function}:{line}: " - "{message}" - ) - - LOGURU_LEVEL: str = LogLevel.INFO.value - LOGURU_SINK: str | None = "/var/log/pygrid/grid.log" - LOGURU_COMPRESSION: str | None = None - LOGURU_ROTATION: str | int | time | timedelta | None = None - LOGURU_RETENTION: str | int | timedelta | None = None - LOGURU_COLORIZE: bool | None = True - LOGURU_SERIALIZE: bool | None = False - LOGURU_BACKTRACE: bool | None = True - LOGURU_DIAGNOSE: bool | None = False - LOGURU_ENQUEUE: bool | None = True - LOGURU_AUTOINIT: bool | None = False - - -@lru_cache -def get_log_config() -> LogConfig: - """Returns the configuration for the logging client. - - Returns: - LogConfig: configuration for the logging client. - """ - return LogConfig() diff --git a/packages/grid/backend/grid/logger/handler.py b/packages/grid/backend/grid/logger/handler.py deleted file mode 100644 index 7f198bbcece..00000000000 --- a/packages/grid/backend/grid/logger/handler.py +++ /dev/null @@ -1,108 +0,0 @@ -# future -from __future__ import annotations - -# stdlib -from functools import lru_cache -import logging -from pprint import pformat -import sys - -# third party -import loguru -from loguru import logger - -# relative -from .config import get_log_config - - -class LogHandler: - def __init__(self) -> None: - self.config = get_log_config() - - def format_record(self, record: loguru.Record) -> str: - """ - Custom loguru log message format for handling JSON (in record['extra']) - """ - format_string: str = self.config.LOGURU_FORMAT - - if record["extra"] is not None: - for key in record["extra"].keys(): - record["extra"][key] = pformat( - record["extra"][key], indent=2, compact=False, width=88 - ) - format_string += "\n{extra[" + key + "]}" - - format_string += "{exception}\n" - - return format_string - - def init_logger(self) -> None: - """ - Redirects all registered std logging handlers to a loguru sink. - Call init_logger() on fastapi startup. - """ - intercept_handler = InterceptHandler() - - # Generalizes log level for all root loggers, including third party - logging.root.setLevel(self.config.LOGURU_LEVEL) - logging.root.handlers = [intercept_handler] - - for log in logging.root.manager.loggerDict.keys(): - log_instance = logging.getLogger(log) - log_instance.handlers = [] - log_instance.propagate = True - - logger.configure( - handlers=[ - { - "sink": sys.stdout, - "level": self.config.LOGURU_LEVEL, - "serialize": self.config.LOGURU_SERIALIZE, - "format": self.format_record, - } - ], - ) - - try: - if ( - self.config.LOGURU_SINK is not ("sys.stdout" or "sys.stderr") - and self.config.LOGURU_SINK is not None - ): - logger.add( - self.config.LOGURU_SINK, - retention=self.config.LOGURU_RETENTION, - rotation=self.config.LOGURU_ROTATION, - compression=self.config.LOGURU_COMPRESSION, - ) - logger.debug(f"Logging to {self.config.LOGURU_SINK}") - - except Exception as err: - logger.debug( - f"Failed creating a new sink. Check your log config. error: {err}" - ) - - -class InterceptHandler(logging.Handler): - """ - Check https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging - """ - - def emit(self, record: logging.LogRecord) -> None: - try: - level = logger.level(record.levelname).name - except ValueError: - level = record.levelno - - frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: - frame = frame.f_back # type: ignore - depth += 1 - - logger.opt(depth=depth, exception=record.exc_info).log( - level, record.getMessage() - ) - - -@lru_cache -def get_log_handler() -> LogHandler: - return LogHandler() diff --git a/packages/grid/backend/grid/logging.yaml b/packages/grid/backend/grid/logging.yaml new file mode 100644 index 00000000000..b41eb783038 --- /dev/null +++ b/packages/grid/backend/grid/logging.yaml @@ -0,0 +1,46 @@ +version: 1 +disable_existing_loggers: True +formatters: + default: + format: "%(asctime)s - %(levelname)s - %(name)s - %(message)s" + datefmt: "%Y-%m-%d %H:%M:%S" + uvicorn.default: + "()": uvicorn.logging.DefaultFormatter + format: "%(asctime)s - %(levelname)s - %(name)s - %(message)s" + uvicorn.access: + "()": "uvicorn.logging.AccessFormatter" + format: "%(asctime)s - %(levelname)s - %(name)s - %(message)s" + datefmt: "%Y-%m-%d %H:%M:%S" +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stdout + uvicorn.default: + formatter: uvicorn.default + class: logging.StreamHandler + stream: ext://sys.stdout + uvicorn.access: + formatter: uvicorn.access + class: logging.StreamHandler + stream: ext://sys.stdout +loggers: + uvicorn.error: + level: INFO + handlers: + - uvicorn.default + propagate: no + uvicorn.access: + level: INFO + handlers: + - uvicorn.access + propagate: no + syft: + level: INFO + handlers: + - default + propagate: no +root: + level: INFO + handlers: + - default diff --git a/packages/grid/backend/grid/main.py b/packages/grid/backend/grid/main.py index 9ca43dadee8..459448c5f01 100644 --- a/packages/grid/backend/grid/main.py +++ b/packages/grid/backend/grid/main.py @@ -1,7 +1,6 @@ -# stdlib - # stdlib from contextlib import asynccontextmanager +import logging from typing import Any # third party @@ -16,7 +15,15 @@ from grid.api.router import api_router from grid.core.config import settings from grid.core.node import worker -from grid.logger.handler import get_log_handler + + +class EndpointFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/api/v2/?probe=livenessProbe") == -1 + + +logger = logging.getLogger("uvicorn.error") +logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) @asynccontextmanager @@ -25,7 +32,7 @@ async def lifespan(app: FastAPI) -> Any: yield finally: worker.stop() - print("Worker Stop !!!") + logger.info("Worker Stop") app = FastAPI( @@ -34,7 +41,6 @@ async def lifespan(app: FastAPI) -> Any: lifespan=lifespan, ) -app.add_event_handler("startup", get_log_handler().init_logger) # Set all CORS enabled origins if settings.BACKEND_CORS_ORIGINS: @@ -47,13 +53,13 @@ async def lifespan(app: FastAPI) -> Any: ) app.include_router(api_router, prefix=settings.API_V2_STR) -print("Included routes, app should now be reachable") +logger.info("Included routes, app should now be reachable") if settings.DEV_MODE: - print("Staging protocol changes...") + logger.info("Staging protocol changes...") status = stage_protocol_changes() - print(status) + logger.info(f"Staging protocol result: {status}") # needed for Google Kubernetes Engine LoadBalancer Healthcheck diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index bcb36c5e5a9..4b3d5de4cf2 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -33,4 +33,4 @@ export NODE_TYPE=$NODE_TYPE echo "NODE_UID=$NODE_UID" echo "NODE_TYPE=$NODE_TYPE" -exec $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" +exec $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-config=$APPDIR/grid/logging.yaml --log-level $LOG_LEVEL "$APP_MODULE" diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 59bfee973ea..944e4986751 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -30,7 +30,6 @@ syft = bcrypt==4.1.2 boto3==1.34.56 forbiddenfruit==0.1.4 - loguru==0.7.2 packaging>=23.0 pyarrow==15.0.0 pycapnp==2.0.0 diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index d7183898935..7aa5789aa30 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -9,79 +9,78 @@ from typing import Any # relative -from .abstract_node import NodeSideType # noqa: F401 -from .abstract_node import NodeType # noqa: F401 -from .client.client import connect # noqa: F401 -from .client.client import login # noqa: F401 -from .client.client import login_as_guest # noqa: F401 -from .client.client import register # noqa: F401 -from .client.domain_client import DomainClient # noqa: F401 -from .client.gateway_client import GatewayClient # noqa: F401 -from .client.registry import DomainRegistry # noqa: F401 -from .client.registry import EnclaveRegistry # noqa: F401 -from .client.registry import NetworkRegistry # noqa: F401 -from .client.search import Search # noqa: F401 -from .client.search import SearchResults # noqa: F401 -from .client.user_settings import UserSettings # noqa: F401 -from .client.user_settings import settings # noqa: F401 -from .custom_worker.config import DockerWorkerConfig # noqa: F401 -from .custom_worker.config import PrebuiltWorkerConfig # noqa: F401 -from .node.credentials import SyftSigningKey # noqa: F401 -from .node.domain import Domain # noqa: F401 -from .node.enclave import Enclave # noqa: F401 -from .node.gateway import Gateway # noqa: F401 -from .node.server import serve_node # noqa: F401 -from .node.server import serve_node as bind_worker # noqa: F401 -from .node.worker import Worker # noqa: F401 -from .orchestra import Orchestra as orchestra # noqa: F401 -from .protocol.data_protocol import bump_protocol_version # noqa: F401 -from .protocol.data_protocol import check_or_stage_protocol # noqa: F401 -from .protocol.data_protocol import get_data_protocol # noqa: F401 -from .protocol.data_protocol import stage_protocol_changes # noqa: F401 -from .serde import NOTHING # noqa: F401 -from .serde.deserialize import _deserialize as deserialize # noqa: F401 -from .serde.serializable import serializable # noqa: F401 -from .serde.serialize import _serialize as serialize # noqa: F401 -from .service.action.action_data_empty import ActionDataEmpty # noqa: F401 -from .service.action.action_object import ActionObject # noqa: F401 -from .service.action.plan import Plan # noqa: F401 -from .service.action.plan import planify # noqa: F401 -from .service.api.api import api_endpoint # noqa: F401 -from .service.api.api import api_endpoint_method # noqa: F401 -from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint # noqa: F401 -from .service.code.user_code import UserCodeStatus # noqa: F401; noqa: F401 -from .service.code.user_code import syft_function # noqa: F401; noqa: F401 -from .service.code.user_code import syft_function_single_use # noqa: F401; noqa: F401 -from .service.data_subject import DataSubjectCreate as DataSubject # noqa: F401 -from .service.dataset.dataset import Contributor # noqa: F401 -from .service.dataset.dataset import CreateAsset as Asset # noqa: F401 -from .service.dataset.dataset import CreateDataset as Dataset # noqa: F401 -from .service.notification.notifications import NotificationStatus # noqa: F401 -from .service.policy.policy import CustomInputPolicy # noqa: F401 -from .service.policy.policy import CustomOutputPolicy # noqa: F401 -from .service.policy.policy import ExactMatch # noqa: F401 -from .service.policy.policy import SingleExecutionExactOutput # noqa: F401 -from .service.policy.policy import UserInputPolicy # noqa: F401 -from .service.policy.policy import UserOutputPolicy # noqa: F401 -from .service.project.project import ProjectSubmit as Project # noqa: F401 -from .service.request.request import SubmitRequest as Request # noqa: F401 -from .service.response import SyftError # noqa: F401 -from .service.response import SyftNotReady # noqa: F401 -from .service.response import SyftSuccess # noqa: F401 -from .service.user.roles import Roles as roles # noqa: F401 -from .service.user.user_service import UserService # noqa: F401 +from .abstract_node import NodeSideType +from .abstract_node import NodeType +from .client.client import connect +from .client.client import login +from .client.client import login_as_guest +from .client.client import register +from .client.domain_client import DomainClient +from .client.gateway_client import GatewayClient +from .client.registry import DomainRegistry +from .client.registry import EnclaveRegistry +from .client.registry import NetworkRegistry +from .client.search import Search +from .client.search import SearchResults +from .client.user_settings import UserSettings +from .client.user_settings import settings +from .custom_worker.config import DockerWorkerConfig +from .custom_worker.config import PrebuiltWorkerConfig +from .node.credentials import SyftSigningKey +from .node.domain import Domain +from .node.enclave import Enclave +from .node.gateway import Gateway +from .node.server import serve_node +from .node.server import serve_node as bind_worker +from .node.worker import Worker +from .orchestra import Orchestra as orchestra +from .protocol.data_protocol import bump_protocol_version +from .protocol.data_protocol import check_or_stage_protocol +from .protocol.data_protocol import get_data_protocol +from .protocol.data_protocol import stage_protocol_changes +from .serde import NOTHING +from .serde.deserialize import _deserialize as deserialize +from .serde.serializable import serializable +from .serde.serialize import _serialize as serialize +from .service.action.action_data_empty import ActionDataEmpty +from .service.action.action_object import ActionObject +from .service.action.plan import Plan +from .service.action.plan import planify +from .service.api.api import api_endpoint +from .service.api.api import api_endpoint_method +from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint +from .service.code.user_code import UserCodeStatus +from .service.code.user_code import syft_function +from .service.code.user_code import syft_function_single_use +from .service.data_subject import DataSubjectCreate as DataSubject +from .service.dataset.dataset import Contributor +from .service.dataset.dataset import CreateAsset as Asset +from .service.dataset.dataset import CreateDataset as Dataset +from .service.notification.notifications import NotificationStatus +from .service.policy.policy import CustomInputPolicy +from .service.policy.policy import CustomOutputPolicy +from .service.policy.policy import ExactMatch +from .service.policy.policy import SingleExecutionExactOutput +from .service.policy.policy import UserInputPolicy +from .service.policy.policy import UserOutputPolicy +from .service.project.project import ProjectSubmit as Project +from .service.request.request import SubmitRequest as Request +from .service.response import SyftError +from .service.response import SyftNotReady +from .service.response import SyftSuccess +from .service.user.roles import Roles as roles +from .service.user.user_service import UserService from .stable_version import LATEST_STABLE_SYFT from .types.syft_object import SyftObject -from .types.twin_object import TwinObject # noqa: F401 -from .types.uid import UID # noqa: F401 -from .util import filterwarnings # noqa: F401 -from .util import logger # noqa: F401 -from .util import options # noqa: F401 -from .util.autoreload import disable_autoreload # noqa: F401 -from .util.autoreload import enable_autoreload # noqa: F401 -from .util.telemetry import instrument # noqa: F401 -from .util.util import autocache # noqa: F401 -from .util.util import get_root_data_path # noqa: F401 +from .types.twin_object import TwinObject +from .types.uid import UID +from .util import filterwarnings +from .util import options +from .util.autoreload import disable_autoreload +from .util.autoreload import enable_autoreload +from .util.telemetry import instrument +from .util.util import autocache +from .util.util import get_root_data_path from .util.version_compare import make_requires requires = make_requires(LATEST_STABLE_SYFT, __version__) @@ -92,7 +91,6 @@ sys.path.append(str(Path(__file__))) -logger.start() try: # third party diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index c4a3a1b40a9..0e61053a768 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -226,6 +226,9 @@ def sign(self, credentials: SyftSigningKey) -> SignedSyftAPICall: signature=signed_message.signature, ) + def __repr__(self) -> str: + return f"SyftAPICall(path={self.path}, args={self.args}, kwargs={self.kwargs}, blocking={self.blocking})" + @instrument @serializable() @@ -1266,7 +1269,6 @@ def monkey_patch_getdef(self: Any, obj: Any, oname: str = "") -> str | None: Inspector._getdef_bak = Inspector._getdef Inspector._getdef = types.MethodType(monkey_patch_getdef, Inspector) except Exception: - # print("Failed to monkeypatch IPython Signature Override") pass # nosec diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 2cec54d0e52..79f8dd21198 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -7,6 +7,7 @@ from enum import Enum from getpass import getpass import json +import logging from typing import Any from typing import TYPE_CHECKING from typing import cast @@ -48,7 +49,6 @@ from ..types.grid_url import GridURL from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.uid import UID -from ..util.logger import debug from ..util.telemetry import instrument from ..util.util import prompt_warning_message from ..util.util import thread_ident @@ -62,6 +62,8 @@ from .connection import NodeConnection from .protocol import SyftProtocol +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from ..service.network.node_peer import NodePeer @@ -77,7 +79,7 @@ def upgrade_tls(url: GridURL, response: Response) -> GridURL: if response.url.startswith("https://") and url.protocol == "http": # we got redirected to https https_url = GridURL.from_url(response.url).with_path("") - debug(f"GridURL Upgraded to HTTPS. {https_url}") + logger.debug(f"GridURL Upgraded to HTTPS. {https_url}") return https_url except Exception as e: print(f"Failed to upgrade to HTTPS. {e}") diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 5cdaa88906d..b54ebf0bcf0 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -2,14 +2,15 @@ from __future__ import annotations # stdlib +import logging from pathlib import Path import re from string import Template +import traceback from typing import TYPE_CHECKING from typing import cast # third party -from loguru import logger import markdown from result import Result from tqdm import tqdm @@ -41,6 +42,8 @@ from .connection import NodeConnection from .protocol import SyftProtocol +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from ..orchestra import NodeHandle @@ -271,8 +274,9 @@ def upload_files( return ActionObject.from_obj(result).send(self) except Exception as err: - logger.debug("upload_files: Error creating action_object: {}", err) - return SyftError(message=f"Failed to upload files: {err}") + return SyftError( + message=f"Failed to upload files: {err}.\n{traceback.format_exc()}" + ) def connect_to_gateway( self, diff --git a/packages/syft/src/syft/client/registry.py b/packages/syft/src/syft/client/registry.py index 4128af452d8..4f239e2265d 100644 --- a/packages/syft/src/syft/client/registry.py +++ b/packages/syft/src/syft/client/registry.py @@ -4,6 +4,7 @@ # stdlib from concurrent import futures import json +import logging import os from typing import Any @@ -18,10 +19,9 @@ from ..service.response import SyftException from ..types.grid_url import GridURL from ..util.constants import DEFAULT_TIMEOUT -from ..util.logger import error -from ..util.logger import warning from .client import SyftClient as Client +logger = logging.getLogger(__name__) NETWORK_REGISTRY_URL = ( "https://raw.githubusercontent.com/OpenMined/NetworkRegistry/main/gateways.json" ) @@ -43,7 +43,7 @@ def __init__(self) -> None: network_json=network_json, version="2.0.0" ) except Exception as e: - warning( + logger.warning( f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. Exception: {e}" ) @@ -64,7 +64,7 @@ def load_network_registry_json() -> dict: return network_json except Exception as e: - warning( + logger.warning( f"Failed to get Network Registry from {NETWORK_REGISTRY_REPO}. Exception: {e}" ) return {} @@ -169,7 +169,6 @@ def create_client(network: dict[str, Any]) -> Client: client = connect(url=str(grid_url)) return client.guest() except Exception as e: - error(f"Failed to login with: {network}. {e}") raise SyftException(f"Failed to login with: {network}. {e}") def __getitem__(self, key: str | int) -> Client: @@ -194,7 +193,7 @@ def __init__(self) -> None: ) self._get_all_domains() except Exception as e: - warning( + logger.warning( f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. {e}" ) @@ -263,7 +262,7 @@ def online_domains(self) -> list[tuple[NodePeer, NodeMetadataJSON | None]]: try: network_client = NetworkRegistry.create_client(network) except Exception as e: - print(f"Error in creating network client with exception {e}") + logger.error(f"Error in creating network client {e}") continue domains: list[NodePeer] = network_client.domains.retrieve_nodes() @@ -334,7 +333,6 @@ def create_client(self, peer: NodePeer) -> Client: try: return peer.guest_client except Exception as e: - error(f"Failed to login to: {peer}. {e}") raise SyftException(f"Failed to login to: {peer}. {e}") def __getitem__(self, key: str | int) -> Client: @@ -364,7 +362,7 @@ def __init__(self) -> None: enclaves_json = response.json() self.all_enclaves = enclaves_json["2.0.0"]["enclaves"] except Exception as e: - warning( + logger.warning( f"Failed to get Enclave Registry, go checkout: {ENCLAVE_REGISTRY_REPO}. {e}" ) @@ -433,7 +431,6 @@ def create_client(enclave: dict[str, Any]) -> Client: client = connect(url=str(grid_url)) return client.guest() except Exception as e: - error(f"Failed to login with: {enclave}. {e}") raise SyftException(f"Failed to login with: {enclave}. {e}") def __getitem__(self, key: str | int) -> Client: diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index c0b4dd8196e..2a17dbc2d55 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -2,6 +2,7 @@ # stdlib from collections.abc import Collection +import logging # relative from ..abstract_node import NodeSideType @@ -21,6 +22,8 @@ from .sync_decision import SyncDecision from .sync_decision import SyncDirection +logger = logging.getLogger(__name__) + def compare_states( from_state: SyncState, @@ -174,7 +177,7 @@ def handle_sync_batch( ) sync_instructions.append(instruction) - print(f"Decision: Syncing {len(sync_instructions)} objects") + logger.debug(f"Decision: Syncing {len(sync_instructions)} objects") # Apply empty state to source side to signal that we are done syncing res_src = src_client.apply_state(src_resolved_state) @@ -206,7 +209,7 @@ def handle_ignore_batch( for other_batch in other_ignore_batches: other_batch.decision = SyncDecision.IGNORE - print(f"Ignoring other batch with root {other_batch.root_type.__name__}") + logger.debug(f"Ignoring other batch with root {other_batch.root_type.__name__}") src_client = obj_diff_batch.source_client tgt_client = obj_diff_batch.target_client @@ -240,7 +243,7 @@ def handle_unignore_batch( other_batches = [b for b in all_batches if b is not obj_diff_batch] other_unignore_batches = get_other_unignore_batches(obj_diff_batch, other_batches) for other_batch in other_unignore_batches: - print(f"Ignoring other batch with root {other_batch.root_type.__name__}") + logger.debug(f"Ignoring other batch with root {other_batch.root_type.__name__}") other_batch.decision = None src_resolved_state.add_unignored(other_batch.root_id) tgt_resolved_state.add_unignored(other_batch.root_id) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 31ead514b2b..58a3798c5fd 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -8,6 +8,7 @@ from functools import partial import hashlib import json +import logging import os from pathlib import Path import shutil @@ -19,7 +20,6 @@ from typing import Any # third party -from loguru import logger from nacl.signing import SigningKey from result import Err from result import Result @@ -140,6 +140,8 @@ from .credentials import SyftVerifyKey from .worker_settings import WorkerSettings +logger = logging.getLogger(__name__) + # if user code needs to be serded and its not available we can call this to refresh # the code for a specific node UID and thread CODE_RELOADER: dict[int, Callable] = {} @@ -464,7 +466,7 @@ def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig: path = self.get_temp_dir("db") file_name: str = f"{self.id}.sqlite" if self.dev_mode: - print(f"{store_type}'s SQLite DB path: {path/file_name}") + logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") return SQLiteStoreConfig( client_config=SQLiteStoreClientConfig( filename=file_name, @@ -535,7 +537,7 @@ def create_queue_config( queue_config_ = queue_config elif queue_port is not None or n_consumers > 0 or create_producer: if not create_producer and queue_port is None: - print("No queue port defined to bind consumers.") + logger.warn("No queue port defined to bind consumers.") queue_config_ = ZMQQueueConfig( client_config=ZMQClientConfig( create_producer=create_producer, @@ -590,7 +592,7 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None: else: # Create consumer for given worker pool syft_worker_uid = get_syft_worker_uid() - print( + logger.info( f"Running as consumer with uid={syft_worker_uid} service={service_name}" ) @@ -750,9 +752,8 @@ def find_and_migrate_data(self) -> None: ) if object_pending_migration: - print( - "Object in Document Store that needs migration: ", - object_pending_migration, + logger.debug( + f"Object in Document Store that needs migration: {object_pending_migration}" ) # Migrate data for objects in document store @@ -762,7 +763,7 @@ def find_and_migrate_data(self) -> None: if object_partition is None: continue - print(f"Migrating data for: {canonical_name} table.") + logger.debug(f"Migrating data for: {canonical_name} table.") migration_status = object_partition.migrate_data( to_klass=object_type, context=context ) @@ -779,9 +780,8 @@ def find_and_migrate_data(self) -> None: ) if action_object_pending_migration: - print( - "Object in Action Store that needs migration: ", - action_object_pending_migration, + logger.info( + f"Object in Action Store that needs migration: {action_object_pending_migration}", ) # Migrate data for objects in action store @@ -795,7 +795,7 @@ def find_and_migrate_data(self) -> None: raise Exception( f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" ) - print("Data Migrated to latest version !!!") + logger.info("Data Migrated to latest version !!!") @property def guest_client(self) -> SyftClient: @@ -817,7 +817,7 @@ def get_guest_client(self, verbose: bool = True) -> SyftClient: ) if self.node_type: message += f"side {self.node_type.value.capitalize()} > as GUEST" - print(message) + logger.debug(message) client_type = connection.get_client_type() if isinstance(client_type, SyftError): @@ -1265,6 +1265,7 @@ def handle_api_call_with_unsigned_result( _private_api_path = user_config_registry.private_path_for(api_call.path) method = self.get_service_method(_private_api_path) try: + logger.info(f"API Call: {api_call}") result = method(context, *api_call.args, **api_call.kwargs) except PySyftException as e: return e.handle() @@ -1604,7 +1605,9 @@ def create_initial_settings(self, admin_email: str) -> NodeSettings | None: try: settings_stash = SettingsStash(store=self.document_store) if self.signing_key is None: - print("create_initial_settings failed as there is no signing key") + logger.debug( + "create_initial_settings failed as there is no signing key" + ) return None settings_exists = settings_stash.get_all(self.signing_key.verify_key).ok() if settings_exists: @@ -1639,7 +1642,7 @@ def create_initial_settings(self, admin_email: str) -> NodeSettings | None: return result.ok() return None except Exception as e: - print(f"create_initial_settings failed with error {e}") + logger.error("create_initial_settings failed", exc_info=e) return None @@ -1679,7 +1682,7 @@ def create_admin_new( else: raise Exception(f"Could not create user: {result}") except Exception as e: - print("Unable to create new admin", e) + logger.error("Unable to create new admin", exc_info=e) return None @@ -1739,11 +1742,12 @@ def create_default_worker_pool(node: Node) -> SyftError | None: if isinstance(default_worker_pool, SyftError): logger.error( - f"Failed to get default worker pool {default_pool_name}. Error: {default_worker_pool.message}" + f"Failed to get default worker pool {default_pool_name}. " + f"Error: {default_worker_pool.message}" ) return default_worker_pool - print(f"Creating default worker image with tag='{default_worker_tag}'") + logger.info(f"Creating default worker image with tag='{default_worker_tag}'") # Get/Create a default worker SyftWorkerImage default_image = create_default_image( credentials=credentials, @@ -1752,11 +1756,11 @@ def create_default_worker_pool(node: Node) -> SyftError | None: in_kubernetes=in_kubernetes(), ) if isinstance(default_image, SyftError): - print("Failed to create default worker image: ", default_image.message) + logger.error(f"Failed to create default worker image: {default_image.message}") return default_image if not default_image.is_built: - print(f"Building default worker image with tag={default_worker_tag}") + logger.info(f"Building default worker image with tag={default_worker_tag}") image_build_method = node.get_service_method(SyftWorkerImageService.build) # Build the Image for given tag result = image_build_method( @@ -1767,11 +1771,11 @@ def create_default_worker_pool(node: Node) -> SyftError | None: ) if isinstance(result, SyftError): - print("Failed to build default worker image: ", result.message) + logger.error(f"Failed to build default worker image: {result.message}") return None # Create worker pool if it doesn't exists - print( + logger.info( "Setting up worker pool" f"name={default_pool_name} " f"workers={worker_count} " @@ -1802,17 +1806,17 @@ def create_default_worker_pool(node: Node) -> SyftError | None: ) if isinstance(result, SyftError): - print(f"Default worker pool error. {result.message}") + logger.info(f"Default worker pool error. {result.message}") return None for n in range(worker_to_add_): container_status = result[n] if container_status.error: - print( + logger.error( f"Failed to create container: Worker: {container_status.worker}," f"Error: {container_status.error}" ) return None - print("Created default worker pool.") + logger.info("Created default worker pool.") return None diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 5b25774ff18..8be45245190 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -1,6 +1,7 @@ # stdlib import base64 import binascii +import logging from typing import Annotated # third party @@ -12,7 +13,6 @@ from fastapi import Response from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse -from loguru import logger from pydantic import ValidationError import requests @@ -34,6 +34,8 @@ from .credentials import UserLoginCredentials from .worker import Worker +logger = logging.getLogger(__name__) + def make_routes(worker: Worker) -> APIRouter: if TRACE_MODE: @@ -42,8 +44,8 @@ def make_routes(worker: Worker) -> APIRouter: # third party from opentelemetry import trace from opentelemetry.propagate import extract - except Exception: - print("Failed to import opentelemetry") + except Exception as e: + logger.error("Failed to import opentelemetry", exc_info=e) router = APIRouter() @@ -171,7 +173,7 @@ def handle_login(email: str, password: str, node: AbstractNode) -> Response: result = method(context=context) if isinstance(result, SyftError): - logger.bind(payload={"email": email}).error(result.message) + logger.error(f"Login Error: {result.message}. user={email}") response = result else: user_private_key = result @@ -196,7 +198,9 @@ def handle_register(data: bytes, node: AbstractNode) -> Response: result = method(new_user=user_create) if isinstance(result, SyftError): - logger.bind(payload={"user": user_create}).error(result.message) + logger.error( + f"Register Error: {result.message}. user={user_create.model_dump()}" + ) response = SyftError(message=f"{result.message}") else: response = result diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 3da97e4b0a2..43b8359a1f9 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -2,7 +2,6 @@ import asyncio from collections.abc import Callable from enum import Enum -import logging import multiprocessing import os import platform @@ -144,14 +143,7 @@ async def _run_uvicorn( except Exception: # nosec print(f"Failed to kill python process on port: {port}") - log_level = "critical" - if dev_mode: - log_level = "info" - logging.getLogger("uvicorn").setLevel(logging.CRITICAL) - logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL) - config = uvicorn.Config( - app, host=host, port=port, log_level=log_level, reload=dev_mode - ) + config = uvicorn.Config(app, host=host, port=port, reload=dev_mode) server = uvicorn.Server(config) await server.serve() diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index dc9eb40e81e..6e6043f4e3e 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -7,6 +7,7 @@ from enum import Enum import inspect from io import BytesIO +import logging from pathlib import Path import sys import threading @@ -46,7 +47,6 @@ from ...types.syncable_object import SyncableSyftObject from ...types.uid import LineageID from ...types.uid import UID -from ...util.logger import debug from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..response import SyftException @@ -59,6 +59,8 @@ from .action_types import action_type_for_type from .action_types import action_types +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from ..sync.diff_state import AttrDiff @@ -443,9 +445,10 @@ def make_action_side_effect( action_type=context.action_type, ) context.action = action - except Exception: - print(f"make_action_side_effect failed with {traceback.format_exc()}") - return Err(f"make_action_side_effect failed with {traceback.format_exc()}") + except Exception as e: + msg = "make_action_side_effect failed" + logger.error(msg, exc_info=e) + return Err(f"{msg} with {traceback.format_exc()}") return Ok((context, args, kwargs)) @@ -521,7 +524,7 @@ def convert_to_pointers( arg.syft_node_uid = node_uid r = arg._save_to_blob_storage() if isinstance(r, SyftError): - print(r.message) + logger.error(r.message) arg = api.services.action.set(arg) arg_list.append(arg) @@ -539,7 +542,7 @@ def convert_to_pointers( arg.syft_node_uid = node_uid r = arg._save_to_blob_storage() if isinstance(r, SyftError): - print(r.message) + logger.error(r.message) arg = api.services.action.set(arg) kwarg_dict[k] = arg @@ -772,9 +775,8 @@ def reload_cache(self) -> SyftError | None: uid=self.syft_blob_storage_entry_id ) if isinstance(blob_retrieval_object, SyftError): - print( - "Could not fetch actionobject data\n", - blob_retrieval_object, + logger.error( + f"Could not fetch actionobject data: {blob_retrieval_object}" ) return blob_retrieval_object # relative @@ -839,13 +841,15 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: blob_deposit_object.blob_storage_entry_id ) else: - print("cannot save to blob storage") + logger.warn("cannot save to blob storage. allocate_method=None") self.syft_action_data_type = type(data) self._set_reprs(data) self.syft_has_bool_attr = hasattr(data, "__bool__") else: - debug("skipping writing action object to store, passed data was empty.") + logger.debug( + "skipping writing action object to store, passed data was empty." + ) self.syft_action_data_cache = data @@ -1575,7 +1579,7 @@ def _syft_run_pre_hooks__( if result.is_ok(): context, result_args, result_kwargs = result.ok() else: - debug(f"Pre-hook failed with {result.err()}") + logger.debug(f"Pre-hook failed with {result.err()}") if name not in self._syft_dont_wrap_attrs(): if HOOK_ALWAYS in self.syft_pre_hooks__: for hook in self.syft_pre_hooks__[HOOK_ALWAYS]: @@ -1584,7 +1588,7 @@ def _syft_run_pre_hooks__( context, result_args, result_kwargs = result.ok() else: msg = result.err().replace("\\n", "\n") - debug(f"Pre-hook failed with {msg}") + logger.debug(f"Pre-hook failed with {msg}") if self.is_pointer: if name not in self._syft_dont_wrap_attrs(): @@ -1595,7 +1599,7 @@ def _syft_run_pre_hooks__( context, result_args, result_kwargs = result.ok() else: msg = result.err().replace("\\n", "\n") - debug(f"Pre-hook failed with {msg}") + logger.debug(f"Pre-hook failed with {msg}") return context, result_args, result_kwargs @@ -1610,7 +1614,7 @@ def _syft_run_post_hooks__( if result.is_ok(): new_result = result.ok() else: - debug(f"Post hook failed with {result.err()}") + logger.debug(f"Post hook failed with {result.err()}") if name not in self._syft_dont_wrap_attrs(): if HOOK_ALWAYS in self.syft_post_hooks__: @@ -1619,7 +1623,7 @@ def _syft_run_post_hooks__( if result.is_ok(): new_result = result.ok() else: - debug(f"Post hook failed with {result.err()}") + logger.debug(f"Post hook failed with {result.err()}") if self.is_pointer: if name not in self._syft_dont_wrap_attrs(): @@ -1629,7 +1633,7 @@ def _syft_run_post_hooks__( if result.is_ok(): new_result = result.ok() else: - debug(f"Post hook failed with {result.err()}") + logger.debug(f"Post hook failed with {result.err()}") return new_result @@ -1721,7 +1725,7 @@ def _syft_wrap_attribute_for_bool_on_nonbools(self, name: str) -> Any: "[_wrap_attribute_for_bool_on_nonbools] self.syft_action_data already implements the bool operator" ) - debug("[__getattribute__] Handling bool on nonbools") + logger.debug("[__getattribute__] Handling bool on nonbools") context = PreHookContext( obj=self, op_name=name, @@ -1754,7 +1758,7 @@ def _syft_wrap_attribute_for_properties(self, name: str) -> Any: raise RuntimeError( "[_wrap_attribute_for_properties] Use this only on properties" ) - debug(f"[__getattribute__] Handling property {name} ") + logger.debug(f"[__getattribute__] Handling property {name}") context = PreHookContext( obj=self, @@ -1778,7 +1782,7 @@ def _syft_wrap_attribute_for_methods(self, name: str) -> Any: def fake_func(*args: Any, **kwargs: Any) -> Any: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) - debug(f"[__getattribute__] Handling method {name} ") + logger.debug(f"[__getattribute__] Handling method {name}") if ( issubclass(self.syft_action_data_type, ActionDataEmpty) and name not in action_data_empty_must_run @@ -1815,20 +1819,20 @@ def _base_wrapper(*args: Any, **kwargs: Any) -> Any: return post_result if inspect.ismethod(original_func) or inspect.ismethoddescriptor(original_func): - debug("Running method: ", name) + logger.debug(f"Running method: {name}") def wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any: return _base_wrapper(*args, **kwargs) wrapper = types.MethodType(wrapper, type(self)) else: - debug("Running non-method: ", name) + logger.debug(f"Running non-method: {name}") wrapper = _base_wrapper try: wrapper.__doc__ = original_func.__doc__ - debug( + logger.debug( "Found original signature for ", name, inspect.signature(original_func), @@ -1837,7 +1841,7 @@ def wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any: original_func ) except Exception: - debug("name", name, "has no signature") + logger.debug(f"name={name} has no signature") # third party return wrapper @@ -1931,7 +1935,7 @@ def is_link(self) -> bool: def __setattr__(self, name: str, value: Any) -> Any: defined_on_self = name in self.__dict__ or name in self.__private_attributes__ - debug(">> ", name, ", defined_on_self = ", defined_on_self) + logger.debug(f">> {name} defined_on_self={defined_on_self}") # use the custom defined version if defined_on_self: @@ -2180,13 +2184,13 @@ def __int__(self) -> float: def debug_original_func(name: str, func: Callable) -> None: - debug(f"{name} func is:") - debug("inspect.isdatadescriptor", inspect.isdatadescriptor(func)) - debug("inspect.isgetsetdescriptor", inspect.isgetsetdescriptor(func)) - debug("inspect.isfunction", inspect.isfunction(func)) - debug("inspect.isbuiltin", inspect.isbuiltin(func)) - debug("inspect.ismethod", inspect.ismethod(func)) - debug("inspect.ismethoddescriptor", inspect.ismethoddescriptor(func)) + logger.debug(f"{name} func is:") + logger.debug(f"inspect.isdatadescriptor = {inspect.isdatadescriptor(func)}") + logger.debug(f"inspect.isgetsetdescriptor = {inspect.isgetsetdescriptor(func)}") + logger.debug(f"inspect.isfunction = {inspect.isfunction(func)}") + logger.debug(f"inspect.isbuiltin = {inspect.isbuiltin(func)}") + logger.debug(f"inspect.ismethod = {inspect.ismethod(func)}") + logger.debug(f"inspect.ismethoddescriptor = {inspect.ismethoddescriptor(func)}") def is_action_data_empty(obj: Any) -> bool: diff --git a/packages/syft/src/syft/service/action/action_types.py b/packages/syft/src/syft/service/action/action_types.py index 9721a48ec8e..c7bd730d557 100644 --- a/packages/syft/src/syft/service/action/action_types.py +++ b/packages/syft/src/syft/service/action/action_types.py @@ -1,10 +1,12 @@ # stdlib +import logging from typing import Any # relative -from ...util.logger import debug from .action_data_empty import ActionDataEmpty +logger = logging.getLogger(__name__) + action_types: dict = {} @@ -21,7 +23,9 @@ def action_type_for_type(obj_or_type: Any) -> type: obj_or_type = type(obj_or_type) if obj_or_type not in action_types: - debug(f"WARNING: No Type for {obj_or_type}, returning {action_types[Any]}") + logger.debug( + f"WARNING: No Type for {obj_or_type}, returning {action_types[Any]}" + ) return action_types.get(obj_or_type, action_types[Any]) @@ -36,7 +40,7 @@ def action_type_for_object(obj: Any) -> type: _type = type(obj) if _type not in action_types: - debug(f"WARNING: No Type for {_type}, returning {action_types[Any]}") + logger.debug(f"WARNING: No Type for {_type}, returning {action_types[Any]}") return action_types[Any] return action_types[_type] diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index ac329420168..c32374ade31 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -1,11 +1,11 @@ # stdlib from collections.abc import Callable from enum import Enum +import logging import secrets from typing import Any # third party -from loguru import logger from result import Result # relative @@ -56,6 +56,8 @@ from .routes import NodeRouteType from .routes import PythonNodeRoute +logger = logging.getLogger(__name__) + VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) NodeTypePartitionKey = PartitionKey(key="node_type", type_=NodeType) OrderByNamePartitionKey = PartitionKey(key="name", type_=str) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index c2db506ba23..5835cf7aa9e 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -1,6 +1,7 @@ # stdlib from collections.abc import Callable from enum import Enum +import logging # third party from result import Err @@ -35,6 +36,8 @@ from .routes import connection_to_route from .routes import route_to_connection +logger = logging.getLogger(__name__) + @serializable() class NodePeerConnectionStatus(Enum): @@ -245,7 +248,6 @@ def client_with_context( self, context: NodeServiceContext ) -> Result[type[SyftClient], str]: # third party - from loguru import logger if len(self.node_routes) < 1: raise ValueError(f"No routes to peer: {self}") @@ -255,12 +257,11 @@ def client_with_context( try: client_type = connection.get_client_type() except Exception as e: - logger.error( - f"Failed to establish a connection with {self.node_type} '{self.name}'. Exception: {e}" - ) - return Err( + msg = ( f"Failed to establish a connection with {self.node_type} '{self.name}'" ) + logger.error(msg, exc_info=e) + return Err(msg) if isinstance(client_type, SyftError): return Err(client_type.message) return Ok( diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index b03bc589d15..476411bc6e6 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -1,11 +1,9 @@ # stdlib +import logging import threading import time from typing import cast -# third party -from loguru import logger - # relative from ...serde.serializable import serializable from ...types.datetime import DateTime @@ -17,6 +15,8 @@ from .node_peer import NodePeerConnectionStatus from .node_peer import NodePeerUpdate +logger = logging.getLogger(__name__) + @serializable(without=["thread"]) class PeerHealthCheckTask: @@ -63,9 +63,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | No peer_update.ping_status = NodePeerConnectionStatus.TIMEOUT peer_client = None except Exception as e: - logger.error( - f"Failed to create client for peer: {peer} with exception {e}" - ) + logger.error(f"Failed to create client for peer: {peer}", exc_info=e) peer_update.ping_status = NodePeerConnectionStatus.TIMEOUT peer_client = None @@ -97,7 +95,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | No ) if result.is_err(): - logger.info(f"Failed to update peer in stash: {result.err()}") + logger.error(f"Failed to update peer in stash: {result.err()}") return None diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index aedb59b2e24..4c10708f0f0 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -2,6 +2,9 @@ # stdlib +# stdlib +import logging + # third party from pydantic import EmailStr from result import Err @@ -22,6 +25,8 @@ from .notifier_enums import NOTIFIERS from .notifier_stash import NotifierStash +logger = logging.getLogger(__name__) + @serializable() class NotifierService(AbstractService): @@ -109,7 +114,7 @@ def turn_on( message="You must provide both server and port to enable notifications." ) - print("[LOG] Got notifier from db") + logging.debug("Got notifier from db") # If no new credentials provided, check for existing ones if not (email_username and email_password): if not (notifier.email_username and notifier.email_password): @@ -119,10 +124,9 @@ def turn_on( + ".settings.enable_notifications(email=<>, password=<>)" ) else: - print("[LOG] No new credentials provided. Using existing ones.") + logging.debug("No new credentials provided. Using existing ones.") email_password = notifier.email_password email_username = notifier.email_username - print("[LOG] Validating credentials...") validation_result = notifier.validate_email_credentials( username=email_username, @@ -132,6 +136,7 @@ def turn_on( ) if validation_result.is_err(): + logging.error(f"Invalid SMTP credentials {validation_result.err()}") return SyftError( message="Invalid SMTP credentials. Please check your username and password." ) @@ -160,8 +165,8 @@ def turn_on( notifier.email_sender = email_sender notifier.active = True - print( - "[LOG] Email credentials are valid. Updating the notifier settings in the db." + logging.debug( + "Email credentials are valid. Updating the notifier settings in the db." ) result = self.stash.update(credentials=context.credentials, settings=notifier) @@ -260,9 +265,8 @@ def init_notifier( sender_not_set = not email_sender and not notifier.email_sender if validation_result.is_err() or sender_not_set: - print( - "Ops something went wrong while trying to setup your notification system.", - "Please check your credentials and configuration.", + logger.error( + f"Notifier validation error - {validation_result.err()}.", ) notifier.active = False else: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 968e4b7c975..8793b49ba0a 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -1,13 +1,12 @@ # stdlib +import logging from multiprocessing import Process import threading from threading import Thread import time from typing import Any -from typing import cast # third party -from loguru import logger import psutil from result import Err from result import Ok @@ -34,6 +33,8 @@ from .queue_stash import QueueItem from .queue_stash import Status +logger = logging.getLogger(__name__) + class MonitorThread(threading.Thread): def __init__( @@ -297,7 +298,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: queue_item.node_uid = worker.id job_item.status = JobStatus.PROCESSING - job_item.node_uid = cast(UID, worker.id) + job_item.node_uid = worker.id # type: ignore[assignment] job_item.updated_at = DateTime.now() if syft_worker_id is not None: diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 4559832f199..08ff386696e 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -2,6 +2,7 @@ from binascii import hexlify from collections import defaultdict import itertools +import logging import socketserver import sys import threading @@ -10,7 +11,6 @@ from typing import Any # third party -from loguru import logger from pydantic import field_validator import zmq from zmq import Frame @@ -61,6 +61,8 @@ # Lock for working on ZMQ socket ZMQ_SOCKET_LOCK = threading.Lock() +logger = logging.getLogger(__name__) + class QueueMsgProtocol: W_WORKER = b"MDPW01" @@ -128,6 +130,13 @@ def get_expiry(self) -> float: def reset_expiry(self) -> None: self.expiry_t.reset() + def __str__(self) -> str: + svc = self.service.name if self.service else None + return ( + f"Worker(addr={self.address!r}, id={self.identity!r}, service={svc}, " + f"syft_worker_id={self.syft_worker_id!r})" + ) + @serializable() class ZMQProducer(QueueProducer): @@ -177,7 +186,7 @@ def close(self) -> None: try: self.poll_workers.unregister(self.socket) except Exception as e: - logger.exception("Failed to unregister poller. {}", e) + logger.exception("Failed to unregister poller.", exc_info=e) finally: if self.thread: self.thread.join(THREAD_TIMEOUT_SEC) @@ -232,7 +241,7 @@ def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bo return True return value except Exception as e: - logger.exception("Failed to resolve action objects. {}", e) + logger.exception("Failed to resolve action objects.", exc_info=e) return True def unwrap_nested_actionobjects(self, data: Any) -> Any: @@ -367,9 +376,7 @@ def read_items(self) -> None: res = self.queue_stash.update(item.syft_client_verify_key, item) if res.is_err(): logger.error( - "Failed to update queue item={} error={}", - item, - res.err(), + f"Failed to update queue item={item} error={res.err()}" ) elif item.status == Status.PROCESSING: # Evaluate Retry condition here @@ -384,9 +391,7 @@ def read_items(self) -> None: res = self.queue_stash.update(item.syft_client_verify_key, item) if res.is_err(): logger.error( - "Failed to update queue item={} error={}", - item, - res.err(), + f"Failed to update queue item={item} error={res.err()}" ) def run(self) -> None: @@ -398,18 +403,18 @@ def run(self) -> None: def send(self, worker: bytes, message: bytes | list[bytes]) -> None: worker_obj = self.require_worker(worker) - self.send_to_worker(worker=worker_obj, msg=message) + self.send_to_worker(worker_obj, QueueMsgProtocol.W_REQUEST, message) def bind(self, endpoint: str) -> None: """Bind producer to endpoint.""" self.socket.bind(endpoint) - logger.info("Producer endpoint: {}", endpoint) + logger.info(f"ZMQProducer endpoint: {endpoint}") def send_heartbeats(self) -> None: """Send heartbeats to idle workers if it's time""" if self.heartbeat_t.has_expired(): for worker in self.waiting: - self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT, None, None) + self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT) self.heartbeat_t.reset() def purge_workers(self) -> None: @@ -420,22 +425,15 @@ def purge_workers(self) -> None: # work on a copy of the iterator for worker in list(self.waiting): if worker.has_expired(): - logger.info( - "Deleting expired Worker id={} uid={} expiry={} now={}", - worker.identity, - worker.syft_worker_id, - worker.get_expiry(), - Timeout.now(), - ) + logger.info(f"Deleting expired worker id={worker}") self.delete_worker(worker, False) def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState ) -> None: if self.worker_stash is None: - # TODO: fix the mypy issue logger.error( # type: ignore[unreachable] - f"Worker stash is not defined for ZMQProducer : {self.queue_name} - {self.id}" + f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" ) return @@ -455,14 +453,13 @@ def update_consumer_state_for_worker( ) if res.is_err(): logger.error( - "Failed to update consumer state for worker id={} to state: {} error={}", - syft_worker_id, - consumer_state, - res.err(), + f"Failed to update consumer state for worker id={syft_worker_id} " + f"to state: {consumer_state} error={res.err()}", ) except Exception as e: logger.error( - f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}. Error: {e}" + f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}", + exc_info=e, ) def worker_waiting(self, worker: Worker) -> None: @@ -487,13 +484,12 @@ def dispatch(self, service: Service, msg: bytes) -> None: msg = service.requests.pop(0) worker = service.waiting.pop(0) self.waiting.remove(worker) - self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, None, msg) + self.send_to_worker(worker, QueueMsgProtocol.W_REQUEST, msg) def send_to_worker( self, worker: Worker, - command: bytes = QueueMsgProtocol.W_REQUEST, - option: bytes | None = None, + command: bytes, msg: bytes | list | None = None, ) -> None: """Send message to worker. @@ -510,50 +506,60 @@ def send_to_worker( elif not isinstance(msg, list): msg = [msg] - # Stack routing and protocol envelopes to start of message - # and routing envelope - if option is not None: - msg = [option] + msg - msg = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] + msg + # ZMQProducer send frames: [address, empty, header, command, ...data] + core = [worker.address, b"", QueueMsgProtocol.W_WORKER, command] + msg = core + msg + + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer send: {core}") - logger.debug("Send: {}", msg) with ZMQ_SOCKET_LOCK: try: self.socket.send_multipart(msg) except zmq.ZMQError as e: - logger.error("Failed to send message to producer. {}", e) + logger.error("ZMQProducer send error", exc_info=e) def _run(self) -> None: - while True: - if self._stop.is_set(): - return + try: + while True: + if self._stop.is_set(): + logger.info("ZMQProducer thread stopped") + return - for service in self.services.values(): - self.dispatch(service, None) + for service in self.services.values(): + self.dispatch(service, None) - items = None + items = None - try: - items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) - except Exception as e: - logger.exception("Failed to poll items: {}", e) + try: + items = self.poll_workers.poll(ZMQ_POLLER_TIMEOUT_MSEC) + except Exception as e: + logger.exception("ZMQProducer poll error", exc_info=e) - if items: - msg = self.socket.recv_multipart() + if items: + msg = self.socket.recv_multipart() + + if len(msg) < 3: + logger.error(f"ZMQProducer invalid recv: {msg}") + continue - logger.debug("Recieve: {}", msg) + # ZMQProducer recv frames: [address, empty, header, command, ...data] + (address, _, header, command, *data) = msg - address = msg.pop(0) - empty = msg.pop(0) # noqa: F841 - header = msg.pop(0) + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQProducer recv: {msg[:4]}") - if header == QueueMsgProtocol.W_WORKER: - self.process_worker(address, msg) - else: - logger.error("Invalid message header: {}", header) + if header == QueueMsgProtocol.W_WORKER: + self.process_worker(address, command, data) + else: + logger.error(f"Invalid message header: {header}") - self.send_heartbeats() - self.purge_workers() + self.send_heartbeats() + self.purge_workers() + except Exception as e: + logger.exception("ZMQProducer thread exception", exc_info=e) def require_worker(self, address: bytes) -> Worker: """Finds the worker (creates if necessary).""" @@ -564,16 +570,13 @@ def require_worker(self, address: bytes) -> Worker: self.workers[identity] = worker return worker - def process_worker(self, address: bytes, msg: list[bytes]) -> None: - command = msg.pop(0) - + def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> None: worker_ready = hexlify(address) in self.workers - worker = self.require_worker(address) if QueueMsgProtocol.W_READY == command: - service_name = msg.pop(0).decode() - syft_worker_id = msg.pop(0).decode() + service_name = data.pop(0).decode() + syft_worker_id = data.pop(0).decode() if worker_ready: # Not first command in session or Reserved service name # If worker was already present, then we disconnect it first @@ -589,18 +592,7 @@ def process_worker(self, address: bytes, msg: list[bytes]) -> None: self.services[service_name] = service if service is not None: worker.service = service - logger.info( - "New Worker service={}, id={}, uid={}", - service.name, - worker.identity, - worker.syft_worker_id, - ) - else: - logger.info( - "New Worker service=None, id={}, uid={}", - worker.identity, - worker.syft_worker_id, - ) + logger.info(f"New worker: {worker}") worker.syft_worker_id = UID(syft_worker_id) self.worker_waiting(worker) @@ -611,19 +603,18 @@ def process_worker(self, address: bytes, msg: list[bytes]) -> None: # if not already present self.worker_waiting(worker) else: - # extract the syft worker id and worker pool name from the message - # Get the corresponding worker pool and worker - # update the status to be unhealthy + logger.info(f"Got heartbeat, but worker not ready. {worker}") self.delete_worker(worker, True) elif QueueMsgProtocol.W_DISCONNECT == command: + logger.info(f"Removing disconnected worker: {worker}") self.delete_worker(worker, False) else: - logger.error("Invalid command: {}", command) + logger.error(f"Invalid command: {command!r}") def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" if disconnect: - self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT, None, None) + self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT) if worker.service and worker in worker.service.waiting: worker.service.waiting.remove(worker) @@ -680,13 +671,12 @@ def reconnect_to_producer(self) -> None: self.socket.connect(self.address) self.poller.register(self.socket, zmq.POLLIN) - logger.info("Connecting Worker id={} to broker addr={}", self.id, self.address) + logger.info(f"Connecting Worker id={self.id} to broker addr={self.address}") # Register queue with the producer self.send_to_producer( QueueMsgProtocol.W_READY, - self.service_name.encode(), - [str(self.syft_worker_id).encode()], + [self.service_name.encode(), str(self.syft_worker_id).encode()], ) def post_init(self) -> None: @@ -704,7 +694,7 @@ def close(self) -> None: try: self.poller.unregister(self.socket) except Exception as e: - logger.exception("Failed to unregister worker. {}", e) + logger.exception("Failed to unregister worker.", exc_info=e) finally: if self.thread is not None: self.thread.join(timeout=THREAD_TIMEOUT_SEC) @@ -715,8 +705,7 @@ def close(self) -> None: def send_to_producer( self, - command: str, - option: bytes | None = None, + command: bytes, msg: bytes | list | None = None, ) -> None: """Send message to producer. @@ -732,23 +721,25 @@ def send_to_producer( elif not isinstance(msg, list): msg = [msg] - if option: - msg = [option] + msg + # ZMQConsumer send frames: [empty, header, command, ...data] + core = [b"", QueueMsgProtocol.W_WORKER, command] + msg = core + msg - msg = [b"", QueueMsgProtocol.W_WORKER, command] + msg - logger.debug("Send: msg={}", msg) + if command != QueueMsgProtocol.W_HEARTBEAT: + logger.info(f"ZMQ Consumer send: {core}") with ZMQ_SOCKET_LOCK: try: self.socket.send_multipart(msg) except zmq.ZMQError as e: - logger.error("Failed to send message to producer. {}", e) + logger.error("ZMQConsumer send error", exc_info=e) def _run(self) -> None: """Send reply, if any, to producer and wait for next request.""" try: while True: if self._stop.is_set(): + logger.info("ZMQConsumer thread stopped") return try: @@ -757,39 +748,38 @@ def _run(self) -> None: logger.info("Context terminated") return except Exception as e: - logger.error("Poll error={}", e) + logger.error("ZMQ poll error", exc_info=e) continue if items: - # Message format: - # [b"", "
", "", "", ""] msg = self.socket.recv_multipart() - logger.debug("Recieve: {}", msg) - # mark as alive self.set_producer_alive() if len(msg) < 3: - logger.error("Invalid message: {}", msg) + logger.error(f"ZMQConsumer invalid recv: {msg}") continue - empty = msg.pop(0) # noqa: F841 - header = msg.pop(0) # noqa: F841 + # Message frames recieved by consumer: + # [empty, header, command, ...data] + (_, _, command, *data) = msg - command = msg.pop(0) + if command != QueueMsgProtocol.W_HEARTBEAT: + # log everything except the last frame which contains serialized data + logger.info(f"ZMQConsumer recv: {msg[:-4]}") if command == QueueMsgProtocol.W_REQUEST: # Call Message Handler try: - message = msg.pop() + message = data.pop() self.associate_job(message) self.message_handler.handle_message( message=message, syft_worker_id=self.syft_worker_id, ) except Exception as e: - logger.exception("Error while handling message. {}", e) + logger.exception("Couldn't handle message", exc_info=e) finally: self.clear_job() elif command == QueueMsgProtocol.W_HEARTBEAT: @@ -797,7 +787,7 @@ def _run(self) -> None: elif command == QueueMsgProtocol.W_DISCONNECT: self.reconnect_to_producer() else: - logger.error("Invalid command: {}", command) + logger.error(f"ZMQConsumer invalid command: {command}") else: if not self.is_producer_alive(): logger.info("Producer check-alive timed out. Reconnecting.") @@ -808,12 +798,11 @@ def _run(self) -> None: except zmq.ZMQError as e: if e.errno == zmq.ETERM: - logger.info("Consumer connection terminated") + logger.info("zmq.ETERM") else: - logger.exception("Consumer error. {}", e) - raise e - - logger.info("Worker finished") + logger.exception("zmq.ZMQError", exc_info=e) + except Exception as e: + logger.exception("ZMQConsumer thread exception", exc_info=e) def set_producer_alive(self) -> None: self.producer_ping_t.reset() @@ -836,7 +825,7 @@ def associate_job(self, message: Frame) -> None: queue_item = _deserialize(message, from_bytes=True) self._set_worker_job(queue_item.job_id) except Exception as e: - logger.exception("Could not associate job. {}", e) + logger.exception("Could not associate job", exc_info=e) def clear_job(self) -> None: self._set_worker_job(None) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 882dd243ec4..d82e1207727 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -3,6 +3,7 @@ from enum import Enum import hashlib import inspect +import logging from typing import Any # third party @@ -58,6 +59,8 @@ from ..response import SyftSuccess from ..user.user import UserView +logger = logging.getLogger(__name__) + @serializable() class RequestStatus(Enum): @@ -158,7 +161,7 @@ def _run( permission=self.apply_permission_type, ) if apply: - print( + logger.debug( "ADDING PERMISSION", requesting_permission_action_obj, id_action ) action_store.add_permission(requesting_permission_action_obj) @@ -182,7 +185,7 @@ def _run( ) return Ok(SyftSuccess(message=f"{type(self)} Success")) except Exception as e: - print(f"failed to apply {type(self)}", e) + logger.error(f"failed to apply {type(self)}", exc_info=e) return Err(SyftError(message=str(e))) def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: @@ -1317,7 +1320,7 @@ def _run( self.linked_obj.update_with_context(context, updated_status) return Ok(SyftSuccess(message=f"{type(self)} Success")) except Exception as e: - print(f"failed to apply {type(self)}. {e}") + logger.error(f"failed to apply {type(self)}", exc_info=e) return Err(SyftError(message=str(e))) def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index cda115cb8b4..c92695e2f6a 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -8,6 +8,7 @@ from functools import partial import inspect from inspect import Parameter +import logging from typing import Any from typing import TYPE_CHECKING @@ -43,6 +44,8 @@ from .user.user_roles import ServiceRole from .warnings import APIEndpointWarning +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from ..client.api import APIModule @@ -491,5 +494,5 @@ def from_api_or_context( ) return partial(service_method, node_context) else: - print("Could not get method from api or context") + logger.error("Could not get method from api or context") return None diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index 94adfbf307c..2db395ce9e5 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +import logging from typing import Any # third party @@ -29,6 +30,8 @@ from ...util.schema import DEFAULT_WELCOME_MSG from ..response import SyftInfo +logger = logging.getLogger(__name__) + @serializable() class NodeSettingsUpdateV4(PartialSyftObject): @@ -54,8 +57,8 @@ def validate_node_side_type(cls, v: str) -> type[Empty]: as information might be leaked." try: display(SyftInfo(message=msg)) - except Exception: - print(SyftInfo(message=msg)) + except Exception as e: + logger.error(msg, exc_info=e) return Empty diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 24d89af2fd6..031f2240376 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import enum import html +import logging import operator import textwrap from typing import Any @@ -13,7 +14,6 @@ from typing import TYPE_CHECKING # third party -from loguru import logger import pandas as pd from rich import box from rich.console import Console @@ -61,6 +61,8 @@ from ..user.user import UserView from .sync_state import SyncState +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from .resolve_widget import PaginatedResolveWidget @@ -509,7 +511,6 @@ def _repr_html_(self) -> str: obj_repr += diff.__repr__() + "
" obj_repr = obj_repr.replace("\n", "
") - # print("New lines", res) attr_text = f"

{self.object_type} ObjectDiff:

\n{obj_repr}" return base_str + attr_text @@ -1060,7 +1061,7 @@ def stage_change(self) -> None: other_batch.decision == SyncDecision.IGNORE and other_batch.root_id in required_dependencies ): - print(f"ignoring other batch ({other_batch.root_type.__name__})") + logger.debug(f"ignoring other batch ({other_batch.root_type.__name__})") other_batch.decision = None @@ -1282,7 +1283,7 @@ def apply_previous_ignore_state( if hash(batch) == batch_hash: batch.decision = SyncDecision.IGNORE else: - print( + logger.debug( f"""A batch with type {batch.root_type.__name__} was previously ignored but has changed It will be available for review again.""" ) @@ -1409,7 +1410,7 @@ def _create_batches( # TODO: Figure out nested user codes, do we even need that? root_ids.append(diff.object_id) # type: ignore - elif ( + elif ( # type: ignore[unreachable] isinstance(diff_obj, Job) # type: ignore and diff_obj.parent_job_id is None # ignore Job objects created by TwinAPIEndpoint diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index db50c2a7a61..e452b8b0b8e 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -1,9 +1,9 @@ # stdlib from collections import defaultdict +import logging from typing import Any # third party -from loguru import logger from result import Err from result import Ok from result import Result @@ -36,6 +36,8 @@ from .sync_stash import SyncStash from .sync_state import SyncState +logger = logging.getLogger(__name__) + def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> Any: if isinstance(item, ActionObject): diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index c952cbe8c13..c9b930c353c 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -1,5 +1,6 @@ # stdlib import contextlib +import logging import os from pathlib import Path import socket @@ -34,6 +35,8 @@ from .worker_pool import WorkerOrchestrationType from .worker_pool import WorkerStatus +logger = logging.getLogger(__name__) + DEFAULT_WORKER_IMAGE_TAG = "openmined/default-worker-image-cpu:0.0.1" DEFAULT_WORKER_POOL_NAME = "default-pool" K8S_NODE_CREDS_NAME = "node-creds" @@ -261,9 +264,9 @@ def run_workers_in_threads( address=address, ) except Exception as e: - print( - "Failed to start consumer for " - f"pool={pool_name} worker={worker_name}. Error: {e}" + logger.error( + f"Failed to start consumer for pool={pool_name} worker={worker_name}", + exc_info=e, ) worker.status = WorkerStatus.STOPPED error = str(e) @@ -335,12 +338,7 @@ def create_kubernetes_pool( pool = None try: - print( - "Creating new pool " - f"name={pool_name} " - f"tag={tag} " - f"replicas={replicas}" - ) + logger.info(f"Creating new pool name={pool_name} tag={tag} replicas={replicas}") env_vars, mount_secrets = prepare_kubernetes_pool_env( runner, @@ -391,7 +389,7 @@ def scale_kubernetes_pool( return SyftError(message=f"Pool does not exist. name={pool_name}") try: - print(f"Scaling pool name={pool_name} to replicas={replicas}") + logger.info(f"Scaling pool name={pool_name} to replicas={replicas}") runner.scale_pool(pool_name=pool_name, replicas=replicas) except Exception as e: return SyftError(message=f"Failed to scale workers {e}") @@ -520,7 +518,7 @@ def run_containers( if not worker_image.is_built: return SyftError(message="Image must be built before running it.") - print(f"Starting workers with start_idx={start_idx} count={number}") + logger.info(f"Starting workers with start_idx={start_idx} count={number}") if orchestration == WorkerOrchestrationType.DOCKER: with contextlib.closing(docker.from_env()) as client: diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index d42645a19bb..a44cf2e2d82 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -1,4 +1,5 @@ # stdlib +import logging from typing import Any # third party @@ -45,6 +46,8 @@ from .worker_service import WorkerService from .worker_stash import WorkerStash +logger = logging.getLogger(__name__) + @serializable() class SyftWorkerPoolService(AbstractService): @@ -527,7 +530,7 @@ def scale( uid=worker.object_uid, ) if delete_result.is_err(): - print(f"Failed to delete worker: {worker.object_uid}") + logger.error(f"Failed to delete worker: {worker.object_uid}") # update worker_pool worker_pool.max_count = number diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index 15658ad4c8c..b9677eda95b 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -44,6 +44,7 @@ from collections.abc import Callable from collections.abc import Generator from io import BytesIO +import logging from typing import Any # third party @@ -74,6 +75,8 @@ from ...types.transforms import make_set_default from ...types.uid import UID +logger = logging.getLogger(__name__) + DEFAULT_TIMEOUT = 10 MAX_RETRIES = 20 @@ -138,11 +141,11 @@ def syft_iter_content( return # If successful, exit the function except requests.exceptions.RequestException as e: if attempt < max_retries: - print( + logger.debug( f"Attempt {attempt}/{max_retries} failed: {e} at byte {current_byte}. Retrying..." ) else: - print(f"Max retries reached. Failed with error: {e}") + logger.error(f"Max retries reached - {e}") raise diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index a63ed8a2d67..03c6f442c26 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -1,6 +1,7 @@ # stdlib from collections.abc import Generator from io import BytesIO +import logging import math from queue import Queue import threading @@ -40,6 +41,8 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...util.constants import DEFAULT_TIMEOUT +logger = logging.getLogger(__name__) + MAX_QUEUE_SIZE = 100 WRITE_EXPIRATION_TIME = 900 # seconds DEFAULT_FILE_PART_SIZE = 1024**3 # 1GB @@ -149,7 +152,7 @@ def add_chunks_to_queue( etags.append({"ETag": etag, "PartNumber": part_no}) except requests.RequestException as e: - print(e) + logger.error(f"Failed to upload file to SeaweedFS - {e}") return SyftError(message=str(e)) mark_write_complete_method = from_api_or_context( diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index fea96e6d456..3d69024c7d4 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -350,7 +350,6 @@ def store_query_keys(self, objs: Any) -> QueryKeys: def _thread_safe_cbk(self, cbk: Callable, *args: Any, **kwargs: Any) -> Any | Err: locked = self.lock.acquire(blocking=True) if not locked: - print("FAILED TO LOCK") return Err( f"Failed to acquire lock for the operation {self.lock.lock_name} ({self.lock._lock})" ) diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index 6a29f6efdfb..48ae6ca1178 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -1,5 +1,6 @@ # stdlib from collections import defaultdict +import logging import threading import time from typing import Any @@ -11,6 +12,7 @@ # relative from ..serde.serializable import serializable +logger = logging.getLogger(__name__) THREAD_FILE_LOCKS: dict[int, dict[str, int]] = defaultdict(dict) @@ -190,7 +192,7 @@ def acquire(self, blocking: bool = True) -> bool: elapsed = time.time() - start_time else: return True - print( + logger.debug( f"Timeout elapsed after {self.timeout} seconds while trying to acquiring lock." ) # third party diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 96a0b70b81f..e68b2f13710 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -4,6 +4,7 @@ # stdlib from collections import defaultdict from copy import deepcopy +import logging from pathlib import Path import sqlite3 import tempfile @@ -33,6 +34,8 @@ from .locks import NoLockingConfig from .locks import SyftLock +logger = logging.getLogger(__name__) + # here we can create a single connection per cache_key # since pytest is concurrent processes, we need to isolate each connection # by its filename and optionally the thread that its running in @@ -350,7 +353,7 @@ def __del__(self) -> None: try: self._close() except Exception as e: - print(f"Could not close connection. Error: {e}") + logger.error("Could not close connection", exc_info=e) @serializable() diff --git a/packages/syft/src/syft/types/grid_url.py b/packages/syft/src/syft/types/grid_url.py index 91cf53e46d7..040969c2730 100644 --- a/packages/syft/src/syft/types/grid_url.py +++ b/packages/syft/src/syft/types/grid_url.py @@ -3,6 +3,7 @@ # stdlib import copy +import logging import os import re from urllib.parse import urlparse @@ -15,6 +16,8 @@ from ..serde.serializable import serializable from ..util.util import verify_tls +logger = logging.getLogger(__name__) + @serializable(attrs=["protocol", "host_or_ip", "port", "path", "query"]) class GridURL: @@ -43,7 +46,7 @@ def from_url(cls, url: str | GridURL) -> GridURL: query=getattr(parts, "query", ""), ) except Exception as e: - print(f"Failed to convert url: {url} to GridURL. {e}") + logger.error(f"Failed to convert url: {url} to GridURL. {e}") raise e def __init__( diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index e4daf3a779f..9df3f22300c 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -10,6 +10,7 @@ from hashlib import sha256 import inspect from inspect import Signature +import logging import types from types import NoneType from types import UnionType @@ -48,6 +49,8 @@ from .syft_metaclass import PartialModelMetaclass from .uid import UID +logger = logging.getLogger(__name__) + if TYPE_CHECKING: # relative from ..client.api import SyftAPI @@ -611,8 +614,9 @@ def _syft_keys_types_dict(cls, attr_name: str) -> dict[str, type]: if isinstance(method, types.FunctionType): type_ = method.__annotations__["return"] except Exception as e: - print( - f"Failed to get attribute from key {key} type for {cls} storage. {e}" + logger.error( + f"Failed to get attribute from key {key} type for {cls} storage.", + exc_info=e, ) raise e # EmailStr seems to be lost every time the value is set even with a validator diff --git a/packages/syft/src/syft/types/uid.py b/packages/syft/src/syft/types/uid.py index b4aab67302e..cd3a0dafba5 100644 --- a/packages/syft/src/syft/types/uid.py +++ b/packages/syft/src/syft/types/uid.py @@ -5,6 +5,7 @@ from collections.abc import Callable from collections.abc import Sequence import hashlib +import logging from typing import Any import uuid from uuid import UUID as uuid_type @@ -14,8 +15,8 @@ # relative from ..serde.serializable import serializable -from ..util.logger import critical -from ..util.logger import traceback_and_raise + +logger = logging.getLogger(__name__) @serializable(attrs=["value"]) @@ -81,9 +82,8 @@ def from_string(value: str) -> UID: try: return UID(value=uuid.UUID(value)) except ValueError as e: - critical(f"Unable to convert {value} to UUID. {e}") - traceback_and_raise(e) - raise + logger.critical(f"Unable to convert {value} to UUID. {e}") + raise e @staticmethod def with_seed(value: str) -> UID: diff --git a/packages/syft/src/syft/util/logger.py b/packages/syft/src/syft/util/logger.py deleted file mode 100644 index d9f0611a6c6..00000000000 --- a/packages/syft/src/syft/util/logger.py +++ /dev/null @@ -1,134 +0,0 @@ -# stdlib -from collections.abc import Callable -import logging -import os -import sys -from typing import Any -from typing import NoReturn -from typing import TextIO - -# third party -from loguru import logger - -LOG_FORMAT = "[{time}][{level}][{module}]][{process.id}] {message}" - -logger.remove() -DEFAULT_SINK = "syft_{time}.log" - - -def remove() -> None: - logger.remove() - - -def add( - sink: None | str | os.PathLike | TextIO | logging.Handler = None, - level: str = "ERROR", -) -> None: - sink = DEFAULT_SINK if sink is None else sink - try: - logger.add( - sink=sink, - format=LOG_FORMAT, - enqueue=True, - colorize=False, - diagnose=True, - backtrace=True, - rotation="10 MB", - retention="1 day", - level=level, - ) - except BaseException: - logger.add( - sink=sink, - format=LOG_FORMAT, - colorize=False, - diagnose=True, - backtrace=True, - level=level, - ) - - -def start() -> None: - add(sink=sys.stderr, level="CRITICAL") - - -def stop() -> None: - logger.stop() - - -def traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn: - try: - if verbose: - logger.opt(lazy=True).exception(e) - else: - logger.opt(lazy=True).critical(e) - except BaseException as ex: - logger.debug("failed to print exception", ex) - if not issubclass(type(e), Exception): - e = Exception(e) - raise e - - -def create_log_and_print_function(level: str) -> Callable: - def log_and_print(*args: Any, **kwargs: Any) -> None: - try: - method = getattr(logger.opt(lazy=True), level, None) - if "print" in kwargs and kwargs["print"] is True: - del kwargs["print"] - print(*args, **kwargs) - if "end" in kwargs: - # clean up extra end for printinga - del kwargs["end"] - - if method is not None: - method(*args, **kwargs) - else: - raise Exception(f"no method {level} on logger") - except BaseException as e: - msg = f"failed to log exception. {e}" - try: - logger.debug(msg) - - except Exception as e: - print(f"{msg}. {e}") - - return log_and_print - - -def traceback(*args: Any, **kwargs: Any) -> None: - # caller = inspect.getframeinfo(inspect.stack()[1][0]) - # print(f"traceback:{caller.filename}:{caller.function}:{caller.lineno}") - return create_log_and_print_function(level="exception")(*args, **kwargs) - - -def critical(*args: Any, **kwargs: Any) -> None: - # caller = inspect.getframeinfo(inspect.stack()[1][0]) - # print(f"critical:{caller.filename}:{caller.function}:{caller.lineno}:{args}") - return create_log_and_print_function(level="critical")(*args, **kwargs) - - -def error(*args: Any, **kwargs: Any) -> None: - # caller = inspect.getframeinfo(inspect.stack()[1][0]) - # print(f"error:{caller.filename}:{caller.function}:{caller.lineno}") - return create_log_and_print_function(level="error")(*args, **kwargs) - - -def warning(*args: Any, **kwargs: Any) -> None: - return create_log_and_print_function(level="warning")(*args, **kwargs) - - -def info(*args: Any, **kwargs: Any) -> None: - return create_log_and_print_function(level="info")(*args, **kwargs) - - -def debug(*args: Any) -> None: - debug_msg = " ".join([str(a) for a in args]) - return logger.debug(debug_msg) - - -def _debug(*args: Any, **kwargs: Any) -> None: - return create_log_and_print_function(level="debug")(*args, **kwargs) - - -def trace(*args: Any, **kwargs: Any) -> None: - return create_log_and_print_function(level="trace")(*args, **kwargs) diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index ee0576cc206..f623e95b480 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -1,5 +1,6 @@ # stdlib import json +import logging import secrets from typing import Any @@ -7,7 +8,6 @@ from IPython.display import HTML from IPython.display import display import jinja2 -from loguru import logger # relative from ...assets import load_css @@ -16,6 +16,8 @@ from ...table import prepare_table_data from ..icons import Icon +logger = logging.getLogger(__name__) + DEFAULT_ID_WIDTH = 110 env = jinja2.Environment(loader=jinja2.PackageLoader("syft", "assets/jinja")) # nosec @@ -145,7 +147,7 @@ def build_tabulator_table( return table_html except Exception as e: - logger.debug("error building table", e) + logger.debug("error building table", exc_info=e) return None diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 998e022bdbd..fc5df24578c 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -3,17 +3,17 @@ from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Set +import logging import re from typing import Any -# third party -from loguru import logger - # relative from .notebook_ui.components.table_template import TABLE_INDEX_KEY from .notebook_ui.components.table_template import create_table_template from .util import full_name_with_qualname +logger = logging.getLogger(__name__) + def _syft_in_mro(self: Any, item: Any) -> bool: if hasattr(type(item), "mro") and type(item) != type: diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index 32a57dd0534..d03f240a1de 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -1,9 +1,12 @@ # stdlib from collections.abc import Callable +import logging import os from typing import Any from typing import TypeVar +logger = logging.getLogger(__name__) + def str_to_bool(bool_str: str | None) -> bool: result = False @@ -27,7 +30,6 @@ def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: instrument = noop else: try: - print("OpenTelemetry Tracing enabled") service_name = os.environ.get("SERVICE_NAME", "client") jaeger_host = os.environ.get("JAEGER_HOST", "localhost") jaeger_port = int(os.environ.get("JAEGER_PORT", "14268")) @@ -74,6 +76,6 @@ def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: from .trace_decorator import instrument as _instrument instrument = _instrument - except Exception: # nosec - print("Failed to import opentelemetry") + except Exception as e: + logger.error("Failed to import opentelemetry", exc_info=e) instrument = noop diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index b0affa2b1a0..bbdba2a2e60 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -10,6 +10,7 @@ import functools import hashlib from itertools import repeat +import logging import multiprocessing import multiprocessing as mp from multiprocessing import set_start_method @@ -37,11 +38,7 @@ from nacl.signing import VerifyKey import requests -# relative -from .logger import critical -from .logger import debug -from .logger import error -from .logger import traceback_and_raise +logger = logging.getLogger(__name__) DATASETS_URL = "https://raw.githubusercontent.com/OpenMined/datasets/main" PANDAS_DATA = f"{DATASETS_URL}/pandas_cookbook" @@ -57,9 +54,9 @@ def full_name_with_qualname(klass: type) -> str: if not hasattr(klass, "__module__"): return f"builtins.{get_qualname_for(klass)}" return f"{klass.__module__}.{get_qualname_for(klass)}" - except Exception: + except Exception as e: # try name as backup - print("Failed to get FQN for:", klass, type(klass)) + logger.error(f"Failed to get FQN for: {klass} {type(klass)}", exc_info=e) return full_name_with_name(klass=klass) @@ -70,7 +67,7 @@ def full_name_with_name(klass: type) -> str: return f"builtins.{get_name_for(klass)}" return f"{klass.__module__}.{get_name_for(klass)}" except Exception as e: - print("Failed to get FQN for:", klass, type(klass)) + logger.error(f"Failed to get FQN for: {klass} {type(klass)}", exc_info=e) raise e @@ -107,7 +104,7 @@ def extract_name(klass: type) -> str: return fqn.split(".")[-1] return fqn except Exception as e: - print(f"Failed to get klass name {klass}") + logger.error(f"Failed to get klass name {klass}", exc_info=e) raise e else: raise ValueError(f"Failed to match regex for klass {klass}") @@ -117,9 +114,7 @@ def validate_type(_object: object, _type: type, optional: bool = False) -> Any: if isinstance(_object, _type) or (optional and (_object is None)): return _object - traceback_and_raise( - f"Object {_object} should've been of type {_type}, not {_object}." - ) + raise Exception(f"Object {_object} should've been of type {_type}, not {_object}.") def validate_field(_object: object, _field: str) -> Any: @@ -128,7 +123,7 @@ def validate_field(_object: object, _field: str) -> Any: if object is not None: return object - traceback_and_raise(f"Object {_object} has no {_field} field set.") + raise Exception(f"Object {_object} has no {_field} field set.") def get_fully_qualified_name(obj: object) -> str: @@ -150,7 +145,7 @@ def get_fully_qualified_name(obj: object) -> str: try: fqn += "." + obj.__class__.__name__ except Exception as e: - error(f"Failed to get FQN: {e}") + logger.error(f"Failed to get FQN: {e}") return fqn @@ -175,7 +170,7 @@ def key_emoji(key: object) -> str: hex_chars = bytes(key).hex()[-8:] return char_emoji(hex_chars=hex_chars) except Exception as e: - error(f"Fail to get key emoji: {e}") + logger.error(f"Fail to get key emoji: {e}") pass return "ALL" @@ -332,7 +327,7 @@ def find_available_port( sock.close() except Exception as e: - print(f"Failed to check port {port}. {e}") + logger.error(f"Failed to check port {port}. {e}") sock.close() if search is False and port_available is False: @@ -446,7 +441,7 @@ def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: except Exception as e: # sometimes the object doesn't have a __module__ so you need to use the type # like: collections.OrderedDict - debug( + logger.debug( f"Unable to get get_fully_qualified_name of {type(obj)} trying type. {e}" ) fqn = get_fully_qualified_name(obj=type(obj)) @@ -457,10 +452,8 @@ def obj2pointer_type(obj: object | None = None, fqn: str | None = None) -> type: try: ref = get_loaded_syft().lib_ast.query(fqn, obj_type=type(obj)) - except Exception as e: - log = f"Cannot find {type(obj)} {fqn} in lib_ast. {e}" - critical(log) - raise Exception(log) + except Exception: + raise Exception(f"Cannot find {type(obj)} {fqn} in lib_ast.") return ref.pointer_type diff --git a/ruff.toml b/ruff.toml index 3dccdf65b91..bdf2c46b9cf 100644 --- a/ruff.toml +++ b/ruff.toml @@ -24,6 +24,7 @@ ignore = [ [lint.per-file-ignores] "*.ipynb" = ["E402"] +"__init__.py" = ["F401"] [lint.pycodestyle] max-line-length = 120 From 36cf7a35e56044531d14e10f1263698f8323305e Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Mon, 24 Jun 2024 16:48:26 +0530 Subject: [PATCH 2/9] remove some more __init__ F401 --- packages/syft/src/syft/serde/__init__.py | 6 ++--- .../src/syft/service/data_subject/__init__.py | 2 +- packages/syft/src/syft/store/__init__.py | 4 +-- packages/syft/src/syft/util/__init__.py | 2 +- packages/syft/tests/conftest.py | 26 +++++++++---------- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/serde/__init__.py b/packages/syft/src/syft/serde/__init__.py index 666be78ca11..00122b4769f 100644 --- a/packages/syft/src/syft/serde/__init__.py +++ b/packages/syft/src/syft/serde/__init__.py @@ -1,4 +1,4 @@ # relative -from .array import NOTHING # noqa: F401 F811 -from .recursive import NOTHING # noqa: F401 F811 -from .third_party import NOTHING # noqa: F401 F811 +from .array import NOTHING # noqa: F811 +from .recursive import NOTHING # noqa: F811 +from .third_party import NOTHING # noqa: F811 diff --git a/packages/syft/src/syft/service/data_subject/__init__.py b/packages/syft/src/syft/service/data_subject/__init__.py index f628bc5d753..f232044493c 100644 --- a/packages/syft/src/syft/service/data_subject/__init__.py +++ b/packages/syft/src/syft/service/data_subject/__init__.py @@ -1,2 +1,2 @@ # relative -from .data_subject import DataSubjectCreate # noqa: F401 +from .data_subject import DataSubjectCreate diff --git a/packages/syft/src/syft/store/__init__.py b/packages/syft/src/syft/store/__init__.py index 2369be33ea4..9260d13f956 100644 --- a/packages/syft/src/syft/store/__init__.py +++ b/packages/syft/src/syft/store/__init__.py @@ -1,3 +1,3 @@ # relative -from .mongo_document_store import MongoDict # noqa: F401 -from .mongo_document_store import MongoStoreConfig # noqa: F401 +from .mongo_document_store import MongoDict +from .mongo_document_store import MongoStoreConfig diff --git a/packages/syft/src/syft/util/__init__.py b/packages/syft/src/syft/util/__init__.py index f6394760c7b..aec1f392faf 100644 --- a/packages/syft/src/syft/util/__init__.py +++ b/packages/syft/src/syft/util/__init__.py @@ -1,2 +1,2 @@ # relative -from .schema import generate_json_schemas # noqa: F401 +from .schema import generate_json_schemas diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index c160034b532..2d781f817d7 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -25,19 +25,19 @@ # relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import dict_action_store # noqa: F401 -from .syft.stores.store_fixtures_test import dict_document_store # noqa: F401 -from .syft.stores.store_fixtures_test import dict_queue_stash # noqa: F401 -from .syft.stores.store_fixtures_test import dict_store_partition # noqa: F401 -from .syft.stores.store_fixtures_test import mongo_action_store # noqa: F401 -from .syft.stores.store_fixtures_test import mongo_document_store # noqa: F401 -from .syft.stores.store_fixtures_test import mongo_queue_stash # noqa: F401 -from .syft.stores.store_fixtures_test import mongo_store_partition # noqa: F401 -from .syft.stores.store_fixtures_test import sqlite_action_store # noqa: F401 -from .syft.stores.store_fixtures_test import sqlite_document_store # noqa: F401 -from .syft.stores.store_fixtures_test import sqlite_queue_stash # noqa: F401 -from .syft.stores.store_fixtures_test import sqlite_store_partition # noqa: F401 -from .syft.stores.store_fixtures_test import sqlite_workspace # noqa: F401 +from .syft.stores.store_fixtures_test import dict_action_store +from .syft.stores.store_fixtures_test import dict_document_store +from .syft.stores.store_fixtures_test import dict_queue_stash +from .syft.stores.store_fixtures_test import dict_store_partition +from .syft.stores.store_fixtures_test import mongo_action_store +from .syft.stores.store_fixtures_test import mongo_document_store +from .syft.stores.store_fixtures_test import mongo_queue_stash +from .syft.stores.store_fixtures_test import mongo_store_partition +from .syft.stores.store_fixtures_test import sqlite_action_store +from .syft.stores.store_fixtures_test import sqlite_document_store +from .syft.stores.store_fixtures_test import sqlite_queue_stash +from .syft.stores.store_fixtures_test import sqlite_store_partition +from .syft.stores.store_fixtures_test import sqlite_workspace def patch_protocol_file(filepath: Path): From dbafb43f443594cb08a7dd366d256bbfe2efc7f4 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Mon, 24 Jun 2024 17:03:35 +0530 Subject: [PATCH 3/9] add queue message handler logging --- packages/syft/src/syft/service/queue/queue.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index a6b1308b895..7515d10be54 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -220,12 +220,10 @@ def handle_message_multiprocessing( else: raise Exception(f"Unknown result type: {type(result)}") - except Exception as e: # nosec + except Exception as e: status = Status.ERRORED job_status = JobStatus.ERRORED - # stdlib - - logger.error(f"Error while handle message multiprocessing: {e}") + logger.error("Unhandled error in handle_message_multiprocessing", exc_info=e) queue_item.result = result queue_item.resolved = True @@ -257,7 +255,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: # relative from ...node.node import Node - queue_item = deserialize(message, from_bytes=True) + queue_item: QueueItem = deserialize(message, from_bytes=True) worker_settings = queue_item.worker_settings queue_config = worker_settings.queue_config @@ -306,6 +304,12 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: if isinstance(job_result, SyftError): raise Exception(f"{job_result.err()}") + logger.info( + f"Handling queue item: id={queue_item.id}, method={queue_item.method} " + f"args={queue_item.args}, kwargs={queue_item.kwargs} " + f"service={queue_item.service}, as_thread={queue_config.thread_workers}" + ) + if queue_config.thread_workers: thread = Thread( target=handle_message_multiprocessing, @@ -316,7 +320,6 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: else: # if psutil.pid_exists(job_item.job_pid): # psutil.Process(job_item.job_pid).terminate() - process = Process( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), From f4f124057bf3f196f5180509145e511eb8c6b48c Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Mon, 24 Jun 2024 17:15:06 +0530 Subject: [PATCH 4/9] undo type hint that made linter v angry --- packages/syft/src/syft/service/queue/queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 7515d10be54..c85b94468f3 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -255,7 +255,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: # relative from ...node.node import Node - queue_item: QueueItem = deserialize(message, from_bytes=True) + queue_item = deserialize(message, from_bytes=True) worker_settings = queue_item.worker_settings queue_config = worker_settings.queue_config From 55e8882f526b793bdca38a52900bad9ceb64efa2 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 24 Jun 2024 15:21:54 +0200 Subject: [PATCH 5/9] fix Job cache order + add early permission check for enqueueing mock Job --- packages/syft/src/syft/node/node.py | 44 ++++++++++++++++++- .../syft/src/syft/service/request/request.py | 4 +- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 53b9c0d36dd..7fdd4de6037 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -4,6 +4,7 @@ # stdlib from collections import OrderedDict from collections.abc import Callable +from datetime import MINYEAR from datetime import datetime from functools import partial import hashlib @@ -17,6 +18,7 @@ from time import sleep import traceback from typing import Any +from typing import cast # third party from loguru import logger @@ -64,6 +66,7 @@ from ..service.job.job_service import JobService from ..service.job.job_stash import Job from ..service.job.job_stash import JobStash +from ..service.job.job_stash import JobStatus from ..service.job.job_stash import JobType from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService @@ -101,6 +104,7 @@ from ..service.sync.sync_service import SyncService from ..service.user.user import User from ..service.user.user import UserCreate +from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..service.user.user_stash import UserStash @@ -123,6 +127,7 @@ from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig +from ..types.datetime import DATETIME_FORMAT from ..types.syft_metaclass import Empty from ..types.syft_object import PartialSyftObject from ..types.syft_object import SYFT_OBJECT_VERSION_2 @@ -1458,14 +1463,35 @@ def add_queueitem_to_queue( return result return job + def _sort_jobs(self, jobs: list[Job]) -> list[Job]: + job_datetimes = {} + for job in jobs: + try: + d = datetime.strptime(job.creation_time, DATETIME_FORMAT) + except Exception: + d = datetime(MINYEAR, 1, 1) + job_datetimes[job.id] = d + + jobs.sort( + key=lambda job: (job.status != JobStatus.COMPLETED, job_datetimes[job.id]), + reverse=True, + ) + + return jobs + def _get_existing_user_code_jobs( self, context: AuthedServiceContext, user_code_id: UID ) -> list[Job] | SyftError: job_service = self.get_service("jobservice") - return job_service.get_by_user_code_id( + jobs = job_service.get_by_user_code_id( context=context, user_code_id=user_code_id ) + if isinstance(jobs, SyftError): + return jobs + + return self._sort_jobs(jobs) + def _is_usercode_call_on_owned_kwargs( self, context: AuthedServiceContext, @@ -1502,6 +1528,14 @@ def add_api_call_to_queue( action = Action.from_api_call(unsigned_call) user_code_id = action.user_code_id + user = self.get_service(UserService).get_current_user(context) + if isinstance(user, SyftError): + return user + user = cast(UserView, user) + + is_execution_on_owned_kwargs_allowed = ( + user.mock_execution_permission or context.role == ServiceRole.ADMIN + ) is_usercode_call_on_owned_kwargs = self._is_usercode_call_on_owned_kwargs( context, unsigned_call, user_code_id ) @@ -1527,6 +1561,14 @@ def add_api_call_to_queue( message="Please wait for the admin to allow the execution of this code" ) + elif ( + is_usercode_call_on_owned_kwargs + and not is_execution_on_owned_kwargs_allowed + ): + return SyftError( + message="You do not have the permissions for mock execution, please contact the admin" + ) + return self.add_action_to_queue( action, api_call.credentials, parent_job_id=parent_job_id ) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index f174ef269b9..ac767a350d3 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -857,7 +857,9 @@ def _create_output_history_for_deposited_result( input_policy = code.input_policy if input_policy is not None: for input_ in input_policy.inputs.values(): - input_ids.update(input_) + # Skip inputs with type Constant + if isinstance(input_, UID): + input_ids.update(input_) res = api.services.code.store_execution_output( user_code_id=code.id, outputs=result, From 6dddba725b17f1c2f3259db8b318ed55276b0472 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 24 Jun 2024 15:26:40 +0200 Subject: [PATCH 6/9] fix link --- packages/syft/src/syft/service/request/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index ac767a350d3..16cc146d578 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -853,7 +853,7 @@ def _create_output_history_for_deposited_result( if isinstance(api, SyftError): return api - input_ids = {} + input_ids = {} # type: ignore input_policy = code.input_policy if input_policy is not None: for input_ in input_policy.inputs.values(): From e50202709b8e83e0c811ef2122ea243b49bf219f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 24 Jun 2024 15:37:54 +0200 Subject: [PATCH 7/9] revert --- packages/syft/src/syft/service/request/request.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 16cc146d578..f174ef269b9 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -853,13 +853,11 @@ def _create_output_history_for_deposited_result( if isinstance(api, SyftError): return api - input_ids = {} # type: ignore + input_ids = {} input_policy = code.input_policy if input_policy is not None: for input_ in input_policy.inputs.values(): - # Skip inputs with type Constant - if isinstance(input_, UID): - input_ids.update(input_) + input_ids.update(input_) res = api.services.code.store_execution_output( user_code_id=code.id, outputs=result, From 13650c4b30c23bd95e64f15fa4e8e17cdb3237e0 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 24 Jun 2024 16:02:29 +0200 Subject: [PATCH 8/9] completed_job.wait() waits forever reproduce https://github.com/OpenMined/Heartbeat/issues/1541 --- .../service/sync/sync_resolve_single_test.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index d68124e9b4d..83fba7fd168 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -1,5 +1,8 @@ # third party +# third party +import numpy as np + # syft absolute import syft import syft as sy @@ -60,6 +63,28 @@ def run_and_deposit_result(client): return job +def create_dataset(client): + mock = np.random.random(5) + private = np.random.random(5) + + dataset = sy.Dataset( + name=sy.util.util.random_name().lower(), + description="Lorem ipsum dolor sit amet, consectetur adipiscing elit", + asset_list=[ + sy.Asset( + name="numpy-data", + mock=mock, + data=private, + shape=private.shape, + mock_is_real=True, + ) + ], + ) + + client.upload_dataset(dataset) + return dataset + + @syft.syft_function_single_use() def compute() -> int: return 42 @@ -110,6 +135,60 @@ def compute() -> int: assert res == compute(syft_no_node=True) +def test_diff_state_with_dataset(low_worker, high_worker): + low_client: DomainClient = low_worker.root_client + client_low_ds = get_ds_client(low_client) + high_client: DomainClient = high_worker.root_client + + _ = create_dataset(high_client) + _ = create_dataset(low_client) + + @sy.syft_function_single_use() + def compute_mean(data) -> int: + return data.mean() + + _ = client_low_ds.code.request_code_execution(compute_mean) + + result = client_low_ds.code.compute_mean(blocking=False) + assert isinstance(result, SyftError), "DS cannot start a job on low side" + + diff_state_before, diff_state_after = compare_and_resolve( + from_client=low_client, to_client=high_client + ) + + assert not diff_state_before.is_same + + assert diff_state_after.is_same + + # run_and_deposit_result(high_client) + data_high = high_client.datasets[0].assets[0] + result = high_client.code.compute_mean(data=data_high, blocking=True) + high_client.requests[0].deposit_result(result) + + diff_state_before, diff_state_after = compare_and_resolve( + from_client=high_client, to_client=low_client + ) + + high_state = high_client.get_sync_state() + low_state = high_client.get_sync_state() + assert high_state.get_previous_state_diff().is_same + assert low_state.get_previous_state_diff().is_same + assert diff_state_after.is_same + + client_low_ds.refresh() + + # check loading results for both blocking and non-blocking case + res_blocking = client_low_ds.code.compute_mean(blocking=True) + res_non_blocking = client_low_ds.code.compute_mean(blocking=False).wait() + + # expected_result = compute_mean(syft_no_node=True, data=) + assert ( + res_blocking + == res_non_blocking + == high_client.datasets[0].assets[0].data.mean() + ) + + def test_sync_with_error(low_worker, high_worker): """Check syncing with an error in a syft function""" low_client: DomainClient = low_worker.root_client From 03ef6ec90d24c38fe677b3be2d7ca417d1c8af3f Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 24 Jun 2024 16:03:12 +0200 Subject: [PATCH 9/9] fix Job.resolved when the result is deposited fixes https://github.com/OpenMined/Heartbeat/issues/1541 --- packages/syft/src/syft/service/job/job_service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 654d83d3cc3..368992ceaa5 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -314,6 +314,7 @@ def create_job_for_user_code_id( status: JobStatus = JobStatus.CREATED, add_code_owner_read_permissions: bool = True, ) -> Job | SyftError: + is_resolved = status in [JobStatus.COMPLETED, JobStatus.ERRORED] job = Job( id=UID(), node_uid=context.node.id, @@ -324,6 +325,7 @@ def create_job_for_user_code_id( log_id=UID(), job_pid=None, user_code_id=user_code_id, + resolved=is_resolved, ) user_code_service = context.node.get_service("usercodeservice") user_code = user_code_service.get_by_uid(context=context, uid=user_code_id)