From 11e71273945631c17ae10a431ecdff9376abf0a7 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 27 Mar 2024 15:30:46 +1000 Subject: [PATCH 001/309] Added initial prototype for rathole --- packages/grid/rathole/client.toml | 6 +++ packages/grid/rathole/domain.dockerfile | 9 ++++ packages/grid/rathole/nginx.conf | 6 +++ packages/grid/rathole/rathole.dockerfile | 53 ++++++++++++++++++++++++ packages/grid/rathole/server.toml | 6 +++ packages/grid/rathole/start-client.sh | 4 ++ packages/grid/rathole/start-server.sh | 2 + 7 files changed, 86 insertions(+) create mode 100644 packages/grid/rathole/client.toml create mode 100644 packages/grid/rathole/domain.dockerfile create mode 100644 packages/grid/rathole/nginx.conf create mode 100644 packages/grid/rathole/rathole.dockerfile create mode 100644 packages/grid/rathole/server.toml create mode 100755 packages/grid/rathole/start-client.sh create mode 100755 packages/grid/rathole/start-server.sh diff --git a/packages/grid/rathole/client.toml b/packages/grid/rathole/client.toml new file mode 100644 index 00000000000..ba8b835a569 --- /dev/null +++ b/packages/grid/rathole/client.toml @@ -0,0 +1,6 @@ +[client] +remote_addr = "host.docker.internal:2333" # public IP and port of gateway + +[client.services.domain] +token = "domain-specific-rathole-secret" +local_addr = "localhost:8000" # nginx proxy diff --git a/packages/grid/rathole/domain.dockerfile b/packages/grid/rathole/domain.dockerfile new file mode 100644 index 00000000000..cdb657540e8 --- /dev/null +++ b/packages/grid/rathole/domain.dockerfile @@ -0,0 +1,9 @@ +ARG PYTHON_VERSION="3.12" +FROM python:${PYTHON_VERSION}-bookworm +RUN apt update && apt install -y netcat-openbsd vim +WORKDIR /app +CMD ["python3", "-m", "http.server", "8000"] +EXPOSE 8000 + +# docker build -f domain.dockerfile . -t domain +# docker run -it -p 8080:8000 domain diff --git a/packages/grid/rathole/nginx.conf b/packages/grid/rathole/nginx.conf new file mode 100644 index 00000000000..af1a3a752d7 --- /dev/null +++ b/packages/grid/rathole/nginx.conf @@ -0,0 +1,6 @@ +server { + listen 8000; + location / { + proxy_pass http://host.docker.internal:8080; + } +} \ No newline at end of file diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile new file mode 100644 index 00000000000..4ae1648e6d6 --- /dev/null +++ b/packages/grid/rathole/rathole.dockerfile @@ -0,0 +1,53 @@ +ARG RATHOLE_VERSION="0.5.0" +ARG PYTHON_VERSION="3.12" + +FROM rust as build +ARG RATHOLE_VERSION +ARG FEATURES +RUN apt update && apt install -y git +RUN git clone -b v${RATHOLE_VERSION} https://github.com/rapiz1/rathole + +WORKDIR /rathole +RUN cargo build --locked --release --features ${FEATURES:-default} + +FROM python:${PYTHON_VERSION}-bookworm +ARG RATHOLE_VERSION +ENV MODE="client" +COPY --from=build /rathole/target/release/rathole /app/rathole +RUN apt update && apt install -y netcat-openbsd vim +WORKDIR /app +COPY ./start-client.sh /app/start-client.sh +COPY ./start-server.sh /app/start-server.sh +COPY ./client.toml /app/client.toml +COPY ./server.toml /app/server.toml +COPY ./nginx.conf /etc/nginx/conf.d/default.conf + +CMD ["sh", "-c", "/app/start-$MODE.sh"] +EXPOSE 2333/udp +EXPOSE 2333 + +# build and run a fake domain to simulate a normal http container service +# docker build -f domain.dockerfile . -t domain +# docker run -it -d -p 8080:8000 domain + +# check the web server is running on 8080 +# curl localhost:8080 + +# build and run the rathole container +# docker build -f rathole.dockerfile . -t rathole + +# run the rathole server +# docker run -it -p 8001:8001 -p 8002:8002 -p 2333:2333 -e MODE=server rathole + +# check nothing is on port 8001 yet +# curl localhost:8001 + +# run the rathole client +# docker run -it -e MODE=client rathole + +# try port 8001 now +# curl localhost:8001 + +# add another client and edit the server.toml and client.toml for port 8002 + + diff --git a/packages/grid/rathole/server.toml b/packages/grid/rathole/server.toml new file mode 100644 index 00000000000..8145491e7cc --- /dev/null +++ b/packages/grid/rathole/server.toml @@ -0,0 +1,6 @@ +[server] +bind_addr = "0.0.0.0:2333" # public open port + +[server.services.domain] +token = "domain-specific-rathole-secret" +bind_addr = "0.0.0.0:8001" diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh new file mode 100755 index 00000000000..60ace3fa31b --- /dev/null +++ b/packages/grid/rathole/start-client.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +apt update && apt install -y nginx +nginx & +/app/rathole client.toml diff --git a/packages/grid/rathole/start-server.sh b/packages/grid/rathole/start-server.sh new file mode 100755 index 00000000000..700b081a60d --- /dev/null +++ b/packages/grid/rathole/start-server.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +/app/rathole server.toml From 91a326a142524bd2140e4990d69517f4a58c95ac Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 15 Apr 2024 17:40:26 +0530 Subject: [PATCH 002/309] update dockerfile with additional hosts for docker internal --- packages/grid/rathole/rathole.dockerfile | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 4ae1648e6d6..c4eb717696b 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -28,7 +28,9 @@ EXPOSE 2333 # build and run a fake domain to simulate a normal http container service # docker build -f domain.dockerfile . -t domain -# docker run -it -d -p 8080:8000 domain +# docker run --name domain1 -it -d -p 8080:8000 domain + + # check the web server is running on 8080 # curl localhost:8080 @@ -37,13 +39,13 @@ EXPOSE 2333 # docker build -f rathole.dockerfile . -t rathole # run the rathole server -# docker run -it -p 8001:8001 -p 8002:8002 -p 2333:2333 -e MODE=server rathole +# docker run --add-host host.docker.internal:host-gateway --name rathole-server -it -p 8001:8001 -p 8002:8002 -p 2333:2333 -e MODE=server rathole # check nothing is on port 8001 yet # curl localhost:8001 # run the rathole client -# docker run -it -e MODE=client rathole +# docker run --add-host host.docker.internal:host-gateway --name rathole-client -it -e MODE=client rathole # try port 8001 now # curl localhost:8001 From 191193d0e75ddb496f5a7c67764e343610e5f74f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 18 Apr 2024 17:59:07 +0530 Subject: [PATCH 003/309] define intial frame for the rathole server app --- packages/grid/rathole/main.py | 72 +++++++++++++++++++++++++++++++++ packages/grid/rathole/models.py | 11 +++++ 2 files changed, 83 insertions(+) create mode 100644 packages/grid/rathole/main.py create mode 100644 packages/grid/rathole/models.py diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/main.py new file mode 100644 index 00000000000..55212b2fcbd --- /dev/null +++ b/packages/grid/rathole/main.py @@ -0,0 +1,72 @@ +# stdlib +import os +import sys + +# third party +from fastapi import FastAPI +from fastapi import status +from loguru import logger + +# relative +from .models import RatholeConfig +from .models import ResponseModel + +# Logging Configuration +log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() +logger.remove() +logger.add(sys.stderr, colorize=True, level=log_level) + +app = FastAPI(title="Rathole") + + +async def healthcheck() -> bool: + return True + + +@app.get( + "/healthcheck", + response_model=ResponseModel, + status_code=status.HTTP_200_OK, +) +async def healthcheck_endpoint() -> ResponseModel: + res = await healthcheck() + if res: + return ResponseModel(message="OK") + else: + return ResponseModel(message="FAIL") + + +@app.post( + "/config/", + response_model=ResponseModel, + status_code=status.HTTP_201_CREATED, +) +async def add_config(config: RatholeConfig) -> ResponseModel: + return ResponseModel(message="Config added successfully") + + +@app.delete( + "/config/{uuid}", + response_model=ResponseModel, + status_code=status.HTTP_200_OK, +) +async def remove_config(uuid: str) -> ResponseModel: + return ResponseModel(message="Config removed successfully") + + +@app.put( + "/config/{uuid}", + response_model=ResponseModel, + status_code=status.HTTP_200_OK, +) +async def update_config() -> ResponseModel: + return ResponseModel(message="Config updated successfully") + + +@app.get( + "/config/{uuid}", + response_model=RatholeConfig, + status_code=status.HTTP_201_CREATED, +) +async def get_config(uuid: str) -> RatholeConfig: + pass diff --git a/packages/grid/rathole/models.py b/packages/grid/rathole/models.py new file mode 100644 index 00000000000..5feb12cbdaa --- /dev/null +++ b/packages/grid/rathole/models.py @@ -0,0 +1,11 @@ +# third party +from pydantic import BaseModel + + +class ResponseModel(BaseModel): + message: str + + +class RatholeConfig(BaseModel): + uuid: str + secret_token: str From 02694ec8f658772a66c704eb1ecf73c2072fd45b Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 22 Apr 2024 15:22:06 +0530 Subject: [PATCH 004/309] define RatholeConfig model - added a tomlreaderwriter to read/write with locks - added a RatholeClientToml Reader/Writer - define a RatholeServerTole reader/writer --- packages/grid/rathole/models.py | 6 ++ packages/grid/rathole/nginx.conf | 17 +++- packages/grid/rathole/toml_writer.py | 26 +++++ packages/grid/rathole/utils.py | 137 +++++++++++++++++++++++++++ 4 files changed, 182 insertions(+), 4 deletions(-) create mode 100644 packages/grid/rathole/toml_writer.py create mode 100644 packages/grid/rathole/utils.py diff --git a/packages/grid/rathole/models.py b/packages/grid/rathole/models.py index 5feb12cbdaa..27091510896 100644 --- a/packages/grid/rathole/models.py +++ b/packages/grid/rathole/models.py @@ -9,3 +9,9 @@ class ResponseModel(BaseModel): class RatholeConfig(BaseModel): uuid: str secret_token: str + local_addr_host: str + local_addr_port: int + + @property + def local_address(self) -> str: + return f"{self.local_addr_host}:{self.local_addr_port}" diff --git a/packages/grid/rathole/nginx.conf b/packages/grid/rathole/nginx.conf index af1a3a752d7..b660e980400 100644 --- a/packages/grid/rathole/nginx.conf +++ b/packages/grid/rathole/nginx.conf @@ -1,6 +1,15 @@ -server { - listen 8000; - location / { - proxy_pass http://host.docker.internal:8080; +http { + server { + listen 8000; + location / { + proxy_pass http://host.docker.internal:8080; + } + } + + server { + listen 8001; + location / { + proxy_pass http://host.docker.internal:8081; + } } } \ No newline at end of file diff --git a/packages/grid/rathole/toml_writer.py b/packages/grid/rathole/toml_writer.py new file mode 100644 index 00000000000..09c44801432 --- /dev/null +++ b/packages/grid/rathole/toml_writer.py @@ -0,0 +1,26 @@ +# stdlib +from pathlib import Path +import tomllib + +# third party +from filelock import FileLock + +FILE_LOCK_TIMEOUT = 30 + + +class TomlReaderWriter: + def __init__(self, lock: FileLock, filename: Path | str) -> None: + self.filename = Path(filename) + self.timeout = FILE_LOCK_TIMEOUT + self.lock = lock + + def write(self, toml_dict: dict) -> None: + with self.lock.acquire(timeout=self.timeout): + with open(str(self.filename), "wb") as fp: + tomllib.dump(toml_dict, fp) + + def read(self) -> dict: + with self.lock.acquire(timeout=self.timeout): + with open(str(self.filename), "rb") as fp: + toml = tomllib.load(fp) + return toml diff --git a/packages/grid/rathole/utils.py b/packages/grid/rathole/utils.py new file mode 100644 index 00000000000..166eee41c32 --- /dev/null +++ b/packages/grid/rathole/utils.py @@ -0,0 +1,137 @@ +# stdlib + +# third party +from filelock import FileLock + +# relative +from .models import RatholeConfig +from .toml_writer import TomlReaderWriter + +lock = FileLock("rathole.toml.lock") + + +class RatholeClientToml: + filename: str = "client.toml" + + def __init__(self) -> None: + self.client_toml = TomlReaderWriter(lock=lock, filename=self.filename) + + def set_remote_addr(self, remote_host: str) -> None: + """Add a new remote address to the client toml file.""" + + toml = self.client_toml.read() + + # Add the new remote address + if "client" not in toml: + toml["client"] = {} + + toml["client"]["remote_addr"] = remote_host + + if remote_host not in toml["client"]["remote"]: + toml["client"]["remote"].append(remote_host) + + self.client_toml.write(toml_dict=toml) + + def add_config(self, config: RatholeConfig) -> None: + """Add a new config to the toml file.""" + + toml = self.client_toml.read() + + # Add the new config + if "services" not in toml["client"]: + toml["client"]["services"] = {} + + if config.uuid not in toml["client"]["services"]: + toml["client"]["services"][config.uuid] = {} + + toml["client"]["services"][config.uuid] = { + "token": config.secret_token, + "local_addr": config.local_address, + } + + self.client_toml.write(toml) + + def remove_config(self, uuid: str) -> None: + """Remove a config from the toml file.""" + + toml = self.client_toml.read() + + # Remove the config + if "services" not in toml["client"]: + return + + if uuid not in toml["client"]["services"]: + return + + del toml["client"]["services"][uuid] + + self.client_toml.write(toml) + + def update_config(self, config: RatholeConfig) -> None: + """Update a config in the toml file.""" + + toml = self.client_toml.read() + + # Update the config + if "services" not in toml["client"]: + return + + if config.uuid not in toml["client"]["services"]: + return + + toml["client"]["services"][config.uuid] = { + "token": config.secret_token, + "local_addr": config.local_address, + } + + self.client_toml.write(toml) + + def get_config(self, uuid: str) -> RatholeConfig | None: + """Get a config from the toml file.""" + + toml = self.client_toml.read() + + # Get the config + if "services" not in toml["client"]: + return None + + if uuid not in toml["client"]["services"]: + return None + + service = toml["client"]["services"][uuid] + + return RatholeConfig( + uuid=uuid, + secret_token=service["token"], + local_addr_host=service["local_addr"].split(":")[0], + local_addr_port=service["local_addr"].split(":")[1], + ) + + def _validate(self) -> bool: + if not self.client_toml.filename.exists(): + return False + + toml = self.client_toml.read() + + if not toml["client"]["remote_addr"]: + return False + + for uuid, config in toml["client"]["services"].items(): + if not uuid: + return False + + if not config["token"] and not config["local_addr"]: + return False + + return True + + @property + def is_valid(self) -> bool: + return self._validate() + + +class ServerTomlReaderWriter: + filename: str = "server.toml" + + def __init__(self) -> None: + self.server_toml = TomlReaderWriter(lock=lock, filename=self.filename) From 4d26dab81669b5df8604ac22f278ef53ea365589 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 22 Apr 2024 16:11:04 +0530 Subject: [PATCH 005/309] implement server toml reader/writer class --- packages/grid/rathole/utils.py | 89 +++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/packages/grid/rathole/utils.py b/packages/grid/rathole/utils.py index 166eee41c32..e65896de439 100644 --- a/packages/grid/rathole/utils.py +++ b/packages/grid/rathole/utils.py @@ -120,7 +120,7 @@ def _validate(self) -> bool: if not uuid: return False - if not config["token"] and not config["local_addr"]: + if not config["token"] or not config["local_addr"]: return False return True @@ -130,8 +130,93 @@ def is_valid(self) -> bool: return self._validate() -class ServerTomlReaderWriter: +class RatholeServerToml: filename: str = "server.toml" def __init__(self) -> None: self.server_toml = TomlReaderWriter(lock=lock, filename=self.filename) + + def set_bind_address(self, bind_address: str) -> None: + """Set the bind address in the server toml file.""" + + toml = self.server_toml.read() + + # Set the bind address + toml["server"]["bind_addr"] = bind_address + + self.server_toml.write(toml) + + def add_config(self, config: RatholeConfig) -> None: + """Add a new config to the toml file.""" + + toml = self.server_toml.read() + + # Add the new config + if "services" not in toml["server"]: + toml["server"]["services"] = {} + + if config.uuid not in toml["server"]["services"]: + toml["server"]["services"][config.uuid] = {} + + toml["server"]["services"][config.uuid] = { + "token": config.secret_token, + "bind_addr": config.local_address, + } + + self.server_toml.write(toml) + + def remove_config(self, uuid: str) -> None: + """Remove a config from the toml file.""" + + toml = self.server_toml.read() + + # Remove the config + if "services" not in toml["server"]: + return + + if uuid not in toml["server"]["services"]: + return + + del toml["server"]["services"][uuid] + + self.server_toml.write(toml) + + def update_config(self, config: RatholeConfig) -> None: + """Update a config in the toml file.""" + + toml = self.server_toml.read() + + # Update the config + if "services" not in toml["server"]: + return + + if config.uuid not in toml["server"]["services"]: + return + + toml["server"]["services"][config.uuid] = { + "token": config.secret_token, + "bind_addr": config.local_address, + } + + self.server_toml.write(toml) + + def _validate(self) -> bool: + if not self.server_toml.filename.exists(): + return False + + toml = self.server_toml.read() + + if not toml["server"]["bind_addr"]: + return False + + for uuid, config in toml["server"]["services"].items(): + if not uuid: + return False + + if not config["token"] or not config["bind_addr"]: + return False + + return True + + def is_valid(self) -> bool: + return self._validate() From a875114506f8e1adf8f5a9ee3a6587acafcc0b94 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 22 Apr 2024 16:21:10 +0530 Subject: [PATCH 006/309] integrate rathole toml client/server manager with Rathole fastapi endpoints --- packages/grid/rathole/main.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/main.py index 55212b2fcbd..c23a664f98d 100644 --- a/packages/grid/rathole/main.py +++ b/packages/grid/rathole/main.py @@ -1,4 +1,5 @@ # stdlib +from enum import Enum import os import sys @@ -10,6 +11,8 @@ # relative from .models import RatholeConfig from .models import ResponseModel +from .utils import RatholeClientToml +from .utils import RatholeServerToml # Logging Configuration log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() @@ -19,6 +22,21 @@ app = FastAPI(title="Rathole") +class RatholeServiceType(Enum): + CLIENT = "client" + SERVER = "server" + + +ServiceType = os.getenv("RATHOLE_SERVICE_TYPE") + + +RatholeTomlManager = ( + RatholeServerToml() + if ServiceType == RatholeServiceType.SERVER.value + else RatholeClientToml() +) + + async def healthcheck() -> bool: return True @@ -42,6 +60,7 @@ async def healthcheck_endpoint() -> ResponseModel: status_code=status.HTTP_201_CREATED, ) async def add_config(config: RatholeConfig) -> ResponseModel: + RatholeTomlManager.add_config(config) return ResponseModel(message="Config added successfully") @@ -51,6 +70,7 @@ async def add_config(config: RatholeConfig) -> ResponseModel: status_code=status.HTTP_200_OK, ) async def remove_config(uuid: str) -> ResponseModel: + RatholeTomlManager.remove_config(uuid) return ResponseModel(message="Config removed successfully") @@ -59,7 +79,8 @@ async def remove_config(uuid: str) -> ResponseModel: response_model=ResponseModel, status_code=status.HTTP_200_OK, ) -async def update_config() -> ResponseModel: +async def update_config(config: RatholeConfig) -> ResponseModel: + RatholeTomlManager.update_config(config=config) return ResponseModel(message="Config updated successfully") @@ -69,4 +90,4 @@ async def update_config() -> ResponseModel: status_code=status.HTTP_201_CREATED, ) async def get_config(uuid: str) -> RatholeConfig: - pass + return RatholeTomlManager.get_config(uuid) From 59eff342f6251516b14fb00edb75f1437097fd01 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 22 Apr 2024 17:36:24 +0530 Subject: [PATCH 007/309] add a default value to Rathole service type --- packages/grid/rathole/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/main.py index c23a664f98d..b1b1fe59a09 100644 --- a/packages/grid/rathole/main.py +++ b/packages/grid/rathole/main.py @@ -22,17 +22,17 @@ app = FastAPI(title="Rathole") -class RatholeServiceType(Enum): +class RatholeMode(Enum): CLIENT = "client" SERVER = "server" -ServiceType = os.getenv("RATHOLE_SERVICE_TYPE") +ServiceType = os.getenv("RATHOLE_MODE", "client").lower() RatholeTomlManager = ( RatholeServerToml() - if ServiceType == RatholeServiceType.SERVER.value + if ServiceType == RatholeMode.SERVER.value else RatholeClientToml() ) From 6ed6b100d7ac9426093509d11a3379257e4aeacb Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 23 Apr 2024 14:35:26 +0530 Subject: [PATCH 008/309] added nginx conf builder --- packages/grid/rathole/nginx_builder.py | 53 ++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 packages/grid/rathole/nginx_builder.py diff --git a/packages/grid/rathole/nginx_builder.py b/packages/grid/rathole/nginx_builder.py new file mode 100644 index 00000000000..53ef4ae64d7 --- /dev/null +++ b/packages/grid/rathole/nginx_builder.py @@ -0,0 +1,53 @@ +# stdlib +from pathlib import Path + +# third party +from filelock import FileLock +import nginx +from nginx import Conf + + +class NginxConfigBuilder: + def __init__(self, filename: str | Path) -> None: + self.filename = Path(filename) + self.lock = FileLock(f"{filename}.lock") + self.lock_timeout = 30 + + def read(self) -> Conf: + with self.lock.acquire(timeout=self.lock_timeout): + conf = nginx.loadf(self.filename) + + return conf + + def write(self, conf: Conf) -> None: + with self.lock.acquire(timeout=self.lock_timeout): + nginx.dumpf(conf, self.filename) + + def add_server(self, listen_port: int, location: str, proxy_pass: str) -> None: + conf = self.read() + server = conf.servers.add() + server.listen = listen_port + location = server.locations.add() + location.path = location + location.proxy_pass = proxy_pass + self.write(conf) + + def remove_server(self, listen_port: int) -> None: + conf = self.read() + for server in conf.servers: + if server.listen == listen_port: + conf.servers.remove(server) + break + self.write(conf) + + def modify_location_for_port( + self, listen_port: int, location: str, proxy_pass: str + ) -> None: + conf = self.read() + for server in conf.servers: + if server.listen == listen_port: + for loc in server.locations: + if loc.path == location: + loc.proxy_pass = proxy_pass + break + self.write(conf) From 0084befb2088d728856981f2d87d6ff690ecdfdc Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 23 Apr 2024 23:43:33 +0530 Subject: [PATCH 009/309] Fix RatholeConfigBuilder class --- packages/grid/rathole/nginx_builder.py | 65 ++++++++++++++++++-------- packages/grid/rathole/toml_writer.py | 2 +- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/packages/grid/rathole/nginx_builder.py b/packages/grid/rathole/nginx_builder.py index 53ef4ae64d7..3a25d34bf8a 100644 --- a/packages/grid/rathole/nginx_builder.py +++ b/packages/grid/rathole/nginx_builder.py @@ -7,9 +7,13 @@ from nginx import Conf -class NginxConfigBuilder: +class RatholeNginxConfigBuilder: def __init__(self, filename: str | Path) -> None: - self.filename = Path(filename) + self.filename = Path(filename).absolute() + + if not self.filename.exists(): + self.filename.touch() + self.lock = FileLock(f"{filename}.lock") self.lock_timeout = 30 @@ -24,30 +28,53 @@ def write(self, conf: Conf) -> None: nginx.dumpf(conf, self.filename) def add_server(self, listen_port: int, location: str, proxy_pass: str) -> None: - conf = self.read() - server = conf.servers.add() - server.listen = listen_port - location = server.locations.add() - location.path = location - location.proxy_pass = proxy_pass - self.write(conf) + n_config = self.read() + server_to_modify = self.find_server_with_listen_port(listen_port) + + if server_to_modify is not None: + server_to_modify.add( + nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)) + ) + else: + server = nginx.Server( + nginx.Key("listen", listen_port), + nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)), + ) + n_config.add(server) + self.write(n_config) def remove_server(self, listen_port: int) -> None: conf = self.read() for server in conf.servers: - if server.listen == listen_port: - conf.servers.remove(server) - break + for child in server.children: + if child.name == "listen" and int(child.value) == listen_port: + conf.remove(server) + break self.write(conf) - def modify_location_for_port( + def find_server_with_listen_port(self, listen_port: int) -> nginx.Server | None: + conf = self.read() + for server in conf.servers: + for child in server.children: + if child.name == "listen" and int(child.value) == listen_port: + return server + return None + + def modify_proxy_for_port( self, listen_port: int, location: str, proxy_pass: str ) -> None: conf = self.read() - for server in conf.servers: - if server.listen == listen_port: - for loc in server.locations: - if loc.path == location: - loc.proxy_pass = proxy_pass - break + server_to_modify = self.find_server_with_listen_port(listen_port) + + if server_to_modify is None: + raise ValueError(f"Server with listen port {listen_port} not found") + + for location in server_to_modify.locations: + if location.value != location: + continue + for key in location.keys: + if key.name == "proxy_pass": + key.value = proxy_pass + break + self.write(conf) diff --git a/packages/grid/rathole/toml_writer.py b/packages/grid/rathole/toml_writer.py index 09c44801432..a0e79aff627 100644 --- a/packages/grid/rathole/toml_writer.py +++ b/packages/grid/rathole/toml_writer.py @@ -10,7 +10,7 @@ class TomlReaderWriter: def __init__(self, lock: FileLock, filename: Path | str) -> None: - self.filename = Path(filename) + self.filename = Path(filename).absolute() self.timeout = FILE_LOCK_TIMEOUT self.lock = lock From bbb093f3b13886d03b87c57056707fec8f63df13 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 24 Apr 2024 14:29:44 +0530 Subject: [PATCH 010/309] - pass server name to nginx builder - integrate nginx builder with toml manager - add requirements.txt - update rathole dockerfile Co-authored-by: Khoa Nguyen --- packages/grid/rathole/__init__.py | 0 packages/grid/rathole/main.py | 2 +- packages/grid/rathole/models.py | 3 ++- packages/grid/rathole/nginx_builder.py | 14 +++++++++++++- packages/grid/rathole/rathole.dockerfile | 7 +++++++ packages/grid/rathole/requirements.txt | 5 +++++ packages/grid/rathole/start-client.sh | 11 +++++++++++ packages/grid/rathole/utils.py | 14 ++++++++++++++ 8 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 packages/grid/rathole/__init__.py create mode 100644 packages/grid/rathole/requirements.txt diff --git a/packages/grid/rathole/__init__.py b/packages/grid/rathole/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/main.py index b1b1fe59a09..fe4c8f8a081 100644 --- a/packages/grid/rathole/main.py +++ b/packages/grid/rathole/main.py @@ -27,7 +27,7 @@ class RatholeMode(Enum): SERVER = "server" -ServiceType = os.getenv("RATHOLE_MODE", "client").lower() +ServiceType = os.getenv("MODE", "client").lower() RatholeTomlManager = ( diff --git a/packages/grid/rathole/models.py b/packages/grid/rathole/models.py index 27091510896..c5921738885 100644 --- a/packages/grid/rathole/models.py +++ b/packages/grid/rathole/models.py @@ -11,7 +11,8 @@ class RatholeConfig(BaseModel): secret_token: str local_addr_host: str local_addr_port: int + server_name: str | None = None @property def local_address(self) -> str: - return f"{self.local_addr_host}:{self.local_addr_port}" + return f"http://{self.local_addr_host}:{self.local_addr_port}" diff --git a/packages/grid/rathole/nginx_builder.py b/packages/grid/rathole/nginx_builder.py index 3a25d34bf8a..3d1bd14ce1f 100644 --- a/packages/grid/rathole/nginx_builder.py +++ b/packages/grid/rathole/nginx_builder.py @@ -27,7 +27,13 @@ def write(self, conf: Conf) -> None: with self.lock.acquire(timeout=self.lock_timeout): nginx.dumpf(conf, self.filename) - def add_server(self, listen_port: int, location: str, proxy_pass: str) -> None: + def add_server( + self, + listen_port: int, + location: str, + proxy_pass: str, + server_name: str | None = None, + ) -> None: n_config = self.read() server_to_modify = self.find_server_with_listen_port(listen_port) @@ -35,12 +41,18 @@ def add_server(self, listen_port: int, location: str, proxy_pass: str) -> None: server_to_modify.add( nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)) ) + if server_name is not None: + server_to_modify.add(nginx.Key("server_name", server_name)) else: server = nginx.Server( nginx.Key("listen", listen_port), nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)), ) + if server_name is not None: + server.add(nginx.Key("server_name", server_name)) + n_config.add(server) + self.write(n_config) def remove_server(self, listen_port: int) -> None: diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index c4eb717696b..3916a0eb12f 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -13,6 +13,7 @@ RUN cargo build --locked --release --features ${FEATURES:-default} FROM python:${PYTHON_VERSION}-bookworm ARG RATHOLE_VERSION ENV MODE="client" +ENV APP_LOG_LEVEL="info" COPY --from=build /rathole/target/release/rathole /app/rathole RUN apt update && apt install -y netcat-openbsd vim WORKDIR /app @@ -21,7 +22,13 @@ COPY ./start-server.sh /app/start-server.sh COPY ./client.toml /app/client.toml COPY ./server.toml /app/server.toml COPY ./nginx.conf /etc/nginx/conf.d/default.conf +COPY ./main.py /app/main.py +COPY ./nginx_builder.py /app/nginx_builder.py +COPY ./utils.py /app/utils.py +COPY ./requirements.txt /app/requirements.txt + +RUN pip install --user -r requirements.txt CMD ["sh", "-c", "/app/start-$MODE.sh"] EXPOSE 2333/udp EXPOSE 2333 diff --git a/packages/grid/rathole/requirements.txt b/packages/grid/rathole/requirements.txt new file mode 100644 index 00000000000..4b379d83e8e --- /dev/null +++ b/packages/grid/rathole/requirements.txt @@ -0,0 +1,5 @@ +fastapi==0.110.0 +filelock==3.13.4 +loguru==0.7.2 +python-nginx +uvicorn[standard]==0.27.1 diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh index 60ace3fa31b..850e71ab6e3 100755 --- a/packages/grid/rathole/start-client.sh +++ b/packages/grid/rathole/start-client.sh @@ -1,4 +1,15 @@ #!/usr/bin/env bash + +APP_MODULE=main:app +LOG_LEVEL=${LOG_LEVEL:-info} +HOST=${HOST:-0.0.0.0} +PORT=${PORT:-80} +RELOAD="" +DEBUG_CMD="" + + + apt update && apt install -y nginx nginx & +exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" & /app/rathole client.toml diff --git a/packages/grid/rathole/utils.py b/packages/grid/rathole/utils.py index e65896de439..485e4ae7e23 100644 --- a/packages/grid/rathole/utils.py +++ b/packages/grid/rathole/utils.py @@ -5,6 +5,7 @@ # relative from .models import RatholeConfig +from .nginx_builder import RatholeNginxConfigBuilder from .toml_writer import TomlReaderWriter lock = FileLock("rathole.toml.lock") @@ -15,6 +16,7 @@ class RatholeClientToml: def __init__(self) -> None: self.client_toml = TomlReaderWriter(lock=lock, filename=self.filename) + self.nginx_mananger = RatholeNginxConfigBuilder("nginx.conf") def set_remote_addr(self, remote_host: str) -> None: """Add a new remote address to the client toml file.""" @@ -51,6 +53,10 @@ def add_config(self, config: RatholeConfig) -> None: self.client_toml.write(toml) + self.nginx_mananger.add_server( + config.local_addr_port, location="/", proxy_pass="http://backend:80" + ) + def remove_config(self, uuid: str) -> None: """Remove a config from the toml file.""" @@ -135,6 +141,7 @@ class RatholeServerToml: def __init__(self) -> None: self.server_toml = TomlReaderWriter(lock=lock, filename=self.filename) + self.nginx_manager = RatholeNginxConfigBuilder("nginx.conf") def set_bind_address(self, bind_address: str) -> None: """Set the bind address in the server toml file.""" @@ -165,6 +172,13 @@ def add_config(self, config: RatholeConfig) -> None: self.server_toml.write(toml) + self.nginx_manager.add_server( + config.local_addr_port, + location="/", + proxy_pass=config.local_address, + server_name=f"{config.server_name}.local*", + ) + def remove_config(self, uuid: str) -> None: """Remove a config from the toml file.""" From 17e9a2f52410deba585c633734b26ddea2065599 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 25 Apr 2024 09:57:40 +0700 Subject: [PATCH 011/309] fix relative import issues --- packages/grid/rathole/main.py | 10 ++++------ packages/grid/rathole/rathole.dockerfile | 4 +++- packages/grid/rathole/start-client.sh | 2 -- packages/grid/rathole/utils.py | 8 +++----- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/main.py index fe4c8f8a081..dfb5f6bd2a3 100644 --- a/packages/grid/rathole/main.py +++ b/packages/grid/rathole/main.py @@ -7,12 +7,10 @@ from fastapi import FastAPI from fastapi import status from loguru import logger - -# relative -from .models import RatholeConfig -from .models import ResponseModel -from .utils import RatholeClientToml -from .utils import RatholeServerToml +from models import RatholeConfig +from models import ResponseModel +from utils import RatholeClientToml +from utils import RatholeServerToml # Logging Configuration log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 3916a0eb12f..5402b4291c1 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -22,12 +22,14 @@ COPY ./start-server.sh /app/start-server.sh COPY ./client.toml /app/client.toml COPY ./server.toml /app/server.toml COPY ./nginx.conf /etc/nginx/conf.d/default.conf +COPY ./__init__.py /app/__init__.py COPY ./main.py /app/main.py +COPY ./models.py /app/models.py COPY ./nginx_builder.py /app/nginx_builder.py COPY ./utils.py /app/utils.py +COPY ./toml_writer.py /app/toml_writer.py COPY ./requirements.txt /app/requirements.txt - RUN pip install --user -r requirements.txt CMD ["sh", "-c", "/app/start-$MODE.sh"] EXPOSE 2333/udp diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh index 850e71ab6e3..035d370a1dc 100755 --- a/packages/grid/rathole/start-client.sh +++ b/packages/grid/rathole/start-client.sh @@ -7,8 +7,6 @@ PORT=${PORT:-80} RELOAD="" DEBUG_CMD="" - - apt update && apt install -y nginx nginx & exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" & diff --git a/packages/grid/rathole/utils.py b/packages/grid/rathole/utils.py index 485e4ae7e23..c6d54c8d7fe 100644 --- a/packages/grid/rathole/utils.py +++ b/packages/grid/rathole/utils.py @@ -2,11 +2,9 @@ # third party from filelock import FileLock - -# relative -from .models import RatholeConfig -from .nginx_builder import RatholeNginxConfigBuilder -from .toml_writer import TomlReaderWriter +from models import RatholeConfig +from nginx_builder import RatholeNginxConfigBuilder +from toml_writer import TomlReaderWriter lock = FileLock("rathole.toml.lock") From 3f1be01132c4c824cf93b175370b7dd687843d38 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 24 Apr 2024 22:45:39 +0530 Subject: [PATCH 012/309] fix rathole server --- packages/grid/rathole/nginx.conf | 7 ------- packages/grid/rathole/rathole.dockerfile | 7 +------ packages/grid/rathole/{ => server}/__init__.py | 0 packages/grid/rathole/{ => server}/main.py | 17 ++++++++++------- packages/grid/rathole/{ => server}/models.py | 0 .../grid/rathole/{ => server}/nginx_builder.py | 0 .../grid/rathole/{ => server}/toml_writer.py | 0 packages/grid/rathole/{ => server}/utils.py | 0 packages/grid/rathole/start-client.sh | 6 +++--- 9 files changed, 14 insertions(+), 23 deletions(-) rename packages/grid/rathole/{ => server}/__init__.py (100%) rename packages/grid/rathole/{ => server}/main.py (83%) rename packages/grid/rathole/{ => server}/models.py (100%) rename packages/grid/rathole/{ => server}/nginx_builder.py (100%) rename packages/grid/rathole/{ => server}/toml_writer.py (100%) rename packages/grid/rathole/{ => server}/utils.py (100%) diff --git a/packages/grid/rathole/nginx.conf b/packages/grid/rathole/nginx.conf index b660e980400..2be02976de4 100644 --- a/packages/grid/rathole/nginx.conf +++ b/packages/grid/rathole/nginx.conf @@ -5,11 +5,4 @@ http { proxy_pass http://host.docker.internal:8080; } } - - server { - listen 8001; - location / { - proxy_pass http://host.docker.internal:8081; - } - } } \ No newline at end of file diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 5402b4291c1..91f3a5d11e9 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -22,13 +22,8 @@ COPY ./start-server.sh /app/start-server.sh COPY ./client.toml /app/client.toml COPY ./server.toml /app/server.toml COPY ./nginx.conf /etc/nginx/conf.d/default.conf -COPY ./__init__.py /app/__init__.py -COPY ./main.py /app/main.py -COPY ./models.py /app/models.py -COPY ./nginx_builder.py /app/nginx_builder.py -COPY ./utils.py /app/utils.py -COPY ./toml_writer.py /app/toml_writer.py COPY ./requirements.txt /app/requirements.txt +COPY ./server/ /app/server/ RUN pip install --user -r requirements.txt CMD ["sh", "-c", "/app/start-$MODE.sh"] diff --git a/packages/grid/rathole/__init__.py b/packages/grid/rathole/server/__init__.py similarity index 100% rename from packages/grid/rathole/__init__.py rename to packages/grid/rathole/server/__init__.py diff --git a/packages/grid/rathole/main.py b/packages/grid/rathole/server/main.py similarity index 83% rename from packages/grid/rathole/main.py rename to packages/grid/rathole/server/main.py index dfb5f6bd2a3..201618c6432 100644 --- a/packages/grid/rathole/main.py +++ b/packages/grid/rathole/server/main.py @@ -7,10 +7,10 @@ from fastapi import FastAPI from fastapi import status from loguru import logger -from models import RatholeConfig -from models import ResponseModel -from utils import RatholeClientToml -from utils import RatholeServerToml +from server.models import RatholeConfig +from server.models import ResponseModel +from server.utils import RatholeClientToml +from server.utils import RatholeServerToml # Logging Configuration log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() @@ -40,7 +40,7 @@ async def healthcheck() -> bool: @app.get( - "/healthcheck", + "/", response_model=ResponseModel, status_code=status.HTTP_200_OK, ) @@ -84,8 +84,11 @@ async def update_config(config: RatholeConfig) -> ResponseModel: @app.get( "/config/{uuid}", - response_model=RatholeConfig, + response_model=RatholeConfig | ResponseModel, status_code=status.HTTP_201_CREATED, ) async def get_config(uuid: str) -> RatholeConfig: - return RatholeTomlManager.get_config(uuid) + config = RatholeTomlManager.get_config(uuid) + if config is None: + return ResponseModel(message="Config not found") + return config diff --git a/packages/grid/rathole/models.py b/packages/grid/rathole/server/models.py similarity index 100% rename from packages/grid/rathole/models.py rename to packages/grid/rathole/server/models.py diff --git a/packages/grid/rathole/nginx_builder.py b/packages/grid/rathole/server/nginx_builder.py similarity index 100% rename from packages/grid/rathole/nginx_builder.py rename to packages/grid/rathole/server/nginx_builder.py diff --git a/packages/grid/rathole/toml_writer.py b/packages/grid/rathole/server/toml_writer.py similarity index 100% rename from packages/grid/rathole/toml_writer.py rename to packages/grid/rathole/server/toml_writer.py diff --git a/packages/grid/rathole/utils.py b/packages/grid/rathole/server/utils.py similarity index 100% rename from packages/grid/rathole/utils.py rename to packages/grid/rathole/server/utils.py diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh index 035d370a1dc..9b667f40e85 100755 --- a/packages/grid/rathole/start-client.sh +++ b/packages/grid/rathole/start-client.sh @@ -1,10 +1,10 @@ #!/usr/bin/env bash -APP_MODULE=main:app +APP_MODULE=server.main:app LOG_LEVEL=${LOG_LEVEL:-info} HOST=${HOST:-0.0.0.0} -PORT=${PORT:-80} -RELOAD="" +PORT=${PORT:-5555} +RELOAD="--reload" DEBUG_CMD="" apt update && apt install -y nginx From f5bd6ff9a0d6adbed330d687d43dc6cfa0507fb0 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 26 Apr 2024 11:04:17 +0530 Subject: [PATCH 013/309] add rathole service to docker compose and traefik template --- packages/grid/default.env | 1 + packages/grid/docker-compose.build.yml | 6 ++++++ packages/grid/docker-compose.dev.yml | 6 ++++++ packages/grid/docker-compose.yml | 13 ++++++++++++- packages/grid/traefik/docker/dynamic.yml | 13 ++++++------- packages/hagrid/hagrid/cli.py | 14 ++++++++++++++ 6 files changed, 45 insertions(+), 8 deletions(-) diff --git a/packages/grid/default.env b/packages/grid/default.env index fb9cbd9a88d..bfb20ef1194 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -21,6 +21,7 @@ TRAEFIK_PUBLIC_TAG=traefik-public STACK_NAME=grid-openmined-org DOCKER_IMAGE_BACKEND=openmined/grid-backend DOCKER_IMAGE_FRONTEND=openmined/grid-frontend +DOCKER_IMAGE_RATHOLE=openmined/grid-rathole DOCKER_IMAGE_TRAEFIK=traefik TRAEFIK_VERSION=v2.11.0 REDIS_VERSION=6.2 diff --git a/packages/grid/docker-compose.build.yml b/packages/grid/docker-compose.build.yml index 7dc60d3fe41..cd43380ec18 100644 --- a/packages/grid/docker-compose.build.yml +++ b/packages/grid/docker-compose.build.yml @@ -22,3 +22,9 @@ services: context: ${RELATIVE_PATH}../ dockerfile: ./grid/backend/backend.dockerfile target: "backend" + + rathole: + build: + context: ${RELATIVE_PATH}../ + dockerfile: ./grid/rathole/rathole.dockerfile + target: "rathole" diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml index d2b1f142053..c6a3c14d9e6 100644 --- a/packages/grid/docker-compose.dev.yml +++ b/packages/grid/docker-compose.dev.yml @@ -52,6 +52,12 @@ services: stdin_open: true tty: true + rathole: + volumes: + - ${RELATIVE_PATH}./rathole/server:/root/app/server + environment: + - DEV_MODE=True + # backend_stream: # volumes: # - ${RELATIVE_PATH}./backend/grid:/root/app/grid diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index c1b4599e300..425ed318565 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -50,13 +50,24 @@ services: - VERSION_HASH=${VERSION_HASH} - PORT=80 - HTTP_PORT=${HTTP_PORT} - - HTTPS_PORT=${HTTPS_PORT} + - HTTPS_PORT=${HTTPS_PORT}RELOAD - BACKEND_API_BASE_URL=${BACKEND_API_BASE_URL} extra_hosts: - "host.docker.internal:host-gateway" labels: - "orgs.openmined.syft=this is a syft frontend container" + rathole: + restart: always + image: "${DOCKER_IMAGE_RATHOLE?Variable not set}:${VERSION-latest}" + profiles: + - rathole + environment: + - SERVICE_NAME=rathole + - APP_LOG_LEVEL=${APP_LOG_LEVEL} + - MODE=${MODE} + - DEV_MODE=${DEV_MODE} + # redis: # restart: always # image: redis:${REDIS_VERSION?Variable not set} diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index cc6a7bb7ee4..3c6d0945077 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -48,14 +48,13 @@ http: middlewares: - "blob-storage-url" - "blob-storage-host" - vpn: - rule: "PathPrefix(`/vpn`)" + rathole: + rule: "PathPrefix(`/rathole`)" entryPoints: - web - - vpn - service: "headscale" + service: "rathole" middlewares: - - "vpn-url" + - "rathole-url" ping: rule: "PathPrefix(`/ping`)" entryPoints: @@ -73,7 +72,7 @@ http: stripprefix: prefixes: /blob forceslash: true - vpn-url: + rathole-url: stripprefix: - prefixes: /vpn + prefixes: /rathole forceslash: true diff --git a/packages/hagrid/hagrid/cli.py b/packages/hagrid/hagrid/cli.py index 8fab98b4a83..cb0a3e9376c 100644 --- a/packages/hagrid/hagrid/cli.py +++ b/packages/hagrid/hagrid/cli.py @@ -496,6 +496,11 @@ def clean(location: str) -> None: is_flag=True, help="Enable auto approval of association requests", ) +@click.option( + "--rathole", + is_flag=True, + help="Enable rathole service", +) def launch(args: tuple[str], **kwargs: Any) -> None: verb = get_launch_verb() try: @@ -1314,6 +1319,7 @@ def create_launch_cmd( parsed_kwargs["headless"] = headless parsed_kwargs["tls"] = bool(kwargs["tls"]) + parsed_kwargs["enable_rathole"] = bool(kwargs["rathole"]) parsed_kwargs["test"] = bool(kwargs["test"]) parsed_kwargs["dev"] = bool(kwargs["dev"]) @@ -2241,6 +2247,11 @@ def create_launch_docker_cmd( else bool(kwargs["use_blob_storage"]) ) + enable_rathole = bool(kwargs.get("enable_rathole")) or str(node_type.input) in [ + "network", + "gateway", + ] + # use a docker volume host_path = "credentials-data" @@ -2411,6 +2422,9 @@ def create_launch_docker_cmd( if use_blob_storage: cmd += " --profile blob-storage" + if enable_rathole: + cmd += " --profile rathole" + # no frontend container so expect bad gateway on the / route if not bool(kwargs["headless"]): cmd += " --profile frontend" From 284ca0585a33260d738d0fdb2596580298564f38 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 26 Apr 2024 11:20:21 +0530 Subject: [PATCH 014/309] set rathole mode in hagrid cli --- packages/grid/docker-compose.dev.yml | 4 ++++ packages/grid/docker-compose.yml | 9 +++++++++ packages/grid/rathole/start-client.sh | 4 ++-- packages/hagrid/hagrid/cli.py | 7 +++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml index c6a3c14d9e6..185cd461170 100644 --- a/packages/grid/docker-compose.dev.yml +++ b/packages/grid/docker-compose.dev.yml @@ -57,6 +57,10 @@ services: - ${RELATIVE_PATH}./rathole/server:/root/app/server environment: - DEV_MODE=True + stdin_open: true + tty: true + ports: + - "2333" # backend_stream: # volumes: diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index 425ed318565..1039a80bc1a 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -67,6 +67,15 @@ services: - APP_LOG_LEVEL=${APP_LOG_LEVEL} - MODE=${MODE} - DEV_MODE=${DEV_MODE} + - APP_PORT=${APP_PORT} + - RATHOLE_PORT=${RATHOLE_PORT:-2333} + extra_hosts: + - "host.docker.internal:host-gateway" + labels: + - "orgs.openmined.syft=this is a syft rathole container" + ports: + - "${APP_PORT}:${APP_PORT}" + - "${RATHOLE_PORT}:${RATHOLE_PORT}" # redis: # restart: always diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh index 9b667f40e85..9bbb7d9097d 100755 --- a/packages/grid/rathole/start-client.sh +++ b/packages/grid/rathole/start-client.sh @@ -3,11 +3,11 @@ APP_MODULE=server.main:app LOG_LEVEL=${LOG_LEVEL:-info} HOST=${HOST:-0.0.0.0} -PORT=${PORT:-5555} +PORT=${APP_PORT:-5555} RELOAD="--reload" DEBUG_CMD="" apt update && apt install -y nginx nginx & -exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" & +exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & /app/rathole client.toml diff --git a/packages/hagrid/hagrid/cli.py b/packages/hagrid/hagrid/cli.py index cb0a3e9376c..1bd901d42ba 100644 --- a/packages/hagrid/hagrid/cli.py +++ b/packages/hagrid/hagrid/cli.py @@ -2263,6 +2263,10 @@ def create_launch_docker_cmd( # we might need to change this for the hagrid template mode host_path = f"{RELATIVE_PATH}./backend/grid/storage/{snake_name}" + rathole_mode = ( + "client" if enable_rathole and str(node_type.input) in ["domain"] else "server" + ) + envs = { "RELEASE": "production", "COMPOSE_DOCKER_CLI_BUILD": 1, @@ -2358,6 +2362,9 @@ def create_launch_docker_cmd( if "enable_signup" in kwargs: envs["ENABLE_SIGNUP"] = kwargs["enable_signup"] + if enable_rathole: + envs["MODE"] = rathole_mode + cmd = "" args = [] for k, v in envs.items(): From ce2d36eeec1450e9d259670f61f1be88f41b8bb1 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 26 Apr 2024 15:17:55 +0530 Subject: [PATCH 015/309] fix rathole config in docker compose files - fix imports in rathole utils - add rathole loadbalancer to traefik --- packages/grid/docker-compose.build.yml | 5 ++--- packages/grid/docker-compose.dev.yml | 6 ++++-- packages/grid/docker-compose.pull.yml | 3 +++ packages/grid/docker-compose.yml | 7 ++----- packages/grid/rathole/server/utils.py | 8 +++++--- packages/grid/rathole/start-server.sh | 11 +++++++++++ packages/grid/traefik/docker/dynamic.yml | 4 ++-- 7 files changed, 29 insertions(+), 15 deletions(-) diff --git a/packages/grid/docker-compose.build.yml b/packages/grid/docker-compose.build.yml index cd43380ec18..a0175bc762a 100644 --- a/packages/grid/docker-compose.build.yml +++ b/packages/grid/docker-compose.build.yml @@ -25,6 +25,5 @@ services: rathole: build: - context: ${RELATIVE_PATH}../ - dockerfile: ./grid/rathole/rathole.dockerfile - target: "rathole" + context: ${RELATIVE_PATH}./rathole + dockerfile: rathole.dockerfile diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml index 185cd461170..6ac138cba31 100644 --- a/packages/grid/docker-compose.dev.yml +++ b/packages/grid/docker-compose.dev.yml @@ -54,13 +54,15 @@ services: rathole: volumes: - - ${RELATIVE_PATH}./rathole/server:/root/app/server + - ${RELATIVE_PATH}./rathole/:/root/app/ environment: - DEV_MODE=True + - APP_PORT=5555 + - APP_LOG_LEVEL=debug stdin_open: true tty: true ports: - - "2333" + - "2333:2333" # backend_stream: # volumes: diff --git a/packages/grid/docker-compose.pull.yml b/packages/grid/docker-compose.pull.yml index db2329b04df..7fea3571f8a 100644 --- a/packages/grid/docker-compose.pull.yml +++ b/packages/grid/docker-compose.pull.yml @@ -24,3 +24,6 @@ services: # Temporary fix until we refactor pull, build, launch UI step during hagrid launch worker: image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" + + rathole: + image: "${DOCKER_IMAGE_RATHOLE?Variable not set}:${VERSION-latest}" diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index 1039a80bc1a..9273a109a60 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -64,18 +64,15 @@ services: - rathole environment: - SERVICE_NAME=rathole - - APP_LOG_LEVEL=${APP_LOG_LEVEL} + - APP_LOG_LEVEL=${APP_LOG_LEVEL:-info} - MODE=${MODE} - DEV_MODE=${DEV_MODE} - - APP_PORT=${APP_PORT} + - APP_PORT=${APP_PORT:-5555} - RATHOLE_PORT=${RATHOLE_PORT:-2333} extra_hosts: - "host.docker.internal:host-gateway" labels: - "orgs.openmined.syft=this is a syft rathole container" - ports: - - "${APP_PORT}:${APP_PORT}" - - "${RATHOLE_PORT}:${RATHOLE_PORT}" # redis: # restart: always diff --git a/packages/grid/rathole/server/utils.py b/packages/grid/rathole/server/utils.py index c6d54c8d7fe..485e4ae7e23 100644 --- a/packages/grid/rathole/server/utils.py +++ b/packages/grid/rathole/server/utils.py @@ -2,9 +2,11 @@ # third party from filelock import FileLock -from models import RatholeConfig -from nginx_builder import RatholeNginxConfigBuilder -from toml_writer import TomlReaderWriter + +# relative +from .models import RatholeConfig +from .nginx_builder import RatholeNginxConfigBuilder +from .toml_writer import TomlReaderWriter lock = FileLock("rathole.toml.lock") diff --git a/packages/grid/rathole/start-server.sh b/packages/grid/rathole/start-server.sh index 700b081a60d..2ab8d52c7ed 100755 --- a/packages/grid/rathole/start-server.sh +++ b/packages/grid/rathole/start-server.sh @@ -1,2 +1,13 @@ #!/usr/bin/env bash + +APP_MODULE=server.main:app +LOG_LEVEL=${LOG_LEVEL:-info} +HOST=${HOST:-0.0.0.0} +PORT=${APP_PORT:-5555} +RELOAD="--reload" +DEBUG_CMD="" + +apt update && apt install -y nginx +nginx & +exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & /app/rathole server.toml diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index 3c6d0945077..830ad50ad86 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -20,10 +20,10 @@ http: loadBalancer: servers: - url: "http://seaweedfs:4001" - headscale: + rathole: loadBalancer: servers: - - url: "http://headscale:8080" + - url: "http://rathole:5555" routers: frontend: rule: "PathPrefix(`/`)" From c179ddb579e387fd1448d1239afbdf4dac2e46bb Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sun, 28 Apr 2024 18:29:46 +0530 Subject: [PATCH 016/309] fix client.toml --- packages/grid/docker-compose.dev.yml | 28 +------- packages/grid/docker-compose.pull.yml | 9 --- packages/grid/docker-compose.yml | 84 ------------------------ packages/grid/rathole/client.toml | 2 +- packages/grid/rathole/nginx.conf | 10 ++- packages/grid/traefik/docker/dynamic.yml | 18 ++++- 6 files changed, 23 insertions(+), 128 deletions(-) diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml index 6ac138cba31..bcde98488e7 100644 --- a/packages/grid/docker-compose.dev.yml +++ b/packages/grid/docker-compose.dev.yml @@ -17,16 +17,6 @@ services: environment: - FRONTEND_TARGET=grid-ui-development - # redis: - # ports: - # - "6379" - - # queue: - # image: rabbitmq:3-management - # ports: - # - "15672" # admin web port - # # - "5672" # AMQP port - mongo: ports: - "27017" @@ -62,23 +52,7 @@ services: stdin_open: true tty: true ports: - - "2333:2333" - - # backend_stream: - # volumes: - # - ${RELATIVE_PATH}./backend/grid:/root/app/grid - # - ${RELATIVE_PATH}../syft:/root/app/syft - # - ${RELATIVE_PATH}./data/package-cache:/root/.cache - # environment: - # - DEV_MODE=True - - # celeryworker: - # volumes: - # - ${RELATIVE_PATH}./backend/grid:/root/app/grid - # - ${RELATIVE_PATH}../syft/:/root/app/syft - # - ${RELATIVE_PATH}./data/package-cache:/root/.cache - # environment: - # - DEV_MODE=True + - 2333:2333 seaweedfs: volumes: diff --git a/packages/grid/docker-compose.pull.yml b/packages/grid/docker-compose.pull.yml index 7fea3571f8a..e68ed03d968 100644 --- a/packages/grid/docker-compose.pull.yml +++ b/packages/grid/docker-compose.pull.yml @@ -1,17 +1,8 @@ version: "3.8" services: - # redis: - # image: redis:${REDIS_VERSION?Variable not set} - - # queue: - # image: rabbitmq:${RABBITMQ_VERSION?Variable not Set}${RABBITMQ_MANAGEMENT:-} - seaweedfs: image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}:${VERSION-latest}" - # docker-host: - # image: qoomon/docker-host - proxy: image: ${DOCKER_IMAGE_TRAEFIK?Variable not set}:${TRAEFIK_VERSION?Variable not set} diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index 9273a109a60..8f7ec80765f 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -74,27 +74,6 @@ services: labels: - "orgs.openmined.syft=this is a syft rathole container" - # redis: - # restart: always - # image: redis:${REDIS_VERSION?Variable not set} - # volumes: - # - app-redis-data:/data - # - ./redis/redis.conf:/usr/local/etc/redis/redis.conf - # environment: - # - SERVICE_NAME=redis - # - RELEASE=${RELEASE:-production} - # env_file: - # - .env - - # queue: - # restart: always - # image: rabbitmq:3 - # environment: - # - SERVICE_NAME=queue - # - RELEASE=${RELEASE:-production} - # volumes: - # - ./rabbitmq/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf - worker: restart: always image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" @@ -183,69 +162,6 @@ services: labels: - "orgs.openmined.syft=this is a syft backend container" - # backend_stream: - # restart: always - # image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" - # depends_on: - # - proxy - # env_file: - # - .env - # environment: - # - SERVICE_NAME=backend_stream - # - RELEASE=${RELEASE:-production} - # - VERSION=${VERSION} - # - VERSION_HASH=${VERSION_HASH} - # - NODE_TYPE=${NODE_TYPE?Variable not set} - # - DOMAIN_NAME=${DOMAIN_NAME?Variable not set} - # - STACK_API_KEY=${STACK_API_KEY} - # - PORT=8011 - # - STREAM_QUEUE=1 - # - IGNORE_TLS_ERRORS=${IGNORE_TLS_ERRORS?False} - # - HTTP_PORT=${HTTP_PORT} - # - HTTPS_PORT=${HTTPS_PORT} - # - USE_BLOB_STORAGE=${USE_BLOB_STORAGE} - # - CONTAINER_HOST=${CONTAINER_HOST} - # - TRACE=${TRACE} - # - JAEGER_HOST=${JAEGER_HOST} - # - JAEGER_PORT=${JAEGER_PORT} - # - DEV_MODE=${DEV_MODE} - # network_mode: service:proxy - # volumes: - # - credentials-data:/root/data/creds/ - - # celeryworker: - # restart: always - # image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" - # depends_on: - # - proxy - # - queue - # env_file: - # - .env - # environment: - # - SERVICE_NAME=celeryworker - # - RELEASE=${RELEASE:-production} - # - VERSION=${VERSION} - # - VERSION_HASH=${VERSION_HASH} - # - NODE_TYPE=${NODE_TYPE?Variable not set} - # - DOMAIN_NAME=${DOMAIN_NAME?Variable not set} - # - C_FORCE_ROOT=1 - # - STACK_API_KEY=${STACK_API_KEY} - # - IGNORE_TLS_ERRORS=${IGNORE_TLS_ERRORS?False} - # - HTTP_PORT=${HTTP_PORT} - # - HTTPS_PORT=${HTTPS_PORT} - # - USE_BLOB_STORAGE=${USE_BLOB_STORAGE} - # - CONTAINER_HOST=${CONTAINER_HOST} - # - NETWORK_CHECK_INTERVAL=${NETWORK_CHECK_INTERVAL} - # - DOMAIN_CHECK_INTERVAL=${DOMAIN_CHECK_INTERVAL} - # - TRACE=${TRACE} - # - JAEGER_HOST=${JAEGER_HOST} - # - JAEGER_PORT=${JAEGER_PORT} - # - DEV_MODE=${DEV_MODE} - # command: "/app/grid/worker-start.sh" - # network_mode: service:proxy - # volumes: - # - credentials-data:/storage - seaweedfs: profiles: - blob-storage diff --git a/packages/grid/rathole/client.toml b/packages/grid/rathole/client.toml index ba8b835a569..581f1af023c 100644 --- a/packages/grid/rathole/client.toml +++ b/packages/grid/rathole/client.toml @@ -1,5 +1,5 @@ [client] -remote_addr = "host.docker.internal:2333" # public IP and port of gateway +remote_addr = "localhost:2333" # public IP and port of gateway [client.services.domain] token = "domain-specific-rathole-secret" diff --git a/packages/grid/rathole/nginx.conf b/packages/grid/rathole/nginx.conf index 2be02976de4..447c482a9e4 100644 --- a/packages/grid/rathole/nginx.conf +++ b/packages/grid/rathole/nginx.conf @@ -1,8 +1,6 @@ -http { - server { - listen 8000; - location / { - proxy_pass http://host.docker.internal:8080; - } +server { + listen 8000; + location / { + proxy_pass http://test-domain-r:8001; } } \ No newline at end of file diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index 830ad50ad86..af5284cedcb 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -24,6 +24,10 @@ http: loadBalancer: servers: - url: "http://rathole:5555" + ratholeforward: + loadBalancer: + servers: + - url: "http://rathole:2333" routers: frontend: rule: "PathPrefix(`/`)" @@ -49,12 +53,19 @@ http: - "blob-storage-url" - "blob-storage-host" rathole: - rule: "PathPrefix(`/rathole`)" + rules: "PathPrefix(`/rathole`)" entryPoints: - web service: "rathole" middlewares: - "rathole-url" + ratholeforward: + rules: "HostRegexp(`{subdomain:[a-z]+}.local.rathole`)" + entryPoints: + - web + service: "ratholeforward" + middlewares: + - "rathole-redirect" ping: rule: "PathPrefix(`/ping`)" entryPoints: @@ -76,3 +87,8 @@ http: stripprefix: prefixes: /rathole forceslash: true + rathole-redirect: + redirectregex: + regex: "^.+" + replacement: "http://rathole:2333/" + permanent: true From 895448b19475b11516ca649977248aa1f7723a44 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sun, 28 Apr 2024 18:31:35 +0530 Subject: [PATCH 017/309] fix client.toml and server.toml for testing --- packages/grid/rathole/client.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/grid/rathole/client.toml b/packages/grid/rathole/client.toml index 581f1af023c..0d878798c95 100644 --- a/packages/grid/rathole/client.toml +++ b/packages/grid/rathole/client.toml @@ -1,5 +1,5 @@ [client] -remote_addr = "localhost:2333" # public IP and port of gateway +remote_addr = "20.197.23.137:2333" # public IP and port of gateway [client.services.domain] token = "domain-specific-rathole-secret" From e04fb02103eb2e207d178078bd3089a417286fa8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 29 Apr 2024 14:53:57 +0530 Subject: [PATCH 018/309] fix loadbalancer in dynamic yml for rathole forward --- packages/grid/docker-compose.yml | 2 ++ packages/grid/traefik/docker/dynamic.yml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index 8f7ec80765f..bc8ae309055 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -62,6 +62,8 @@ services: image: "${DOCKER_IMAGE_RATHOLE?Variable not set}:${VERSION-latest}" profiles: - rathole + depends_on: + - proxy environment: - SERVICE_NAME=rathole - APP_LOG_LEVEL=${APP_LOG_LEVEL:-info} diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index af5284cedcb..6304a76be66 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -27,7 +27,7 @@ http: ratholeforward: loadBalancer: servers: - - url: "http://rathole:2333" + - url: "http://rathole:80" routers: frontend: rule: "PathPrefix(`/`)" @@ -90,5 +90,5 @@ http: rathole-redirect: redirectregex: regex: "^.+" - replacement: "http://rathole:2333/" + replacement: "http://rathole:80/" permanent: true From b22309eafddf36a38e338132f081be53bcc28345 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 29 Apr 2024 15:06:34 +0530 Subject: [PATCH 019/309] fix port number for rathole forwarding --- packages/grid/traefik/docker/dynamic.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index 6304a76be66..e6d9b99c292 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -27,7 +27,7 @@ http: ratholeforward: loadBalancer: servers: - - url: "http://rathole:80" + - url: "http://rathole:8000" routers: frontend: rule: "PathPrefix(`/`)" @@ -90,5 +90,5 @@ http: rathole-redirect: redirectregex: regex: "^.+" - replacement: "http://rathole:80/" + replacement: "http://rathole:8000/" permanent: true From b61d70e478e93389e47ce4a92d9e6112d78b03a5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 29 Apr 2024 17:35:40 +0530 Subject: [PATCH 020/309] retry fixing traefik --- packages/grid/traefik/docker/dynamic.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index e6d9b99c292..8ab8bbf8317 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -64,8 +64,8 @@ http: entryPoints: - web service: "ratholeforward" - middlewares: - - "rathole-redirect" + # middlewares: + # - "rathole-redirect" ping: rule: "PathPrefix(`/ping`)" entryPoints: @@ -90,5 +90,5 @@ http: rathole-redirect: redirectregex: regex: "^.+" - replacement: "http://rathole:8000/" + replacement: "http://rathole:8000$${1}" permanent: true From 6f0759ead00a0e7e93f40335269fcc7d7f4c0ce5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 30 Apr 2024 11:00:19 +0530 Subject: [PATCH 021/309] add rathole service and statefilset yaml --- .../templates/rathole/rathole-service.yaml | 30 +++++++ .../rathole/rathole-statefulset.yaml | 89 +++++++++++++++++++ packages/grid/helm/syft/values.yaml | 24 +++++ 3 files changed, 143 insertions(+) create mode 100644 packages/grid/helm/syft/templates/rathole/rathole-service.yaml create mode 100644 packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml diff --git a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml new file mode 100644 index 00000000000..777e530262b --- /dev/null +++ b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml @@ -0,0 +1,30 @@ +apiVersion: v1 +kind: Service +metadata: + name: rathole + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: rathole +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: rathole + ports: + - name: nginx + protocol: TCP + port: 80 + targetPort: 80 + - name: api + protocol: TCP + port: 5555 + targetPort: 5555 + type: NodePort + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: rathole + ports: + - name: rathole + protocol: TCP + port: 2333 + targetPort: 2333 diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml new file mode 100644 index 00000000000..ff72a8cb593 --- /dev/null +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -0,0 +1,89 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: rathole + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: rathole +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: rathole + serviceName: rathole + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: rathole + {{- if .Values.rathole.podLabels }} + {{- toYaml .Values.rathole.podLabels | nindent 8 }} + {{- end }} + {{- if .Values.rathole.podAnnotations }} + annotations: {{- toYaml .Values.rathole.podAnnotations | nindent 8 }} + {{- end }} + spec: + {{- if .Values.rathole.nodeSelector }} + nodeSelector: {{- .Values.rathole.nodeSelector | toYaml | nindent 8 }} + {{- end }} + containers: + - name: rathole + image: {{ .Values.global.registry }}/openmined/grid-rathole:{{ .Values.global.version }} + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.rathole.resources "preset" .Values.rathole.resourcesPreset) | nindent 12 }} + env: + - name: SERVICE_NAME + value: "rathole" + - name: APP_LOG_LEVEL + value: {{ .Values.rathole.appLogLevel | quote }} + - name: MODE + value: {{ .Values.rathole.mode | quote }} + - name: DEV_MODE + value: {{ .Values.rathole.devMode | quote }} + - name: APP_PORT + value: {{ .Values.rathole.appPort | quote }} + - name: RATHOLE_PORT + value: {{ .Values.rathole.ratholePort | quote }} + {{- if .Values.rathole.env }} + {{- toYaml .Values.rathole.env | nindent 12 }} + {{- end }} + ports: + - name: rathole-port + containerPort: 2333 + - name: api-port + containerPort: 5555 + - name: nginx-port + containerPort: 80 + startupProbe: + httpGet: + path: /?probe=startupProbe + port: api-port + failureThreshold: 30 + livenessProbe: + httpGet: + path: /ping?probe=livenessProbe + port: api-port + periodSeconds: 15 + timeoutSeconds: 5 + failureThreshold: 3 + volumeMounts: + - name: rathole-data + mountPath: /data + readOnly: false + # TODO: Mount the .toml and nginx.conf files + + # Add any additional container configuration here + # such as environment variables, volumes, etc. + volumeClaimTemplates: + - metadata: + name: rathole-data + spec: + accessModes: [ "ReadWriteOnce" ] + resources: + requests: + storage: {{ .Values.rathole.volumeSize | quote }} + storageClassName: {{ .Values.rathole.storageClassName | quote }} \ No newline at end of file diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index b390f67996e..1396616272e 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -222,3 +222,27 @@ ingress: # ================================================================================= extraResources: [] + + +# ================================================================================= + +rathole: + # Extra environment vars + env: null + + ratholePort: 2333 + appPort: 5555 + mode: "client" + devMode: "false" + appLogLevel: "info" + + # Pod labels & annotations + podLabels: null + podAnnotations: null + + # Node selector for pods + nodeSelector: null + + # Pod Resource Limits + resourcesPreset: small + resources: null From bffeefab9b22deeb37c9dff5e87a9be2b42c74b5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 30 Apr 2024 21:55:16 +0530 Subject: [PATCH 022/309] add rathole to devspace and define a configmap to update server.toml --- packages/grid/devspace.yaml | 13 ++++++++++ .../templates/rathole/rathole-configmap.yaml | 11 ++++++++ .../rathole/rathole-statefulset.yaml | 25 +++++++++++++------ 3 files changed, 41 insertions(+), 8 deletions(-) create mode 100644 packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 75e2757e58a..ae0a17f7ec1 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -25,6 +25,7 @@ vars: DOCKER_IMAGE_BACKEND: openmined/grid-backend DOCKER_IMAGE_FRONTEND: openmined/grid-frontend DOCKER_IMAGE_SEAWEEDFS: openmined/grid-seaweedfs + DOCKER_IMAGE_RATHOLE: openmined/grid-rathole CONTAINER_REGISTRY: "docker.io" VERSION: "0.8.7-beta.2" PLATFORM: $(uname -m | grep -q 'arm64' && echo "arm64" || echo "amd64") @@ -58,6 +59,14 @@ images: context: ./seaweedfs tags: - dev-${DEVSPACE_TIMESTAMP} + rathole: + image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_RATHOLE}" + buildKit: + args: ["--platform", "linux/${PLATFORM}"] + dockerfile: ./rathole/rathole.dockerfile + context: ./rathole + tags: + - dev-${DEVSPACE_TIMESTAMP} # This is a list of `deployments` that DevSpace can create for this project deployments: @@ -109,6 +118,10 @@ dev: - path: ./backend/grid:/root/app/grid - path: ../syft:/root/app/syft ssh: {} + rathole: + labelSelector: + app.kubernetes.io/name: syft + app.kubernetes.io/component: rathole profiles: - name: gateway diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml new file mode 100644 index 00000000000..02aca907e6f --- /dev/null +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: rathole-config + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: rathole +data: + myserver.toml: | + [server] + bind_addr = "0.0.0.0:2333" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index ff72a8cb593..992370bbbcb 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -34,7 +34,7 @@ spec: - name: rathole image: {{ .Values.global.registry }}/openmined/grid-rathole:{{ .Values.global.version }} imagePullPolicy: Always - resources: {{ include "common.resources.set" (dict "resources" .Values.rathole.resources "preset" .Values.rathole.resourcesPreset) | nindent 12 }} + resources: {{ include "common.resources.set" (dict "resources" .Values.rathole.resources "preset" .Values.rathole.resourcesPreset) | nindent 12 }} env: - name: SERVICE_NAME value: "rathole" @@ -65,25 +65,34 @@ spec: failureThreshold: 30 livenessProbe: httpGet: - path: /ping?probe=livenessProbe + path: /?probe=livenessProbe port: api-port periodSeconds: 15 timeoutSeconds: 5 failureThreshold: 3 volumeMounts: - - name: rathole-data - mountPath: /data + - name: rathole-config + mountPath: /app/data/myserver.toml + subPath: myserver.toml readOnly: false - # TODO: Mount the .toml and nginx.conf files + terminationGracePeriodSeconds: 5 + volumes: + - name: rathole-config + configMap: + name: rathole-config + # TODO: Mount the .toml and nginx.conf files # Add any additional container configuration here # such as environment variables, volumes, etc. volumeClaimTemplates: - metadata: name: rathole-data + labels: + {{- include "common.volumeLabels" . | nindent 8 }} + app.kubernetes.io/component: rathole spec: - accessModes: [ "ReadWriteOnce" ] + accessModes: + - ReadWriteOnce resources: requests: - storage: {{ .Values.rathole.volumeSize | quote }} - storageClassName: {{ .Values.rathole.storageClassName | quote }} \ No newline at end of file + storage: 10Mi \ No newline at end of file From eee0fc1502ed598b385af2389d8709983c133b58 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 1 May 2024 13:16:37 +0530 Subject: [PATCH 023/309] refactor and combine client and server.sh --- .../templates/rathole/rathole-configmap.yaml | 9 ++++++++- packages/grid/rathole/rathole.dockerfile | 3 +-- packages/grid/rathole/start-client.sh | 13 ------------ packages/grid/rathole/start-server.sh | 13 ------------ packages/grid/rathole/start.sh | 20 +++++++++++++++++++ 5 files changed, 29 insertions(+), 29 deletions(-) delete mode 100755 packages/grid/rathole/start-client.sh delete mode 100755 packages/grid/rathole/start-server.sh create mode 100755 packages/grid/rathole/start.sh diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 02aca907e6f..7405235e0f0 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -6,6 +6,13 @@ metadata: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: rathole data: - myserver.toml: | + {{- if eq .Values.rathole.mode "server" }} + server.toml: | [server] bind_addr = "0.0.0.0:2333" + {{- end }} + + {{- if eq .Values.rathole.mode "client" }} + client.toml: | + [client] + remote_addr = "" diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 91f3a5d11e9..370b4233757 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -17,8 +17,7 @@ ENV APP_LOG_LEVEL="info" COPY --from=build /rathole/target/release/rathole /app/rathole RUN apt update && apt install -y netcat-openbsd vim WORKDIR /app -COPY ./start-client.sh /app/start-client.sh -COPY ./start-server.sh /app/start-server.sh +COPY ./start.sh /app/start.sh COPY ./client.toml /app/client.toml COPY ./server.toml /app/server.toml COPY ./nginx.conf /etc/nginx/conf.d/default.conf diff --git a/packages/grid/rathole/start-client.sh b/packages/grid/rathole/start-client.sh deleted file mode 100755 index 9bbb7d9097d..00000000000 --- a/packages/grid/rathole/start-client.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash - -APP_MODULE=server.main:app -LOG_LEVEL=${LOG_LEVEL:-info} -HOST=${HOST:-0.0.0.0} -PORT=${APP_PORT:-5555} -RELOAD="--reload" -DEBUG_CMD="" - -apt update && apt install -y nginx -nginx & -exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & -/app/rathole client.toml diff --git a/packages/grid/rathole/start-server.sh b/packages/grid/rathole/start-server.sh deleted file mode 100755 index 2ab8d52c7ed..00000000000 --- a/packages/grid/rathole/start-server.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash - -APP_MODULE=server.main:app -LOG_LEVEL=${LOG_LEVEL:-info} -HOST=${HOST:-0.0.0.0} -PORT=${APP_PORT:-5555} -RELOAD="--reload" -DEBUG_CMD="" - -apt update && apt install -y nginx -nginx & -exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & -/app/rathole server.toml diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh new file mode 100755 index 00000000000..45815250e93 --- /dev/null +++ b/packages/grid/rathole/start.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +APP_MODULE=server.main:app +LOG_LEVEL=${LOG_LEVEL:-info} +HOST=${HOST:-0.0.0.0} +PORT=${APP_PORT:-5555} +RELOAD="--reload" +DEBUG_CMD="" +RATHOLE_MODE=${RATHOLE_MODE:-server} + +apt update && apt install -y nginx +nginx & exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & + +if [[ $RATHOLE_MODE == "server" ]]; then + /app/rathole server.toml +elif [[ $RATHOLE_MODE = "client" ]]; then + /app/rathole client.toml +else + echo "RATHOLE_MODE is set to an invalid value. Exiting." +fi From 0eb42a97bf0c4e16a161ef98a6f0a52193f256a0 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 1 May 2024 13:39:08 +0530 Subject: [PATCH 024/309] fix rathole configmap --- packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 7405235e0f0..4b5fbea3633 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -16,3 +16,4 @@ data: client.toml: | [client] remote_addr = "" + {{- end }} From dc31f6be048aca2e5165c323cef85c2a26bab5d8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 1 May 2024 14:07:43 +0530 Subject: [PATCH 025/309] mount volume based on rathole mode if server or client --- .../syft/templates/rathole/rathole-statefulset.yaml | 13 +++++++++++-- packages/grid/helm/syft/values.yaml | 2 +- packages/grid/rathole/rathole.dockerfile | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index 992370bbbcb..860ea1b1881 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -71,10 +71,19 @@ spec: timeoutSeconds: 5 failureThreshold: 3 volumeMounts: + {{- if -eq .Values.rathole.mode "server" }} - name: rathole-config - mountPath: /app/data/myserver.toml - subPath: myserver.toml + mountPath: /app/data/server.toml + subPath: server.toml readOnly: false + {{- end }} + + {{- if -eq .Values.rathole.mode "client" }} + - name: rathole-config + mountPath: /app/data/client.toml + subPath: client.toml + readOnly: false + {{- end }} terminationGracePeriodSeconds: 5 volumes: - name: rathole-config diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 1396616272e..c59512e478b 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -232,7 +232,7 @@ rathole: ratholePort: 2333 appPort: 5555 - mode: "client" + mode: "server" devMode: "false" appLogLevel: "info" diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 370b4233757..0d8bebc3590 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -25,7 +25,7 @@ COPY ./requirements.txt /app/requirements.txt COPY ./server/ /app/server/ RUN pip install --user -r requirements.txt -CMD ["sh", "-c", "/app/start-$MODE.sh"] +CMD ["sh", "-c", "/app/start.sh"] EXPOSE 2333/udp EXPOSE 2333 From 00cc225d6808e4ea93369051c8ccea54be1b11a3 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 1 May 2024 14:53:48 +0530 Subject: [PATCH 026/309] mount server.toml and client.toml based on rathole mode --- .../helm/syft/templates/rathole/rathole-configmap.yaml | 4 ++++ .../helm/syft/templates/rathole/rathole-statefulset.yaml | 8 ++++---- packages/grid/rathole/client.toml | 6 ------ packages/grid/rathole/rathole.dockerfile | 2 -- packages/grid/rathole/server.toml | 6 ------ 5 files changed, 8 insertions(+), 18 deletions(-) delete mode 100644 packages/grid/rathole/client.toml delete mode 100644 packages/grid/rathole/server.toml diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 4b5fbea3633..3e843fbc1cc 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -10,6 +10,10 @@ data: server.toml: | [server] bind_addr = "0.0.0.0:2333" + + [server.services.domain] + token = "domain-specific-rathole-secret" + bind_addr = "0.0.0.0:8001" {{- end }} {{- if eq .Values.rathole.mode "client" }} diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index 860ea1b1881..e44b5c4a442 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -71,16 +71,16 @@ spec: timeoutSeconds: 5 failureThreshold: 3 volumeMounts: - {{- if -eq .Values.rathole.mode "server" }} + {{- if eq .Values.rathole.mode "server" }} - name: rathole-config - mountPath: /app/data/server.toml + mountPath: /app/server.toml subPath: server.toml readOnly: false {{- end }} - {{- if -eq .Values.rathole.mode "client" }} + {{- if eq .Values.rathole.mode "client" }} - name: rathole-config - mountPath: /app/data/client.toml + mountPath: /app/client.toml subPath: client.toml readOnly: false {{- end }} diff --git a/packages/grid/rathole/client.toml b/packages/grid/rathole/client.toml deleted file mode 100644 index 0d878798c95..00000000000 --- a/packages/grid/rathole/client.toml +++ /dev/null @@ -1,6 +0,0 @@ -[client] -remote_addr = "20.197.23.137:2333" # public IP and port of gateway - -[client.services.domain] -token = "domain-specific-rathole-secret" -local_addr = "localhost:8000" # nginx proxy diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 0d8bebc3590..1d1b82785af 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -18,8 +18,6 @@ COPY --from=build /rathole/target/release/rathole /app/rathole RUN apt update && apt install -y netcat-openbsd vim WORKDIR /app COPY ./start.sh /app/start.sh -COPY ./client.toml /app/client.toml -COPY ./server.toml /app/server.toml COPY ./nginx.conf /etc/nginx/conf.d/default.conf COPY ./requirements.txt /app/requirements.txt COPY ./server/ /app/server/ diff --git a/packages/grid/rathole/server.toml b/packages/grid/rathole/server.toml deleted file mode 100644 index 8145491e7cc..00000000000 --- a/packages/grid/rathole/server.toml +++ /dev/null @@ -1,6 +0,0 @@ -[server] -bind_addr = "0.0.0.0:2333" # public open port - -[server.services.domain] -token = "domain-specific-rathole-secret" -bind_addr = "0.0.0.0:8001" From 79bcc23c1f5cedff639c7712b63a1cc54ce77628 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 2 May 2024 20:15:21 +0530 Subject: [PATCH 027/309] update configmap to seperate path of conf loaded and used in rathole add a script to copy toml file from loaded path to used path cleanup start.sh --- packages/grid/devspace.yaml | 3 ++ .../syft/templates/proxy/proxy-configmap.yaml | 9 +++++ .../templates/proxy/proxy-deployment.yaml | 7 ++++ .../templates/rathole/rathole-service.yaml | 14 +------ .../rathole/rathole-statefulset.yaml | 37 ++++--------------- packages/grid/rathole/start.sh | 23 ++++++------ 6 files changed, 39 insertions(+), 54 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index ae0a17f7ec1..3c22db1fd91 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -133,6 +133,9 @@ profiles: path: images.seaweedfs - op: remove path: dev.seaweedfs + - op: replace + path: deployments.syft.helm.values.rathole.mode + value: "server" - name: gcp patches: diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 1989f399161..2654d49c8a1 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -72,4 +72,13 @@ data: providers: file: filename: /etc/traefik/dynamic.yml +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: proxy-config-dynamic + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: proxy +data: {} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml index 6adb42f6c9c..f06051b9e0b 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml @@ -45,6 +45,10 @@ spec: - mountPath: /etc/traefik name: traefik-conf readOnly: false + volumeMounts: + - mountPath: /etc/traefik/dynamic + name: traefik-conf-dynamic + readOnly: false startupProbe: null livenessProbe: httpGet: @@ -59,3 +63,6 @@ spec: - configMap: name: proxy-config name: traefik-conf + - configMap: + name: proxy-config-dynamic + name: traefik-conf-dynamic \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml index 777e530262b..5d6184e4795 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml @@ -7,19 +7,6 @@ metadata: app.kubernetes.io/component: rathole spec: type: ClusterIP - selector: - {{- include "common.selectorLabels" . | nindent 4 }} - app.kubernetes.io/component: rathole - ports: - - name: nginx - protocol: TCP - port: 80 - targetPort: 80 - - name: api - protocol: TCP - port: 5555 - targetPort: 5555 - type: NodePort selector: {{- include "common.selectorLabels" . | nindent 4 }} app.kubernetes.io/component: rathole @@ -28,3 +15,4 @@ spec: protocol: TCP port: 2333 targetPort: 2333 + diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index e44b5c4a442..0f07516e352 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -54,45 +54,22 @@ spec: ports: - name: rathole-port containerPort: 2333 - - name: api-port - containerPort: 5555 - - name: nginx-port - containerPort: 80 - startupProbe: - httpGet: - path: /?probe=startupProbe - port: api-port - failureThreshold: 30 - livenessProbe: - httpGet: - path: /?probe=livenessProbe - port: api-port - periodSeconds: 15 - timeoutSeconds: 5 - failureThreshold: 3 + startupProbe: null + livenessProbe: null volumeMounts: - {{- if eq .Values.rathole.mode "server" }} - - name: rathole-config - mountPath: /app/server.toml - subPath: server.toml + - name: mount-config + mountPath: /conf/ readOnly: false - {{- end }} - - {{- if eq .Values.rathole.mode "client" }} - name: rathole-config - mountPath: /app/client.toml - subPath: client.toml + mountPath: /app/conf/ readOnly: false - {{- end }} terminationGracePeriodSeconds: 5 volumes: - name: rathole-config + emptyDir: {} + - name: mount-config configMap: name: rathole-config - # TODO: Mount the .toml and nginx.conf files - - # Add any additional container configuration here - # such as environment variables, volumes, etc. volumeClaimTemplates: - metadata: name: rathole-data diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index 45815250e93..907972bb50a 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -1,20 +1,21 @@ #!/usr/bin/env bash - -APP_MODULE=server.main:app -LOG_LEVEL=${LOG_LEVEL:-info} -HOST=${HOST:-0.0.0.0} -PORT=${APP_PORT:-5555} -RELOAD="--reload" -DEBUG_CMD="" RATHOLE_MODE=${RATHOLE_MODE:-server} -apt update && apt install -y nginx -nginx & exec python -m $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $APP_PORT --log-level $LOG_LEVEL "$APP_MODULE" & +cp -L -r -f /conf/* conf/ +#!/bin/bash if [[ $RATHOLE_MODE == "server" ]]; then - /app/rathole server.toml + /app/rathole conf/server.toml & elif [[ $RATHOLE_MODE = "client" ]]; then - /app/rathole client.toml + /app/rathole conf/client.toml & else echo "RATHOLE_MODE is set to an invalid value. Exiting." fi + +while true +do + # Execute your script here + cp -L -r -f /conf/* conf/ + # Sleep for 10 seconds + sleep 10 +done \ No newline at end of file From c693d3ff2d7ad91eeeb02bfd350e2d0a8348ca3f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 11:29:00 +0530 Subject: [PATCH 028/309] add some comment to start script --- packages/grid/rathole/start.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index 907972bb50a..e948973decf 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -3,7 +3,6 @@ RATHOLE_MODE=${RATHOLE_MODE:-server} cp -L -r -f /conf/* conf/ -#!/bin/bash if [[ $RATHOLE_MODE == "server" ]]; then /app/rathole conf/server.toml & elif [[ $RATHOLE_MODE = "client" ]]; then @@ -12,6 +11,7 @@ else echo "RATHOLE_MODE is set to an invalid value. Exiting." fi +# reload config every 10 seconds while true do # Execute your script here From 3d7759e43b9d1bf6027560204e06386e2b31fe84 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sun, 5 May 2024 14:50:41 +0530 Subject: [PATCH 029/309] rathole: remove fastapi and nginx ports mapping from traefik --- .../syft/templates/proxy/proxy-configmap.yaml | 3 +- .../templates/proxy/proxy-deployment.yaml | 17 +++++----- .../templates/rathole/rathole-service.yaml | 1 - packages/grid/traefik/docker/dynamic.yml | 31 ------------------- 4 files changed, 9 insertions(+), 43 deletions(-) diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 2654d49c8a1..6e7dd9c0fe0 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -81,4 +81,5 @@ metadata: labels: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: proxy -data: {} \ No newline at end of file +data: + rathole-dynamic.yml: | diff --git a/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml index f06051b9e0b..db5bef8a813 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml @@ -45,10 +45,6 @@ spec: - mountPath: /etc/traefik name: traefik-conf readOnly: false - volumeMounts: - - mountPath: /etc/traefik/dynamic - name: traefik-conf-dynamic - readOnly: false startupProbe: null livenessProbe: httpGet: @@ -60,9 +56,10 @@ spec: readinessProbe: null terminationGracePeriodSeconds: 5 volumes: - - configMap: - name: proxy-config - name: traefik-conf - - configMap: - name: proxy-config-dynamic - name: traefik-conf-dynamic \ No newline at end of file + - name: traefik-conf + projected: + sources: + - configMap: + name: proxy-config + - configMap: + name: proxy-config-dynamic \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml index 5d6184e4795..d9050f0d693 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml @@ -15,4 +15,3 @@ spec: protocol: TCP port: 2333 targetPort: 2333 - diff --git a/packages/grid/traefik/docker/dynamic.yml b/packages/grid/traefik/docker/dynamic.yml index 8ab8bbf8317..61e68e7ad03 100644 --- a/packages/grid/traefik/docker/dynamic.yml +++ b/packages/grid/traefik/docker/dynamic.yml @@ -20,14 +20,6 @@ http: loadBalancer: servers: - url: "http://seaweedfs:4001" - rathole: - loadBalancer: - servers: - - url: "http://rathole:5555" - ratholeforward: - loadBalancer: - servers: - - url: "http://rathole:8000" routers: frontend: rule: "PathPrefix(`/`)" @@ -52,20 +44,6 @@ http: middlewares: - "blob-storage-url" - "blob-storage-host" - rathole: - rules: "PathPrefix(`/rathole`)" - entryPoints: - - web - service: "rathole" - middlewares: - - "rathole-url" - ratholeforward: - rules: "HostRegexp(`{subdomain:[a-z]+}.local.rathole`)" - entryPoints: - - web - service: "ratholeforward" - # middlewares: - # - "rathole-redirect" ping: rule: "PathPrefix(`/ping`)" entryPoints: @@ -83,12 +61,3 @@ http: stripprefix: prefixes: /blob forceslash: true - rathole-url: - stripprefix: - prefixes: /rathole - forceslash: true - rathole-redirect: - redirectregex: - regex: "^.+" - replacement: "http://rathole:8000$${1}" - permanent: true From a1a379b954646e8247b0691d2e688510d9be3b01 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 8 May 2024 15:19:28 +0530 Subject: [PATCH 030/309] add toml_w and tomli packages - add methods to extract configmap by name --- packages/grid/devspace.yaml | 2 ++ .../syft/templates/proxy/proxy-configmap.yaml | 12 ++++++++++++ .../syft/templates/rathole/rathole-service.yaml | 4 ++-- packages/grid/helm/syft/values.yaml | 1 + packages/grid/helm/values.dev.yaml | 3 +++ packages/syft/setup.cfg | 2 ++ packages/syft/src/syft/custom_worker/k8s.py | 15 +++++++++++++++ 7 files changed, 37 insertions(+), 2 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 3c22db1fd91..214ea0a23b4 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -122,6 +122,8 @@ dev: labelSelector: app.kubernetes.io/name: syft app.kubernetes.io/component: rathole + ports: + - port: "2333" # rathole profiles: - name: gateway diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 6e7dd9c0fe0..8853fae80b2 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -83,3 +83,15 @@ metadata: app.kubernetes.io/component: proxy data: rathole-dynamic.yml: | + # http: + # services: + # rathole_domain_1: + # loadBalancer: + # servers: + # - url: "http://rathole-0.rathole.syft.svc.cluster.local:2333" + # routers: + # rathole_domain_1: + # rule: "Host('domain1.domain.syft.local')" + # entryPoints: + # - "web" + # service: "rathole_domain_1" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml index d9050f0d693..01fa305ac77 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml @@ -6,12 +6,12 @@ metadata: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: rathole spec: - type: ClusterIP + clusterIP: None selector: {{- include "common.selectorLabels" . | nindent 4 }} app.kubernetes.io/component: rathole ports: - name: rathole - protocol: TCP port: 2333 targetPort: 2333 + protocol: TCP diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index c59512e478b..6e0bdfbd32d 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -229,6 +229,7 @@ extraResources: [] rathole: # Extra environment vars env: null + enabled: true ratholePort: 2333 appPort: 5555 diff --git a/packages/grid/helm/values.dev.yaml b/packages/grid/helm/values.dev.yaml index 0951f7e906c..6d0bd129028 100644 --- a/packages/grid/helm/values.dev.yaml +++ b/packages/grid/helm/values.dev.yaml @@ -42,3 +42,6 @@ frontend: proxy: resourcesPreset: null resources: null + +rathole: + enabled: "true" diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 54e059b73ac..a6ceb8b4bc1 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -68,6 +68,8 @@ syft = PyYAML==6.0.1 azure-storage-blob==12.19.1 ipywidgets==8.1.2 + tomli==2.0.1 # Later for python 3.11 > we can just use tomlib that comes with python + tomli_w==1.0.0 install_requires = %(syft)s diff --git a/packages/syft/src/syft/custom_worker/k8s.py b/packages/syft/src/syft/custom_worker/k8s.py index cb4b5765e62..60557c86afb 100644 --- a/packages/syft/src/syft/custom_worker/k8s.py +++ b/packages/syft/src/syft/custom_worker/k8s.py @@ -9,6 +9,7 @@ # third party import kr8s from kr8s.objects import APIObject +from kr8s.objects import ConfigMap from kr8s.objects import Pod from kr8s.objects import Secret from pydantic import BaseModel @@ -171,6 +172,20 @@ def b64encode_secret(data: str) -> str: """Convert the data to base64 encoded string for Secret.""" return base64.b64encode(data.encode()).decode() + @staticmethod + def get_configmap(client: kr8s.Api, name: str) -> ConfigMap | None: + config_map = client.get("configmaps", name) + return config_map[0] if config_map else None + + @staticmethod + def update_configmap( + config_map: ConfigMap, + patch: dict, + ) -> None: + existing_data = config_map.raw + existing_data.update(patch) + config_map.patch(patch=existing_data) + @staticmethod def create_dockerconfig_secret( secret_name: str, From 619e119f7b3390d0a4dd951bee3a309831374aa0 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 9 May 2024 12:50:40 +0530 Subject: [PATCH 031/309] add classes to handle CRUD ops for toml server and client files in rathole - add RatholeServer class to add client info to rathole server - add method to alter endpoints in traefik - add method to add config to rathole dynamic traefik config via configmaps --- .../backend/backend-service-account.yaml | 2 +- .../src/syft/service/network/node_peer.py | 1 + .../syft/src/syft/service/network/rathole.py | 173 +++++++++++++ .../src/syft/service/network/rathole_toml.py | 237 ++++++++++++++++++ 4 files changed, 412 insertions(+), 1 deletion(-) create mode 100644 packages/syft/src/syft/service/network/rathole.py create mode 100644 packages/syft/src/syft/service/network/rathole_toml.py diff --git a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml index a466d0c3fe4..7b542adfc0b 100644 --- a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml @@ -26,7 +26,7 @@ metadata: app.kubernetes.io/component: backend rules: - apiGroups: [""] - resources: ["pods", "configmaps", "secrets"] + resources: ["pods", "configmaps", "secrets", "service"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] - apiGroups: [""] resources: ["pods/log"] diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 70e6f9bfb40..453b7f3563d 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -66,6 +66,7 @@ class NodePeer(SyftObject): node_routes: list[NodeRouteType] = [] node_type: NodeType admin_email: str + rathole_token: str | None = None def existed_route( self, route: NodeRouteType | None = None, route_id: UID | None = None diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py new file mode 100644 index 00000000000..964c1dde27a --- /dev/null +++ b/packages/syft/src/syft/service/network/rathole.py @@ -0,0 +1,173 @@ +# stdlib +from typing import Self +from typing import cast + +# third party +import yaml + +# relative +from ...custom_worker.k8s import KubeUtils +from ...custom_worker.k8s import get_kr8s_client +from ...serde import serializable +from ...types.base import SyftBaseModel +from .node_peer import NodePeer +from .rathole_toml import RatholeServerToml +from .routes import HTTPNodeRoute + +RATHOLE_TOML_CONFIG_MAP = "rathole-config" +RATHOLE_PROXY_CONFIG_MAP = "rathole-proxy-config" +RATHOLE_DEFAULT_BIND_ADDRESS = "http://0.0.0.0:2333" +PROXY_CONFIG_MAP = "proxy-config" + + +@serializable() +class RatholeConfig(SyftBaseModel): + uuid: str + secret_token: str + local_addr_host: str + local_addr_port: int + server_name: str | None = None + + @property + def local_address(self) -> str: + return f"http://{self.local_addr_host}:{self.local_addr_port}" + + @classmethod + def from_peer(cls, peer: NodePeer) -> Self: + high_priority_route = peer.pick_highest_priority_route() + + if not isinstance(high_priority_route, HTTPNodeRoute): + raise ValueError("Rathole only supports HTTPNodeRoute") + + return cls( + uuid=peer.id, + secret_token=peer.rathole_token, + local_addr_host=high_priority_route.host_or_ip, + local_addr_port=high_priority_route.port, + server_name=peer.name, + ) + + +# class RatholeProxyConfigWriter: +# def get_config(self, *args, **kwargs): +# pass + +# def save_config(self, *args, **kwargs): +# pass + +# def add_service(url: str, service_name: str, port: int, hostname: str): +# pass + +# def delete_service(self, *args, **kwargs): +# pass + + +class RatholeService: + def __init__(self) -> None: + self.k8rs_client = get_kr8s_client() + + def add_client_to_server(self, peer: NodePeer) -> None: + """Add a client to the rathole server toml file.""" + + route = cast(HTTPNodeRoute, peer.pick_highest_priority_route()) + + config = RatholeConfig( + uuid=peer.id, + secret_token=peer.rathole_token, + local_addr_host="localhost", + local_addr_port=route.port, + server_name=peer.name, + ) + + # Get rathole toml config map + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + client_filename = RatholeServerToml.filename + + toml_str = rathole_config_map.data[client_filename] + + # Add the peer info to the toml file + rathole_toml = RatholeServerToml(toml_str) + rathole_toml.add_config(config=config) + + if not rathole_toml.get_bind_address(): + # First time adding a peer + rathole_toml.set_bind_address(RATHOLE_DEFAULT_BIND_ADDRESS) + + rathole_config_map.data[client_filename] = rathole_toml.toml_str + + # Update the rathole config map + KubeUtils.update_configmap( + client=self.k8rs_client, + name=RATHOLE_TOML_CONFIG_MAP, + data=rathole_config_map.data, + ) + + # Add the peer info to the proxy config map + self.add_port_to_proxy(config) + + def add_port_to_proxy(self, config: RatholeConfig, entrypoint: str = "web") -> None: + """Add a port to the rathole proxy config map.""" + + rathole_proxy_config_map = KubeUtils.get_configmap( + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + ) + + rathole_proxy = rathole_proxy_config_map.data["rathole-proxy.yml"] + + if not rathole_proxy: + rathole_proxy = {"http": {"routers": {}, "services": {}}} + + # TODO: config.port, this should be a random port + + rathole_proxy["http"]["services"][config.server_name] = { + "loadBalancer": { + "servers": [{"url": f"http://rathole:{config.local_addr_port}"}] + } + } + + rathole_proxy["http"]["routers"][config.server_name] = { + "rule": f"Host(`{config.server_name}.syft.local`)", + "service": config.server_name, + "entryPoints": [entrypoint], + } + + KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, rathole_proxy) + + def add_entrypoint(self, port: int, peer_name: str) -> None: + """Add an entrypoint to the traefik config map.""" + + proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + + data = proxy_config_map.data + + traefik_config_str = data["traefik.yml"] + + traefik_config = yaml.safe_load(traefik_config_str) + + traefik_config["entryPoints"][f"{peer_name}-entrypoint"] = { + "address": f":{port}" + } + + data["traefik.yml"] = yaml.safe_dump(traefik_config) + + KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, data) + + def remove_endpoint(self, peer_name: str) -> None: + """Remove an entrypoint from the traefik config map.""" + + proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + + data = proxy_config_map.data + + traefik_config_str = data["traefik.yml"] + + traefik_config = yaml.safe_load(traefik_config_str) + + del traefik_config["entryPoints"][f"{peer_name}-entrypoint"] + + data["traefik.yml"] = yaml.safe_dump(traefik_config) + + KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, data) diff --git a/packages/syft/src/syft/service/network/rathole_toml.py b/packages/syft/src/syft/service/network/rathole_toml.py new file mode 100644 index 00000000000..e50c830c5d8 --- /dev/null +++ b/packages/syft/src/syft/service/network/rathole_toml.py @@ -0,0 +1,237 @@ +# third party +import tomli +import tomli_w + +# relative +from .rathole import RatholeConfig + + +class TomlReaderWriter: + @staticmethod + def load(toml_str: str) -> dict: + return tomli.loads(toml_str) + + @staticmethod + def dump(toml_dict: str) -> str: + return tomli_w.dumps(toml_dict) + + +class RatholeBaseToml: + filename: str + + def __init__(self, toml_str: str) -> None: + self.toml_writer = TomlReaderWriter + self.toml_str = toml_str + + def read(self) -> dict: + return self.toml_writer.load(self.toml_str) + + def save(self, toml_dict: dict) -> None: + self.toml_str = self.toml_writer.dump(self.toml_str) + + def _validate(self) -> bool: + raise NotImplementedError + + @property + def is_valid(self) -> bool: + return self._validate() + + +class RatholeClientToml(RatholeBaseToml): + filename: str = "client.toml" + + def set_remote_addr(self, remote_host: str) -> None: + """Add a new remote address to the client toml file.""" + + toml = self.read() + + # Add the new remote address + if "client" not in toml: + toml["client"] = {} + + toml["client"]["remote_addr"] = remote_host + + if remote_host not in toml["client"]["remote"]: + toml["client"]["remote"].append(remote_host) + + self.save(toml) + + def add_config(self, config: RatholeConfig) -> None: + """Add a new config to the toml file.""" + + toml = self.read() + + # Add the new config + if "services" not in toml["client"]: + toml["client"]["services"] = {} + + if config.uuid not in toml["client"]["services"]: + toml["client"]["services"][config.uuid] = {} + + toml["client"]["services"][config.uuid] = { + "token": config.secret_token, + "local_addr": config.local_address, + } + + self.save(toml) + + def remove_config(self, uuid: str) -> None: + """Remove a config from the toml file.""" + + toml = self.read() + + # Remove the config + if "services" not in toml["client"]: + return + + if uuid not in toml["client"]["services"]: + return + + del toml["client"]["services"][uuid] + + self.save(toml) + + def update_config(self, config: RatholeConfig) -> None: + """Update a config in the toml file.""" + + toml = self.read() + + # Update the config + if "services" not in toml["client"]: + return + + if config.uuid not in toml["client"]["services"]: + return + + toml["client"]["services"][config.uuid] = { + "token": config.secret_token, + "local_addr": config.local_address, + } + + self.save(toml) + + def get_config(self, uuid: str) -> RatholeConfig | None: + """Get a config from the toml file.""" + + toml = self.read() + + # Get the config + if "services" not in toml["client"]: + return None + + if uuid not in toml["client"]["services"]: + return None + + service = toml["client"]["services"][uuid] + + return RatholeConfig( + uuid=uuid, + secret_token=service["token"], + local_addr_host=service["local_addr"].split(":")[0], + local_addr_port=service["local_addr"].split(":")[1], + ) + + def _validate(self) -> bool: + toml = self.read() + + if not toml["client"]["remote_addr"]: + return False + + for uuid, config in toml["client"]["services"].items(): + if not uuid: + return False + + if not config["token"] or not config["local_addr"]: + return False + + return True + + +class RatholeServerToml(RatholeBaseToml): + filename: str = "server.toml" + + def set_bind_address(self, bind_address: str) -> None: + """Set the bind address in the server toml file.""" + + toml = self.read() + + # Set the bind address + toml["server"]["bind_addr"] = bind_address + + self.save(toml) + + def get_bind_address(self) -> str: + """Get the bind address from the server toml file.""" + + toml = self.read() + + return toml["server"]["bind_addr"] + + def add_config(self, config: RatholeConfig) -> None: + """Add a new config to the toml file.""" + + toml = self.read() + + # Add the new config + if "services" not in toml["server"]: + toml["server"]["services"] = {} + + if config.uuid not in toml["server"]["services"]: + toml["server"]["services"][config.uuid] = {} + + toml["server"]["services"][config.uuid] = { + "token": config.secret_token, + "bind_addr": config.local_address, + } + + self.save(toml) + + def remove_config(self, uuid: str) -> None: + """Remove a config from the toml file.""" + + toml = self.read() + + # Remove the config + if "services" not in toml["server"]: + return + + if uuid not in toml["server"]["services"]: + return + + del toml["server"]["services"][uuid] + + self.save(toml) + + def update_config(self, config: RatholeConfig) -> None: + """Update a config in the toml file.""" + + toml = self.read() + + # Update the config + if "services" not in toml["server"]: + return + + if config.uuid not in toml["server"]["services"]: + return + + toml["server"]["services"][config.uuid] = { + "token": config.secret_token, + "bind_addr": config.local_address, + } + + self.save(toml) + + def _validate(self) -> bool: + toml = self.read() + + if not toml["server"]["bind_addr"]: + return False + + for uuid, config in toml["server"]["services"].items(): + if not uuid: + return False + + if not config["token"] or not config["bind_addr"]: + return False + + return True From 55272f8448cffff72b87ee7b5cac88d6ce2c728d Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 9 May 2024 15:13:36 +0530 Subject: [PATCH 032/309] add method to add host to client.toml in configmap - add methods to forward rathole port to proxy --- .../syft/src/syft/service/network/rathole.py | 82 ++++++++++++++++--- 1 file changed, 71 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py index 964c1dde27a..56d25af36a9 100644 --- a/packages/syft/src/syft/service/network/rathole.py +++ b/packages/syft/src/syft/service/network/rathole.py @@ -1,4 +1,5 @@ # stdlib +import secrets from typing import Self from typing import cast @@ -11,6 +12,7 @@ from ...serde import serializable from ...types.base import SyftBaseModel from .node_peer import NodePeer +from .rathole_toml import RatholeClientToml from .rathole_toml import RatholeServerToml from .routes import HTTPNodeRoute @@ -66,13 +68,13 @@ class RatholeService: def __init__(self) -> None: self.k8rs_client = get_kr8s_client() - def add_client_to_server(self, peer: NodePeer) -> None: - """Add a client to the rathole server toml file.""" + def add_host_to_server(self, peer: NodePeer) -> None: + """Add a host to the rathole server toml file.""" route = cast(HTTPNodeRoute, peer.pick_highest_priority_route()) config = RatholeConfig( - uuid=peer.id, + uuid=peer.id.to_string(), secret_token=peer.rathole_token, local_addr_host="localhost", local_addr_port=route.port, @@ -106,21 +108,81 @@ def add_client_to_server(self, peer: NodePeer) -> None: ) # Add the peer info to the proxy config map - self.add_port_to_proxy(config) + self.add_dynamic_addr_to_rathole(config) + + def get_random_port(self) -> int: + """Get a random port number.""" + return secrets.randbits(15) + + def add_host_to_client(self, peer: NodePeer) -> None: + """Add a host to the rathole client toml file.""" + + random_port = self.get_random_port() + + config = RatholeConfig( + uuid=peer.id.to_string(), + secret_token=peer.rathole_token, + local_addr_host="localhost", + local_addr_port=random_port, + server_name=peer.name, + ) + + # Get rathole toml config map + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + client_filename = RatholeClientToml.filename + + toml_str = rathole_config_map.data[client_filename] + + rathole_toml = RatholeClientToml(toml_str=toml_str) + + rathole_toml.add_config(config=config) + + self.add_entrypoint(port=random_port, peer_name=peer.name) - def add_port_to_proxy(self, config: RatholeConfig, entrypoint: str = "web") -> None: + self.forward_port_to_proxy(config=config, entrypoint=peer.name) + + def forward_port_to_proxy( + self, config: RatholeConfig, entrypoint: str = "web" + ) -> None: """Add a port to the rathole proxy config map.""" rathole_proxy_config_map = KubeUtils.get_configmap( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) - rathole_proxy = rathole_proxy_config_map.data["rathole-proxy.yml"] + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] if not rathole_proxy: rathole_proxy = {"http": {"routers": {}, "services": {}}} - # TODO: config.port, this should be a random port + rathole_proxy["http"]["services"][config.server_name] = { + "loadBalancer": {"servers": [{"url": "http://proxy:8001"}]} + } + + rathole_proxy["http"]["routers"][config.server_name] = { + "rule": "PathPrefix(`/`)", + "service": config.server_name, + "entryPoints": [entrypoint], + } + + KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, rathole_proxy) + + def add_dynamic_addr_to_rathole( + self, config: RatholeConfig, entrypoint: str = "web" + ) -> None: + """Add a port to the rathole proxy config map.""" + + rathole_proxy_config_map = KubeUtils.get_configmap( + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + ) + + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] + + if not rathole_proxy: + rathole_proxy = {"http": {"routers": {}, "services": {}}} rathole_proxy["http"]["services"][config.server_name] = { "loadBalancer": { @@ -147,9 +209,7 @@ def add_entrypoint(self, port: int, peer_name: str) -> None: traefik_config = yaml.safe_load(traefik_config_str) - traefik_config["entryPoints"][f"{peer_name}-entrypoint"] = { - "address": f":{port}" - } + traefik_config["entryPoints"][f"{peer_name}"] = {"address": f":{port}"} data["traefik.yml"] = yaml.safe_dump(traefik_config) @@ -166,7 +226,7 @@ def remove_endpoint(self, peer_name: str) -> None: traefik_config = yaml.safe_load(traefik_config_str) - del traefik_config["entryPoints"][f"{peer_name}-entrypoint"] + del traefik_config["entryPoints"][f"{peer_name}"] data["traefik.yml"] = yaml.safe_dump(traefik_config) From 7b8e461cd66de72132ba416938733087d6196593 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 13 May 2024 13:12:42 +0530 Subject: [PATCH 033/309] move RatholeService to seperate file - fix bugs with saving toml - add resourceVersion --- packages/grid/default.env | 3 + .../backend/backend-statefulset.yaml | 8 + .../syft/templates/proxy/proxy-configmap.yaml | 13 +- .../templates/rathole/rathole-configmap.yaml | 1 + packages/grid/helm/syft/values.yaml | 4 +- .../syft/src/syft/service/network/rathole.py | 206 +----------------- .../syft/service/network/rathole_service.py | 206 ++++++++++++++++++ .../src/syft/service/network/rathole_toml.py | 8 +- 8 files changed, 233 insertions(+), 216 deletions(-) create mode 100644 packages/syft/src/syft/service/network/rathole_service.py diff --git a/packages/grid/default.env b/packages/grid/default.env index f6efe8aa463..906251ee865 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -110,3 +110,6 @@ ENABLE_SIGNUP=False # Enclave Attestation DOCKER_IMAGE_ENCLAVE_ATTESTATION=openmined/grid-enclave-attestation + +# Rathole Config +RATHOLE_PORT=2333 \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 8048f262e5e..f9b2dd42353 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -82,9 +82,17 @@ spec: {{- if .Values.node.debuggerEnabled }} - name: DEBUGGER_ENABLED value: "true" + {{- end }} + {{- if eq .Values.node.type "gateway" }} - name: ASSOCIATION_REQUEST_AUTO_APPROVAL value: {{ .Values.node.associationRequestAutoApproval | quote }} {{- end }} + {{- if .Values.rathole.enabled }} + - name: RATHOLE_PORT + value: {{ .Values.rathole.port | quote }} + - name: RATHOLE_ENABLED + value: "true" + {{- end }} # MongoDB - name: MONGO_PORT value: {{ .Values.mongo.port | quote }} diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 8853fae80b2..ee5d2316e99 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -2,6 +2,7 @@ apiVersion: v1 kind: ConfigMap metadata: name: proxy-config + resourceVersion: "" labels: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: proxy @@ -83,15 +84,3 @@ metadata: app.kubernetes.io/component: proxy data: rathole-dynamic.yml: | - # http: - # services: - # rathole_domain_1: - # loadBalancer: - # servers: - # - url: "http://rathole-0.rathole.syft.svc.cluster.local:2333" - # routers: - # rathole_domain_1: - # rule: "Host('domain1.domain.syft.local')" - # entryPoints: - # - "web" - # service: "rathole_domain_1" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 3e843fbc1cc..46b141d68d9 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -2,6 +2,7 @@ apiVersion: v1 kind: ConfigMap metadata: name: rathole-config + resourceVersion: "" labels: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: rathole diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 6d2bf1761ac..ee6db586bb6 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -227,9 +227,9 @@ rathole: env: null enabled: true - ratholePort: 2333 + port: 2333 appPort: 5555 - mode: "server" + mode: "client" devMode: "false" appLogLevel: "info" diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py index 56d25af36a9..a90e6f34030 100644 --- a/packages/syft/src/syft/service/network/rathole.py +++ b/packages/syft/src/syft/service/network/rathole.py @@ -1,25 +1,15 @@ # stdlib -import secrets from typing import Self -from typing import cast - -# third party -import yaml # relative -from ...custom_worker.k8s import KubeUtils -from ...custom_worker.k8s import get_kr8s_client -from ...serde import serializable +from ...serde.serializable import serializable from ...types.base import SyftBaseModel +from ...util.util import get_env from .node_peer import NodePeer -from .rathole_toml import RatholeClientToml -from .rathole_toml import RatholeServerToml -from .routes import HTTPNodeRoute -RATHOLE_TOML_CONFIG_MAP = "rathole-config" -RATHOLE_PROXY_CONFIG_MAP = "rathole-proxy-config" -RATHOLE_DEFAULT_BIND_ADDRESS = "http://0.0.0.0:2333" -PROXY_CONFIG_MAP = "proxy-config" + +def get_rathole_port() -> int: + return int(get_env("RATHOLE_PORT", "2333")) @serializable() @@ -36,6 +26,9 @@ def local_address(self) -> str: @classmethod def from_peer(cls, peer: NodePeer) -> Self: + # relative + from .routes import HTTPNodeRoute + high_priority_route = peer.pick_highest_priority_route() if not isinstance(high_priority_route, HTTPNodeRoute): @@ -48,186 +41,3 @@ def from_peer(cls, peer: NodePeer) -> Self: local_addr_port=high_priority_route.port, server_name=peer.name, ) - - -# class RatholeProxyConfigWriter: -# def get_config(self, *args, **kwargs): -# pass - -# def save_config(self, *args, **kwargs): -# pass - -# def add_service(url: str, service_name: str, port: int, hostname: str): -# pass - -# def delete_service(self, *args, **kwargs): -# pass - - -class RatholeService: - def __init__(self) -> None: - self.k8rs_client = get_kr8s_client() - - def add_host_to_server(self, peer: NodePeer) -> None: - """Add a host to the rathole server toml file.""" - - route = cast(HTTPNodeRoute, peer.pick_highest_priority_route()) - - config = RatholeConfig( - uuid=peer.id.to_string(), - secret_token=peer.rathole_token, - local_addr_host="localhost", - local_addr_port=route.port, - server_name=peer.name, - ) - - # Get rathole toml config map - rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP - ) - - client_filename = RatholeServerToml.filename - - toml_str = rathole_config_map.data[client_filename] - - # Add the peer info to the toml file - rathole_toml = RatholeServerToml(toml_str) - rathole_toml.add_config(config=config) - - if not rathole_toml.get_bind_address(): - # First time adding a peer - rathole_toml.set_bind_address(RATHOLE_DEFAULT_BIND_ADDRESS) - - rathole_config_map.data[client_filename] = rathole_toml.toml_str - - # Update the rathole config map - KubeUtils.update_configmap( - client=self.k8rs_client, - name=RATHOLE_TOML_CONFIG_MAP, - data=rathole_config_map.data, - ) - - # Add the peer info to the proxy config map - self.add_dynamic_addr_to_rathole(config) - - def get_random_port(self) -> int: - """Get a random port number.""" - return secrets.randbits(15) - - def add_host_to_client(self, peer: NodePeer) -> None: - """Add a host to the rathole client toml file.""" - - random_port = self.get_random_port() - - config = RatholeConfig( - uuid=peer.id.to_string(), - secret_token=peer.rathole_token, - local_addr_host="localhost", - local_addr_port=random_port, - server_name=peer.name, - ) - - # Get rathole toml config map - rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP - ) - - client_filename = RatholeClientToml.filename - - toml_str = rathole_config_map.data[client_filename] - - rathole_toml = RatholeClientToml(toml_str=toml_str) - - rathole_toml.add_config(config=config) - - self.add_entrypoint(port=random_port, peer_name=peer.name) - - self.forward_port_to_proxy(config=config, entrypoint=peer.name) - - def forward_port_to_proxy( - self, config: RatholeConfig, entrypoint: str = "web" - ) -> None: - """Add a port to the rathole proxy config map.""" - - rathole_proxy_config_map = KubeUtils.get_configmap( - self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP - ) - - rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] - - if not rathole_proxy: - rathole_proxy = {"http": {"routers": {}, "services": {}}} - - rathole_proxy["http"]["services"][config.server_name] = { - "loadBalancer": {"servers": [{"url": "http://proxy:8001"}]} - } - - rathole_proxy["http"]["routers"][config.server_name] = { - "rule": "PathPrefix(`/`)", - "service": config.server_name, - "entryPoints": [entrypoint], - } - - KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, rathole_proxy) - - def add_dynamic_addr_to_rathole( - self, config: RatholeConfig, entrypoint: str = "web" - ) -> None: - """Add a port to the rathole proxy config map.""" - - rathole_proxy_config_map = KubeUtils.get_configmap( - self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP - ) - - rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] - - if not rathole_proxy: - rathole_proxy = {"http": {"routers": {}, "services": {}}} - - rathole_proxy["http"]["services"][config.server_name] = { - "loadBalancer": { - "servers": [{"url": f"http://rathole:{config.local_addr_port}"}] - } - } - - rathole_proxy["http"]["routers"][config.server_name] = { - "rule": f"Host(`{config.server_name}.syft.local`)", - "service": config.server_name, - "entryPoints": [entrypoint], - } - - KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, rathole_proxy) - - def add_entrypoint(self, port: int, peer_name: str) -> None: - """Add an entrypoint to the traefik config map.""" - - proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) - - data = proxy_config_map.data - - traefik_config_str = data["traefik.yml"] - - traefik_config = yaml.safe_load(traefik_config_str) - - traefik_config["entryPoints"][f"{peer_name}"] = {"address": f":{port}"} - - data["traefik.yml"] = yaml.safe_dump(traefik_config) - - KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, data) - - def remove_endpoint(self, peer_name: str) -> None: - """Remove an entrypoint from the traefik config map.""" - - proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) - - data = proxy_config_map.data - - traefik_config_str = data["traefik.yml"] - - traefik_config = yaml.safe_load(traefik_config_str) - - del traefik_config["entryPoints"][f"{peer_name}"] - - data["traefik.yml"] = yaml.safe_dump(traefik_config) - - KubeUtils.update_configmap(self.k8rs_client, PROXY_CONFIG_MAP, data) diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py new file mode 100644 index 00000000000..53ea27f5c38 --- /dev/null +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -0,0 +1,206 @@ +# stdlib +import secrets + +# third party +import yaml + +# relative +from ...custom_worker.k8s import KubeUtils +from ...custom_worker.k8s import get_kr8s_client +from .node_peer import NodePeer +from .rathole import RatholeConfig +from .rathole import get_rathole_port +from .rathole_toml import RatholeClientToml +from .rathole_toml import RatholeServerToml + +RATHOLE_TOML_CONFIG_MAP = "rathole-config" +RATHOLE_PROXY_CONFIG_MAP = "proxy-config-dynamic" +PROXY_CONFIG_MAP = "proxy-config" + + +class RatholeService: + def __init__(self) -> None: + self.k8rs_client = get_kr8s_client() + + def add_host_to_server(self, peer: NodePeer) -> None: + """Add a host to the rathole server toml file. + + Args: + peer (NodePeer): The peer to be added to the rathole server. + + Returns: + None + """ + + random_port = self.get_random_port() + + config = RatholeConfig( + uuid=peer.id.to_string(), + secret_token=peer.rathole_token, + local_addr_host="localhost", + local_addr_port=random_port, + server_name=peer.name, + ) + + # Get rathole toml config map + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + client_filename = RatholeServerToml.filename + + toml_str = rathole_config_map.data[client_filename] + + # Add the peer info to the toml file + rathole_toml = RatholeServerToml(toml_str) + rathole_toml.add_config(config=config) + + # First time adding a peer + if not rathole_toml.get_rathole_listener_addr(): + bind_addr = f"http://localhost:{get_rathole_port()}" + rathole_toml.set_rathole_listener_addr(bind_addr) + + data = {client_filename: rathole_toml.toml_str} + + # Update the rathole config map + KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) + + # Add the peer info to the proxy config map + self.add_dynamic_addr_to_rathole(config) + + def get_random_port(self) -> int: + """Get a random port number.""" + return secrets.randbits(15) + + def add_host_to_client(self, peer: NodePeer) -> None: + """Add a host to the rathole client toml file.""" + + random_port = self.get_random_port() + + config = RatholeConfig( + uuid=peer.id.to_string(), + secret_token=peer.rathole_token, + local_addr_host="localhost", + local_addr_port=random_port, + server_name=peer.name, + ) + + # Get rathole toml config map + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + client_filename = RatholeClientToml.filename + + toml_str = rathole_config_map.data[client_filename] + + rathole_toml = RatholeClientToml(toml_str=toml_str) + + rathole_toml.add_config(config=config) + + data = {client_filename: rathole_toml.toml_str} + + # Update the rathole config map + KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) + + self.add_entrypoint(port=random_port, peer_name=peer.name) + + self.forward_port_to_proxy(config=config, entrypoint=peer.name) + + def forward_port_to_proxy( + self, config: RatholeConfig, entrypoint: str = "web" + ) -> None: + """Add a port to the rathole proxy config map.""" + + rathole_proxy_config_map = KubeUtils.get_configmap( + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + ) + + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] + + if not rathole_proxy: + rathole_proxy = {"http": {"routers": {}, "services": {}}} + else: + rathole_proxy = yaml.safe_load(rathole_proxy) + + rathole_proxy["http"]["services"][config.server_name] = { + "loadBalancer": {"servers": [{"url": "http://proxy:8001"}]} + } + + rathole_proxy["http"]["routers"][config.server_name] = { + "rule": "PathPrefix(`/`)", + "service": config.server_name, + "entryPoints": [entrypoint], + } + + KubeUtils.update_configmap( + config_map=rathole_proxy_config_map, + patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, + ) + + def add_dynamic_addr_to_rathole( + self, config: RatholeConfig, entrypoint: str = "web" + ) -> None: + """Add a port to the rathole proxy config map.""" + + rathole_proxy_config_map = KubeUtils.get_configmap( + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + ) + + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] + + if not rathole_proxy: + rathole_proxy = {"http": {"routers": {}, "services": {}}} + else: + rathole_proxy = yaml.safe_load(rathole_proxy) + + rathole_proxy["http"]["services"][config.server_name] = { + "loadBalancer": { + "servers": [{"url": f"http://rathole:{config.local_addr_port}"}] + } + } + + rathole_proxy["http"]["routers"][config.server_name] = { + "rule": f"Host(`{config.server_name}.syft.local`)", + "service": config.server_name, + "entryPoints": [entrypoint], + } + + KubeUtils.update_configmap( + config_map=rathole_proxy_config_map, + patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, + ) + + def add_entrypoint(self, port: int, peer_name: str) -> None: + """Add an entrypoint to the traefik config map.""" + + proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + + data = proxy_config_map.data + + traefik_config_str = data["traefik.yml"] + + traefik_config = yaml.safe_load(traefik_config_str) + + traefik_config["entryPoints"][f"{peer_name}"] = {"address": f":{port}"} + + data["traefik.yml"] = yaml.safe_dump(traefik_config) + + KubeUtils.update_configmap(config_map=proxy_config_map, patch={"data": data}) + + def remove_endpoint(self, peer_name: str) -> None: + """Remove an entrypoint from the traefik config map.""" + + proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + + data = proxy_config_map.data + + traefik_config_str = data["traefik.yml"] + + traefik_config = yaml.safe_load(traefik_config_str) + + del traefik_config["entryPoints"][f"{peer_name}"] + + data["traefik.yml"] = yaml.safe_dump(traefik_config) + + KubeUtils.update_configmap(config_map=proxy_config_map, patch={"data": data}) diff --git a/packages/syft/src/syft/service/network/rathole_toml.py b/packages/syft/src/syft/service/network/rathole_toml.py index e50c830c5d8..7ca69be6d14 100644 --- a/packages/syft/src/syft/service/network/rathole_toml.py +++ b/packages/syft/src/syft/service/network/rathole_toml.py @@ -27,7 +27,7 @@ def read(self) -> dict: return self.toml_writer.load(self.toml_str) def save(self, toml_dict: dict) -> None: - self.toml_str = self.toml_writer.dump(self.toml_str) + self.toml_str = self.toml_writer.dump(toml_dict) def _validate(self) -> bool: raise NotImplementedError @@ -150,17 +150,17 @@ def _validate(self) -> bool: class RatholeServerToml(RatholeBaseToml): filename: str = "server.toml" - def set_bind_address(self, bind_address: str) -> None: + def set_rathole_listener_addr(self, bind_addr: str) -> None: """Set the bind address in the server toml file.""" toml = self.read() # Set the bind address - toml["server"]["bind_addr"] = bind_address + toml["server"]["bind_addr"] = bind_addr self.save(toml) - def get_bind_address(self) -> str: + def get_rathole_listener_addr(self) -> str: """Get the bind address from the server toml file.""" toml = self.read() From 13e2dd015792fb282da522ef885e264050a020a4 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 13 May 2024 16:03:29 +0530 Subject: [PATCH 034/309] fix add host to client method to pass remote addr - integrate Rathole service to Network service --- .../service/network/association_request.py | 10 +++++++++- .../syft/service/network/network_service.py | 19 +++++++++++++++++++ .../syft/service/network/rathole_service.py | 16 ++++++++++------ 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 70c08a52e56..0cdb78375fb 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -1,5 +1,6 @@ # stdlib import secrets +from typing import cast # third party from result import Err @@ -72,14 +73,21 @@ def _run( except Exception as e: return Err(SyftError(message=str(e))) - network_stash = service_ctx.node.get_service(NetworkService).stash + network_service = cast( + NetworkService, service_ctx.node.get_service(NetworkService) + ) + + network_stash = network_service.stash result = network_stash.create_or_update_peer( service_ctx.node.verify_key, self.remote_peer ) + if result.is_err(): return Err(SyftError(message=str(result.err()))) + network_service.rathole_service.add_host_to_server(self.remote_peer) + # this way they can match up who we are with who they think we are # Sending a signed messages for the peer to verify self_node_peer = self.self_node_route.validate_with_context(context=service_ctx) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index fd937f8491f..dfb7c8cdce2 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from hashlib import sha256 import secrets from typing import Any @@ -48,6 +49,8 @@ from ..warnings import CRUDWarning from .association_request import AssociationRequestChange from .node_peer import NodePeer +from .rathole import get_rathole_port +from .rathole_service import RatholeService from .routes import HTTPNodeRoute from .routes import NodeRoute from .routes import NodeRouteType @@ -140,6 +143,7 @@ class NetworkService(AbstractService): def __init__(self, store: DocumentStore) -> None: self.store = store self.stash = NetworkStash(store=store) + self.rathole_service = RatholeService() # TODO: Check with MADHAVA, can we even allow guest user to introduce routes to # domain nodes? @@ -172,6 +176,9 @@ def exchange_credentials_with( ) random_challenge = secrets.token_bytes(16) + rathole_token = self._generate_token() + self_node_peer.rathole_token = rathole_token + # ask the remote client to add this node (represented by `self_node_peer`) as a peer remote_res = remote_client.api.services.network.add_peer( peer=self_node_peer, @@ -195,12 +202,24 @@ def exchange_credentials_with( if result.is_err(): return SyftError(message=str(result.err())) + remote_addr = f"{remote_node_route.protocol}://{remote_node_route.host_or_ip}:{get_rathole_port()()}" + + self.rathole_service.add_host_to_client( + peer_name=self_node_peer.name, + peer_id=self_node_peer.id.to_string(), + rathole_token=self_node_peer.rathole_token, + remote_addr=remote_addr, + ) + return ( SyftSuccess(message="Routes Exchanged") if association_request_approved else remote_res ) + def _generate_token(self) -> str: + return sha256(secrets.token_bytes(16)).hexdigest() + @service_method(path="network.add_peer", name="add_peer", roles=GUEST_ROLE_LEVEL) def add_peer( self, diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 53ea27f5c38..c9556f69f4e 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -72,17 +72,19 @@ def get_random_port(self) -> int: """Get a random port number.""" return secrets.randbits(15) - def add_host_to_client(self, peer: NodePeer) -> None: + def add_host_to_client( + self, peer_name: str, peer_id: str, rathole_token: str, remote_addr: str + ) -> None: """Add a host to the rathole client toml file.""" random_port = self.get_random_port() config = RatholeConfig( - uuid=peer.id.to_string(), - secret_token=peer.rathole_token, + uuid=peer_id, + secret_token=rathole_token, local_addr_host="localhost", local_addr_port=random_port, - server_name=peer.name, + server_name=peer_name, ) # Get rathole toml config map @@ -98,14 +100,16 @@ def add_host_to_client(self, peer: NodePeer) -> None: rathole_toml.add_config(config=config) + rathole_toml.set_remote_addr(remote_addr) + data = {client_filename: rathole_toml.toml_str} # Update the rathole config map KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) - self.add_entrypoint(port=random_port, peer_name=peer.name) + self.add_entrypoint(port=random_port, peer_name=peer_name) - self.forward_port_to_proxy(config=config, entrypoint=peer.name) + self.forward_port_to_proxy(config=config, entrypoint=peer_name) def forward_port_to_proxy( self, config: RatholeConfig, entrypoint: str = "web" From ad838ac7e5ba4c21f612788e053801d8cf0600c0 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 15 May 2024 14:50:10 +0530 Subject: [PATCH 035/309] fix values.yml not being correctly propogated --- packages/grid/devspace.yaml | 11 +++++++---- packages/grid/helm/syft/values.yaml | 4 ---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index dc999c80063..cd777d5ef44 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -76,15 +76,16 @@ deployments: releaseName: syft-dev chart: name: ./helm/syft + # anything that does not need devspace $env vars should go in values.dev.yaml + valuesFiles: + - ./helm/syft/values.yaml + - ./helm/values.dev.yaml values: global: registry: ${CONTAINER_REGISTRY} version: dev-${DEVSPACE_TIMESTAMP} node: type: domain # required for the gateway profile - # anything that does not need devspace $env vars should go in values.dev.yaml - valuesFiles: - - ./helm/values.dev.yaml dev: mongo: @@ -139,9 +140,11 @@ profiles: path: images.seaweedfs - op: remove path: dev.seaweedfs + + # Patch mode to server - op: replace path: deployments.syft.helm.values.rathole.mode - value: "server" + value: server # Port Re-Mapping # Mongo diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index ee6db586bb6..f03f81d619c 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -228,10 +228,6 @@ rathole: enabled: true port: 2333 - appPort: 5555 - mode: "client" - devMode: "false" - appLogLevel: "info" # Pod labels & annotations podLabels: null From 01fe5e82268a8fffff2b545854e56b727e01f9bb Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 15 May 2024 19:50:27 +0530 Subject: [PATCH 036/309] add mode to rathole in values.yml --- packages/grid/devspace.yaml | 2 ++ packages/grid/helm/syft/values.yaml | 1 + 2 files changed, 3 insertions(+) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index cd777d5ef44..b52069a4d5a 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -86,6 +86,8 @@ deployments: version: dev-${DEVSPACE_TIMESTAMP} node: type: domain # required for the gateway profile + rathole: + mode: client dev: mongo: diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index f03f81d619c..7df55c978da 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -228,6 +228,7 @@ rathole: enabled: true port: 2333 + mode: client # Pod labels & annotations podLabels: null From f7be2806354217562db335e3650fbb4fae0613e3 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 15 May 2024 22:20:35 +0530 Subject: [PATCH 037/309] fix set remote addr method in Rathole service - skip remote ping check if rathole --- .../service/network/association_request.py | 55 ++++++++++--------- .../syft/service/network/network_service.py | 8 ++- .../src/syft/service/network/rathole_toml.py | 5 +- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 0cdb78375fb..ac64b0c57f2 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -43,35 +43,36 @@ def _run( service_ctx = context.to_service_ctx() - try: - remote_client: SyftClient = self.remote_peer.client_with_context( - context=service_ctx - ) - if remote_client.is_err(): - return SyftError( - message=f"Failed to create remote client for peer: " - f"{self.remote_peer.id}. Error: {remote_client.err()}" + if self.remote_peer.rathole_token is None: + try: + remote_client: SyftClient = self.remote_peer.client_with_context( + context=service_ctx ) - remote_client = remote_client.ok() - random_challenge = secrets.token_bytes(16) - remote_res = remote_client.api.services.network.ping( - challenge=random_challenge - ) - except Exception as e: - return SyftError(message="Remote Peer cannot ping peer:" + str(e)) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{self.remote_peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + random_challenge = secrets.token_bytes(16) + remote_res = remote_client.api.services.network.ping( + challenge=random_challenge + ) + except Exception as e: + return SyftError(message="Remote Peer cannot ping peer:" + str(e)) - if isinstance(remote_res, SyftError): - return Err(remote_res) + if isinstance(remote_res, SyftError): + return Err(remote_res) - challenge_signature = remote_res + challenge_signature = remote_res - # Verifying if the challenge is valid - try: - self.remote_peer.verify_key.verify_key.verify( - random_challenge, challenge_signature - ) - except Exception as e: - return Err(SyftError(message=str(e))) + # Verifying if the challenge is valid + try: + self.remote_peer.verify_key.verify_key.verify( + random_challenge, challenge_signature + ) + except Exception as e: + return Err(SyftError(message=str(e))) network_service = cast( NetworkService, service_ctx.node.get_service(NetworkService) @@ -79,6 +80,8 @@ def _run( network_stash = network_service.stash + network_service.rathole_service.add_host_to_server(self.remote_peer) + result = network_stash.create_or_update_peer( service_ctx.node.verify_key, self.remote_peer ) @@ -86,8 +89,6 @@ def _run( if result.is_err(): return Err(SyftError(message=str(result.err()))) - network_service.rathole_service.add_host_to_server(self.remote_peer) - # this way they can match up who we are with who they think we are # Sending a signed messages for the peer to verify self_node_peer = self.self_node_route.validate_with_context(context=service_ctx) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index dfb7c8cdce2..4976119f5a9 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -188,7 +188,9 @@ def exchange_credentials_with( ) if isinstance(remote_res, SyftError): - return remote_res + return SyftError( + message=f"returned error from add peer: {remote_res.message}" + ) association_request_approved = not isinstance(remote_res, Request) @@ -200,7 +202,9 @@ def exchange_credentials_with( remote_node_peer, ) if result.is_err(): - return SyftError(message=str(result.err())) + return SyftError( + message=f"Failed to update remote node peer: {str(result.err())}" + ) remote_addr = f"{remote_node_route.protocol}://{remote_node_route.host_or_ip}:{get_rathole_port()()}" diff --git a/packages/syft/src/syft/service/network/rathole_toml.py b/packages/syft/src/syft/service/network/rathole_toml.py index 7ca69be6d14..e5fe17b59e9 100644 --- a/packages/syft/src/syft/service/network/rathole_toml.py +++ b/packages/syft/src/syft/service/network/rathole_toml.py @@ -49,10 +49,7 @@ def set_remote_addr(self, remote_host: str) -> None: if "client" not in toml: toml["client"] = {} - toml["client"]["remote_addr"] = remote_host - - if remote_host not in toml["client"]["remote"]: - toml["client"]["remote"].append(remote_host) + toml["client"]["remote_addr"] = remote_host self.save(toml) From c5c3fa6243bf0caddf59c5c92ef85eba6f81b443 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 15 May 2024 22:45:09 +0530 Subject: [PATCH 038/309] rename RATHOLE_MODE to MODE in rathole/start.sh - removed extra braces from get_rathole_port --- packages/grid/rathole/start.sh | 8 ++++---- packages/syft/src/syft/service/network/network_service.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index e948973decf..d9527196ce9 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -1,14 +1,14 @@ #!/usr/bin/env bash -RATHOLE_MODE=${RATHOLE_MODE:-server} +MODE=${MODE:-server} cp -L -r -f /conf/* conf/ -if [[ $RATHOLE_MODE == "server" ]]; then +if [[ $MODE == "server" ]]; then /app/rathole conf/server.toml & -elif [[ $RATHOLE_MODE = "client" ]]; then +elif [[ $MODE = "client" ]]; then /app/rathole conf/client.toml & else - echo "RATHOLE_MODE is set to an invalid value. Exiting." + echo "RATHOLE MODE is set to an invalid value. Exiting." fi # reload config every 10 seconds diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4976119f5a9..2a8aa986846 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -206,7 +206,7 @@ def exchange_credentials_with( message=f"Failed to update remote node peer: {str(result.err())}" ) - remote_addr = f"{remote_node_route.protocol}://{remote_node_route.host_or_ip}:{get_rathole_port()()}" + remote_addr = f"{remote_node_route.protocol}://{remote_node_route.host_or_ip}:{get_rathole_port()}" self.rathole_service.add_host_to_client( peer_name=self_node_peer.name, From 401d97417266f6eccdb29feea4424d94c2d94c77 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 16 May 2024 00:04:09 +0530 Subject: [PATCH 039/309] add a retry if client.toml is invalid when no connections are setup --- .../syft/templates/rathole/rathole-configmap.yaml | 2 +- packages/grid/rathole/start.sh | 11 ++++++++++- packages/syft/src/syft/service/network/rathole.py | 2 +- .../syft/src/syft/service/network/rathole_service.py | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 46b141d68d9..d9e004d3c5e 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -20,5 +20,5 @@ data: {{- if eq .Values.rathole.mode "client" }} client.toml: | [client] - remote_addr = "" + remote_addr = "0.0.0.0:2333" {{- end }} diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index d9527196ce9..4095f30f3aa 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -6,7 +6,16 @@ cp -L -r -f /conf/* conf/ if [[ $MODE == "server" ]]; then /app/rathole conf/server.toml & elif [[ $MODE = "client" ]]; then - /app/rathole conf/client.toml & + while true; do + /app/rathole conf/client.toml + status=$? + if [ $status -eq 0 ]; then + break + else + echo "Failed to load client.toml, retrying in 5 seconds..." + sleep 10 + fi + done & else echo "RATHOLE MODE is set to an invalid value. Exiting." fi diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py index a90e6f34030..e102134ada6 100644 --- a/packages/syft/src/syft/service/network/rathole.py +++ b/packages/syft/src/syft/service/network/rathole.py @@ -22,7 +22,7 @@ class RatholeConfig(SyftBaseModel): @property def local_address(self) -> str: - return f"http://{self.local_addr_host}:{self.local_addr_port}" + return f"{self.local_addr_host}:{self.local_addr_port}" @classmethod def from_peer(cls, peer: NodePeer) -> Self: diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index c9556f69f4e..250c56c21cb 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -57,7 +57,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: # First time adding a peer if not rathole_toml.get_rathole_listener_addr(): - bind_addr = f"http://localhost:{get_rathole_port()}" + bind_addr = f"localhost:{get_rathole_port()}" rathole_toml.set_rathole_listener_addr(bind_addr) data = {client_filename: rathole_toml.toml_str} From e4e55f38947d42829b5a02bce44c0afc3346503f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 16 May 2024 11:19:30 +0530 Subject: [PATCH 040/309] remove fastapi from grid/rathole --- packages/grid/rathole/domain.dockerfile | 9 - packages/grid/rathole/nginx.conf | 6 - packages/grid/rathole/requirements.txt | 5 - packages/grid/rathole/server/__init__.py | 0 packages/grid/rathole/server/main.py | 94 ------- packages/grid/rathole/server/models.py | 18 -- packages/grid/rathole/server/nginx_builder.py | 92 ------- packages/grid/rathole/server/toml_writer.py | 26 -- packages/grid/rathole/server/utils.py | 236 ------------------ 9 files changed, 486 deletions(-) delete mode 100644 packages/grid/rathole/domain.dockerfile delete mode 100644 packages/grid/rathole/nginx.conf delete mode 100644 packages/grid/rathole/requirements.txt delete mode 100644 packages/grid/rathole/server/__init__.py delete mode 100644 packages/grid/rathole/server/main.py delete mode 100644 packages/grid/rathole/server/models.py delete mode 100644 packages/grid/rathole/server/nginx_builder.py delete mode 100644 packages/grid/rathole/server/toml_writer.py delete mode 100644 packages/grid/rathole/server/utils.py diff --git a/packages/grid/rathole/domain.dockerfile b/packages/grid/rathole/domain.dockerfile deleted file mode 100644 index cdb657540e8..00000000000 --- a/packages/grid/rathole/domain.dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -ARG PYTHON_VERSION="3.12" -FROM python:${PYTHON_VERSION}-bookworm -RUN apt update && apt install -y netcat-openbsd vim -WORKDIR /app -CMD ["python3", "-m", "http.server", "8000"] -EXPOSE 8000 - -# docker build -f domain.dockerfile . -t domain -# docker run -it -p 8080:8000 domain diff --git a/packages/grid/rathole/nginx.conf b/packages/grid/rathole/nginx.conf deleted file mode 100644 index 447c482a9e4..00000000000 --- a/packages/grid/rathole/nginx.conf +++ /dev/null @@ -1,6 +0,0 @@ -server { - listen 8000; - location / { - proxy_pass http://test-domain-r:8001; - } -} \ No newline at end of file diff --git a/packages/grid/rathole/requirements.txt b/packages/grid/rathole/requirements.txt deleted file mode 100644 index 4b379d83e8e..00000000000 --- a/packages/grid/rathole/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -fastapi==0.110.0 -filelock==3.13.4 -loguru==0.7.2 -python-nginx -uvicorn[standard]==0.27.1 diff --git a/packages/grid/rathole/server/__init__.py b/packages/grid/rathole/server/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/packages/grid/rathole/server/main.py b/packages/grid/rathole/server/main.py deleted file mode 100644 index 201618c6432..00000000000 --- a/packages/grid/rathole/server/main.py +++ /dev/null @@ -1,94 +0,0 @@ -# stdlib -from enum import Enum -import os -import sys - -# third party -from fastapi import FastAPI -from fastapi import status -from loguru import logger -from server.models import RatholeConfig -from server.models import ResponseModel -from server.utils import RatholeClientToml -from server.utils import RatholeServerToml - -# Logging Configuration -log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() -logger.remove() -logger.add(sys.stderr, colorize=True, level=log_level) - -app = FastAPI(title="Rathole") - - -class RatholeMode(Enum): - CLIENT = "client" - SERVER = "server" - - -ServiceType = os.getenv("MODE", "client").lower() - - -RatholeTomlManager = ( - RatholeServerToml() - if ServiceType == RatholeMode.SERVER.value - else RatholeClientToml() -) - - -async def healthcheck() -> bool: - return True - - -@app.get( - "/", - response_model=ResponseModel, - status_code=status.HTTP_200_OK, -) -async def healthcheck_endpoint() -> ResponseModel: - res = await healthcheck() - if res: - return ResponseModel(message="OK") - else: - return ResponseModel(message="FAIL") - - -@app.post( - "/config/", - response_model=ResponseModel, - status_code=status.HTTP_201_CREATED, -) -async def add_config(config: RatholeConfig) -> ResponseModel: - RatholeTomlManager.add_config(config) - return ResponseModel(message="Config added successfully") - - -@app.delete( - "/config/{uuid}", - response_model=ResponseModel, - status_code=status.HTTP_200_OK, -) -async def remove_config(uuid: str) -> ResponseModel: - RatholeTomlManager.remove_config(uuid) - return ResponseModel(message="Config removed successfully") - - -@app.put( - "/config/{uuid}", - response_model=ResponseModel, - status_code=status.HTTP_200_OK, -) -async def update_config(config: RatholeConfig) -> ResponseModel: - RatholeTomlManager.update_config(config=config) - return ResponseModel(message="Config updated successfully") - - -@app.get( - "/config/{uuid}", - response_model=RatholeConfig | ResponseModel, - status_code=status.HTTP_201_CREATED, -) -async def get_config(uuid: str) -> RatholeConfig: - config = RatholeTomlManager.get_config(uuid) - if config is None: - return ResponseModel(message="Config not found") - return config diff --git a/packages/grid/rathole/server/models.py b/packages/grid/rathole/server/models.py deleted file mode 100644 index c5921738885..00000000000 --- a/packages/grid/rathole/server/models.py +++ /dev/null @@ -1,18 +0,0 @@ -# third party -from pydantic import BaseModel - - -class ResponseModel(BaseModel): - message: str - - -class RatholeConfig(BaseModel): - uuid: str - secret_token: str - local_addr_host: str - local_addr_port: int - server_name: str | None = None - - @property - def local_address(self) -> str: - return f"http://{self.local_addr_host}:{self.local_addr_port}" diff --git a/packages/grid/rathole/server/nginx_builder.py b/packages/grid/rathole/server/nginx_builder.py deleted file mode 100644 index 3d1bd14ce1f..00000000000 --- a/packages/grid/rathole/server/nginx_builder.py +++ /dev/null @@ -1,92 +0,0 @@ -# stdlib -from pathlib import Path - -# third party -from filelock import FileLock -import nginx -from nginx import Conf - - -class RatholeNginxConfigBuilder: - def __init__(self, filename: str | Path) -> None: - self.filename = Path(filename).absolute() - - if not self.filename.exists(): - self.filename.touch() - - self.lock = FileLock(f"{filename}.lock") - self.lock_timeout = 30 - - def read(self) -> Conf: - with self.lock.acquire(timeout=self.lock_timeout): - conf = nginx.loadf(self.filename) - - return conf - - def write(self, conf: Conf) -> None: - with self.lock.acquire(timeout=self.lock_timeout): - nginx.dumpf(conf, self.filename) - - def add_server( - self, - listen_port: int, - location: str, - proxy_pass: str, - server_name: str | None = None, - ) -> None: - n_config = self.read() - server_to_modify = self.find_server_with_listen_port(listen_port) - - if server_to_modify is not None: - server_to_modify.add( - nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)) - ) - if server_name is not None: - server_to_modify.add(nginx.Key("server_name", server_name)) - else: - server = nginx.Server( - nginx.Key("listen", listen_port), - nginx.Location(location, nginx.Key("proxy_pass", proxy_pass)), - ) - if server_name is not None: - server.add(nginx.Key("server_name", server_name)) - - n_config.add(server) - - self.write(n_config) - - def remove_server(self, listen_port: int) -> None: - conf = self.read() - for server in conf.servers: - for child in server.children: - if child.name == "listen" and int(child.value) == listen_port: - conf.remove(server) - break - self.write(conf) - - def find_server_with_listen_port(self, listen_port: int) -> nginx.Server | None: - conf = self.read() - for server in conf.servers: - for child in server.children: - if child.name == "listen" and int(child.value) == listen_port: - return server - return None - - def modify_proxy_for_port( - self, listen_port: int, location: str, proxy_pass: str - ) -> None: - conf = self.read() - server_to_modify = self.find_server_with_listen_port(listen_port) - - if server_to_modify is None: - raise ValueError(f"Server with listen port {listen_port} not found") - - for location in server_to_modify.locations: - if location.value != location: - continue - for key in location.keys: - if key.name == "proxy_pass": - key.value = proxy_pass - break - - self.write(conf) diff --git a/packages/grid/rathole/server/toml_writer.py b/packages/grid/rathole/server/toml_writer.py deleted file mode 100644 index a0e79aff627..00000000000 --- a/packages/grid/rathole/server/toml_writer.py +++ /dev/null @@ -1,26 +0,0 @@ -# stdlib -from pathlib import Path -import tomllib - -# third party -from filelock import FileLock - -FILE_LOCK_TIMEOUT = 30 - - -class TomlReaderWriter: - def __init__(self, lock: FileLock, filename: Path | str) -> None: - self.filename = Path(filename).absolute() - self.timeout = FILE_LOCK_TIMEOUT - self.lock = lock - - def write(self, toml_dict: dict) -> None: - with self.lock.acquire(timeout=self.timeout): - with open(str(self.filename), "wb") as fp: - tomllib.dump(toml_dict, fp) - - def read(self) -> dict: - with self.lock.acquire(timeout=self.timeout): - with open(str(self.filename), "rb") as fp: - toml = tomllib.load(fp) - return toml diff --git a/packages/grid/rathole/server/utils.py b/packages/grid/rathole/server/utils.py deleted file mode 100644 index 485e4ae7e23..00000000000 --- a/packages/grid/rathole/server/utils.py +++ /dev/null @@ -1,236 +0,0 @@ -# stdlib - -# third party -from filelock import FileLock - -# relative -from .models import RatholeConfig -from .nginx_builder import RatholeNginxConfigBuilder -from .toml_writer import TomlReaderWriter - -lock = FileLock("rathole.toml.lock") - - -class RatholeClientToml: - filename: str = "client.toml" - - def __init__(self) -> None: - self.client_toml = TomlReaderWriter(lock=lock, filename=self.filename) - self.nginx_mananger = RatholeNginxConfigBuilder("nginx.conf") - - def set_remote_addr(self, remote_host: str) -> None: - """Add a new remote address to the client toml file.""" - - toml = self.client_toml.read() - - # Add the new remote address - if "client" not in toml: - toml["client"] = {} - - toml["client"]["remote_addr"] = remote_host - - if remote_host not in toml["client"]["remote"]: - toml["client"]["remote"].append(remote_host) - - self.client_toml.write(toml_dict=toml) - - def add_config(self, config: RatholeConfig) -> None: - """Add a new config to the toml file.""" - - toml = self.client_toml.read() - - # Add the new config - if "services" not in toml["client"]: - toml["client"]["services"] = {} - - if config.uuid not in toml["client"]["services"]: - toml["client"]["services"][config.uuid] = {} - - toml["client"]["services"][config.uuid] = { - "token": config.secret_token, - "local_addr": config.local_address, - } - - self.client_toml.write(toml) - - self.nginx_mananger.add_server( - config.local_addr_port, location="/", proxy_pass="http://backend:80" - ) - - def remove_config(self, uuid: str) -> None: - """Remove a config from the toml file.""" - - toml = self.client_toml.read() - - # Remove the config - if "services" not in toml["client"]: - return - - if uuid not in toml["client"]["services"]: - return - - del toml["client"]["services"][uuid] - - self.client_toml.write(toml) - - def update_config(self, config: RatholeConfig) -> None: - """Update a config in the toml file.""" - - toml = self.client_toml.read() - - # Update the config - if "services" not in toml["client"]: - return - - if config.uuid not in toml["client"]["services"]: - return - - toml["client"]["services"][config.uuid] = { - "token": config.secret_token, - "local_addr": config.local_address, - } - - self.client_toml.write(toml) - - def get_config(self, uuid: str) -> RatholeConfig | None: - """Get a config from the toml file.""" - - toml = self.client_toml.read() - - # Get the config - if "services" not in toml["client"]: - return None - - if uuid not in toml["client"]["services"]: - return None - - service = toml["client"]["services"][uuid] - - return RatholeConfig( - uuid=uuid, - secret_token=service["token"], - local_addr_host=service["local_addr"].split(":")[0], - local_addr_port=service["local_addr"].split(":")[1], - ) - - def _validate(self) -> bool: - if not self.client_toml.filename.exists(): - return False - - toml = self.client_toml.read() - - if not toml["client"]["remote_addr"]: - return False - - for uuid, config in toml["client"]["services"].items(): - if not uuid: - return False - - if not config["token"] or not config["local_addr"]: - return False - - return True - - @property - def is_valid(self) -> bool: - return self._validate() - - -class RatholeServerToml: - filename: str = "server.toml" - - def __init__(self) -> None: - self.server_toml = TomlReaderWriter(lock=lock, filename=self.filename) - self.nginx_manager = RatholeNginxConfigBuilder("nginx.conf") - - def set_bind_address(self, bind_address: str) -> None: - """Set the bind address in the server toml file.""" - - toml = self.server_toml.read() - - # Set the bind address - toml["server"]["bind_addr"] = bind_address - - self.server_toml.write(toml) - - def add_config(self, config: RatholeConfig) -> None: - """Add a new config to the toml file.""" - - toml = self.server_toml.read() - - # Add the new config - if "services" not in toml["server"]: - toml["server"]["services"] = {} - - if config.uuid not in toml["server"]["services"]: - toml["server"]["services"][config.uuid] = {} - - toml["server"]["services"][config.uuid] = { - "token": config.secret_token, - "bind_addr": config.local_address, - } - - self.server_toml.write(toml) - - self.nginx_manager.add_server( - config.local_addr_port, - location="/", - proxy_pass=config.local_address, - server_name=f"{config.server_name}.local*", - ) - - def remove_config(self, uuid: str) -> None: - """Remove a config from the toml file.""" - - toml = self.server_toml.read() - - # Remove the config - if "services" not in toml["server"]: - return - - if uuid not in toml["server"]["services"]: - return - - del toml["server"]["services"][uuid] - - self.server_toml.write(toml) - - def update_config(self, config: RatholeConfig) -> None: - """Update a config in the toml file.""" - - toml = self.server_toml.read() - - # Update the config - if "services" not in toml["server"]: - return - - if config.uuid not in toml["server"]["services"]: - return - - toml["server"]["services"][config.uuid] = { - "token": config.secret_token, - "bind_addr": config.local_address, - } - - self.server_toml.write(toml) - - def _validate(self) -> bool: - if not self.server_toml.filename.exists(): - return False - - toml = self.server_toml.read() - - if not toml["server"]["bind_addr"]: - return False - - for uuid, config in toml["server"]["services"].items(): - if not uuid: - return False - - if not config["token"] or not config["bind_addr"]: - return False - - return True - - def is_valid(self) -> bool: - return self._validate() From dd9a037203b0118ed2ba51763780787aae89c2b2 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 16 May 2024 12:35:59 +0530 Subject: [PATCH 041/309] use rathole image to build rathole --- packages/grid/rathole/rathole.dockerfile | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 1d1b82785af..5f3a8762677 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -1,32 +1,23 @@ ARG RATHOLE_VERSION="0.5.0" ARG PYTHON_VERSION="3.12" -FROM rust as build -ARG RATHOLE_VERSION -ARG FEATURES -RUN apt update && apt install -y git -RUN git clone -b v${RATHOLE_VERSION} https://github.com/rapiz1/rathole - -WORKDIR /rathole -RUN cargo build --locked --release --features ${FEATURES:-default} +FROM rapiz1/rathole:v${RATHOLE_VERSION} as build FROM python:${PYTHON_VERSION}-bookworm ARG RATHOLE_VERSION ENV MODE="client" -ENV APP_LOG_LEVEL="info" -COPY --from=build /rathole/target/release/rathole /app/rathole RUN apt update && apt install -y netcat-openbsd vim +COPY --from=build /app/rathole /app/rathole + WORKDIR /app COPY ./start.sh /app/start.sh -COPY ./nginx.conf /etc/nginx/conf.d/default.conf -COPY ./requirements.txt /app/requirements.txt -COPY ./server/ /app/server/ -RUN pip install --user -r requirements.txt -CMD ["sh", "-c", "/app/start.sh"] EXPOSE 2333/udp EXPOSE 2333 +CMD ["sh", "-c", "/app/start.sh"] + + # build and run a fake domain to simulate a normal http container service # docker build -f domain.dockerfile . -t domain # docker run --name domain1 -it -d -p 8080:8000 domain From f9a3fcf7d143608476c7318b95923385e69ddc88 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 17 May 2024 10:28:06 +0530 Subject: [PATCH 042/309] configure rathole toml to use websockets - add a path in ingress to resolve to rathole service - fix remote addr url in client.toml --- packages/grid/helm/syft/templates/global/ingress.yaml | 11 +++++++++-- .../syft/templates/rathole/rathole-configmap.yaml | 10 ++++++++++ .../syft/src/syft/service/network/network_service.py | 8 ++++++-- packages/syft/src/syft/types/grid_url.py | 4 ++++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/packages/grid/helm/syft/templates/global/ingress.yaml b/packages/grid/helm/syft/templates/global/ingress.yaml index 677a66313a6..71bac8d6964 100644 --- a/packages/grid/helm/syft/templates/global/ingress.yaml +++ b/packages/grid/helm/syft/templates/global/ingress.yaml @@ -27,13 +27,20 @@ spec: - host: {{ .Values.ingress.hostname | quote }} http: paths: - - backend: + - path: / + pathType: Prefix + backend: service: name: proxy port: number: 80 - path: / + - path: /rathole pathType: Prefix + backend: + service: + name: rathole + port: + number: 2333 {{- if .Values.ingress.tls.enabled }} tls: - hosts: diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index d9e004d3c5e..cd5453f1ea2 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -12,6 +12,11 @@ data: [server] bind_addr = "0.0.0.0:2333" + [server.transport] + type = "websocket" + [server.transport.websocket] + tls = false + [server.services.domain] token = "domain-specific-rathole-secret" bind_addr = "0.0.0.0:8001" @@ -21,4 +26,9 @@ data: client.toml: | [client] remote_addr = "0.0.0.0:2333" + + [client.transport] + type = "websocket" + [client.transport.websocket] + tls = false {{- end }} diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 2a8aa986846..e0143541728 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -49,7 +49,6 @@ from ..warnings import CRUDWarning from .association_request import AssociationRequestChange from .node_peer import NodePeer -from .rathole import get_rathole_port from .rathole_service import RatholeService from .routes import HTTPNodeRoute from .routes import NodeRoute @@ -206,7 +205,12 @@ def exchange_credentials_with( message=f"Failed to update remote node peer: {str(result.err())}" ) - remote_addr = f"{remote_node_route.protocol}://{remote_node_route.host_or_ip}:{get_rathole_port()}" + remote_url = GridURL( + host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port + ) + rathole_remote_addr = remote_url.with_path("/rathole").as_container_host() + + remote_addr = rathole_remote_addr.url_no_protocol self.rathole_service.add_host_to_client( peer_name=self_node_peer.name, diff --git a/packages/syft/src/syft/types/grid_url.py b/packages/syft/src/syft/types/grid_url.py index 91cf53e46d7..9db8de440a8 100644 --- a/packages/syft/src/syft/types/grid_url.py +++ b/packages/syft/src/syft/types/grid_url.py @@ -135,6 +135,10 @@ def base_url(self) -> str: def base_url_no_port(self) -> str: return f"{self.protocol}://{self.host_or_ip}" + @property + def url_no_protocol(self) -> str: + return f"{self.host_or_ip}:{self.port}{self.path}" + @property def url_path(self) -> str: return f"{self.path}{self.query_string}" From 299dbac78e589cecb99ff923cafbe06f5bd3bda8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 21 May 2024 23:14:13 +0530 Subject: [PATCH 043/309] configure to use same port and path for both http and websocket - add a traefik rule to route requests with Header with websocket to rathole --- packages/grid/helm/syft/templates/global/ingress.yaml | 11 ++--------- .../helm/syft/templates/proxy/proxy-configmap.yaml | 11 ++++++++++- .../syft/src/syft/service/network/network_service.py | 2 +- .../syft/src/syft/service/network/rathole_service.py | 4 ++-- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/packages/grid/helm/syft/templates/global/ingress.yaml b/packages/grid/helm/syft/templates/global/ingress.yaml index 71bac8d6964..677a66313a6 100644 --- a/packages/grid/helm/syft/templates/global/ingress.yaml +++ b/packages/grid/helm/syft/templates/global/ingress.yaml @@ -27,20 +27,13 @@ spec: - host: {{ .Values.ingress.hostname | quote }} http: paths: - - path: / - pathType: Prefix - backend: + - backend: service: name: proxy port: number: 80 - - path: /rathole + path: / pathType: Prefix - backend: - service: - name: rathole - port: - number: 2333 {{- if .Values.ingress.tls.enabled }} tls: - hosts: diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index ee5d2316e99..9a3cacf23b2 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -2,7 +2,6 @@ apiVersion: v1 kind: ConfigMap metadata: name: proxy-config - resourceVersion: "" labels: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: proxy @@ -22,7 +21,16 @@ data: loadBalancer: servers: - url: "http://seaweedfs:8333" + rathole: + loadBalancer: + servers: + - url: "http://rathole:2333" routers: + rathole: + rule: "PathPrefix(`/`) && Headers(`Upgrade`, `websocket`)" + entryPoints: + - "web" + service: "rathole" frontend: rule: "PathPrefix(`/`)" entryPoints: @@ -84,3 +92,4 @@ metadata: app.kubernetes.io/component: proxy data: rathole-dynamic.yml: | + diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index e0143541728..daff8c411f3 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -208,7 +208,7 @@ def exchange_credentials_with( remote_url = GridURL( host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port ) - rathole_remote_addr = remote_url.with_path("/rathole").as_container_host() + rathole_remote_addr = remote_url.as_container_host() remote_addr = rathole_remote_addr.url_no_protocol diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 250c56c21cb..9be55c25fcf 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -82,8 +82,8 @@ def add_host_to_client( config = RatholeConfig( uuid=peer_id, secret_token=rathole_token, - local_addr_host="localhost", - local_addr_port=random_port, + local_addr_host="proxy", + local_addr_port=8001, server_name=peer_name, ) From f67774c31315b5a9e6d1222d33d688601fd29d0b Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 22 May 2024 10:42:48 +0530 Subject: [PATCH 044/309] fix proxy port set to the client toml --- packages/syft/src/syft/service/network/rathole_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 9be55c25fcf..c9287535a0a 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -83,7 +83,7 @@ def add_host_to_client( uuid=peer_id, secret_token=rathole_token, local_addr_host="proxy", - local_addr_port=8001, + local_addr_port=80, server_name=peer_name, ) From fd60a9b889cbf17ae9a68dcd8263f89fdf573ac1 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 22 May 2024 18:33:19 +1000 Subject: [PATCH 045/309] Added build step for rathole for arm64 platforms --- packages/grid/rathole/rathole.dockerfile | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 5f3a8762677..7166afd63dd 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -1,13 +1,20 @@ ARG RATHOLE_VERSION="0.5.0" ARG PYTHON_VERSION="3.12" -FROM rapiz1/rathole:v${RATHOLE_VERSION} as build +FROM rust as build +ARG RATHOLE_VERSION +ARG FEATURES +RUN apt update && apt install -y git +RUN git clone -b v${RATHOLE_VERSION} https://github.com/rapiz1/rathole + +WORKDIR /rathole +RUN cargo build --locked --release --features ${FEATURES:-default} FROM python:${PYTHON_VERSION}-bookworm ARG RATHOLE_VERSION ENV MODE="client" RUN apt update && apt install -y netcat-openbsd vim -COPY --from=build /app/rathole /app/rathole +COPY --from=build /rathole/target/release/rathole /app/rathole WORKDIR /app COPY ./start.sh /app/start.sh From 72e7daa10a3a187ccf75b4ccf8b381906b5911b4 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 22 May 2024 21:42:08 +0530 Subject: [PATCH 046/309] start traefik in watch mode with watch on directory instead of a single file - replace localhost with 0.0.0.0 during local addr bind in rathole client.toml - fix allow backend to access and perform crud on resource Services --- .../syft/templates/backend/backend-service-account.yaml | 2 +- .../grid/helm/syft/templates/proxy/proxy-configmap.yaml | 4 +++- packages/syft/src/syft/service/network/rathole_service.py | 8 +------- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml index 7b542adfc0b..76d70afee70 100644 --- a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml @@ -26,7 +26,7 @@ metadata: app.kubernetes.io/component: backend rules: - apiGroups: [""] - resources: ["pods", "configmaps", "secrets", "service"] + resources: ["pods", "configmaps", "secrets", "services"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] - apiGroups: [""] resources: ["pods/log"] diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 9a3cacf23b2..c7628bc2ec4 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -80,7 +80,9 @@ data: providers: file: - filename: /etc/traefik/dynamic.yml + directory: /etc/traefik/ + watch: true + --- apiVersion: v1 diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index c9287535a0a..a514d499d98 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -37,7 +37,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: config = RatholeConfig( uuid=peer.id.to_string(), secret_token=peer.rathole_token, - local_addr_host="localhost", + local_addr_host="0.0.0.0", local_addr_port=random_port, server_name=peer.name, ) @@ -77,8 +77,6 @@ def add_host_to_client( ) -> None: """Add a host to the rathole client toml file.""" - random_port = self.get_random_port() - config = RatholeConfig( uuid=peer_id, secret_token=rathole_token, @@ -107,10 +105,6 @@ def add_host_to_client( # Update the rathole config map KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) - self.add_entrypoint(port=random_port, peer_name=peer_name) - - self.forward_port_to_proxy(config=config, entrypoint=peer_name) - def forward_port_to_proxy( self, config: RatholeConfig, entrypoint: str = "web" ) -> None: From 031dc80d6ecc10c75e2fc67556f1de8f7d78223d Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 23 May 2024 10:43:13 +0200 Subject: [PATCH 047/309] do serialization via objectregistry --- .../syft/src/syft/capnp/recursive_serde.capnp | 2 + packages/syft/src/syft/node/node.py | 18 ++- .../syft/src/syft/protocol/data_protocol.py | 79 ++++------ .../src/syft/protocol/protocol_version.json | 24 +--- packages/syft/src/syft/serde/recursive.py | 58 +++++--- .../src/syft/service/action/action_store.py | 2 +- .../syft/src/syft/service/job/job_stash.py | 40 +++++- packages/syft/src/syft/service/log/log.py | 32 ++++- .../syft/service/notifier/notifier_service.py | 3 +- .../syft/src/syft/store/kv_document_store.py | 4 +- packages/syft/src/syft/types/syft_object.py | 120 +++++----------- .../src/syft/types/syft_object_registry.py | 136 ++++++++++++++++++ packages/syft/src/syft/types/transforms.py | 2 +- packages/syft/src/syft/util/schema.py | 1 + .../tests/syft/transforms/transforms_test.py | 2 +- 15 files changed, 340 insertions(+), 183 deletions(-) create mode 100644 packages/syft/src/syft/types/syft_object_registry.py diff --git a/packages/syft/src/syft/capnp/recursive_serde.capnp b/packages/syft/src/syft/capnp/recursive_serde.capnp index 8f4b1b17953..c29ba57aae6 100644 --- a/packages/syft/src/syft/capnp/recursive_serde.capnp +++ b/packages/syft/src/syft/capnp/recursive_serde.capnp @@ -5,4 +5,6 @@ struct RecursiveSerde { fieldsData @1 :List(List(Data)); fullyQualifiedName @2 :Text; nonrecursiveBlob @3 :List(Data); + canonicalName @4 :Text; + version @5 :Int32; } diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 7861ee422e0..14632f914a0 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -405,10 +405,16 @@ def __init__( self.create_initial_settings(admin_email=root_email) - self.init_queue_manager(queue_config=self.queue_config) self.init_blob_storage(config=blob_storage_config) + # Migrate data before any operation on db + if migrate: + self.find_and_migrate_data() + + # first migrate, for backwards compatibility + self.init_queue_manager(queue_config=self.queue_config) + context = AuthedServiceContext( node=self, credentials=self.verify_key, @@ -419,9 +425,7 @@ def __init__( if background_tasks: self.run_peer_health_checks(context=context) - # Migrate data before any operation on db - if migrate: - self.find_and_migrate_data() + NodeRegistry.set_node_for(self.id, self) @@ -1634,7 +1638,11 @@ def create_admin_new( else: raise Exception(f"Could not create user: {result}") except Exception as e: - print("Unable to create new admin", e) + # import ipdb + # ipdb.set_trace() + import traceback + + print("Unable to create new admin", traceback.format_exc()) return None diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 79f0d680658..912537e1266 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -21,12 +21,12 @@ # relative from .. import __version__ -from ..serde.recursive import TYPE_BANK from ..service.response import SyftError from ..service.response import SyftException from ..service.response import SyftSuccess from ..types.dicttuple import DictTuple from ..types.syft_object import SyftBaseObject +from ..types.syft_object_registry import SyftObjectRegistry from ..util.util import get_dev_mode PROTOCOL_STATE_FILENAME = "protocol_version.json" @@ -208,22 +208,19 @@ def build_state(self, stop_key: str | None = None) -> dict: return state_dict return state_dict + @staticmethod + def obj_json(version, _hash, action="add"): + return { + "version": int(version), + "hash": _hash, + "action": action, + } + def diff_state(self, state: dict) -> tuple[dict, dict]: compare_dict: dict = defaultdict(dict) # what versions are in the latest code object_diff: dict = defaultdict(dict) # diff in latest code with saved json - for k in TYPE_BANK: - ( - nonrecursive, - serialize, - deserialize, - attribute_list, - exclude_attrs_list, - serde_overrides, - hash_exclude_attrs, - cls, - attribute_types, - version, - ) = TYPE_BANK[k] + for serde_properties in SyftObjectRegistry.__object_serialization_registry__.values(): + cls, version = serde_properties[7], serde_properties[9] if issubclass(cls, SyftBaseObject): canonical_name = cls.__canonical_name__ if canonical_name in IGNORE_TYPES or canonical_name.startswith( @@ -238,10 +235,8 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: if canonical_name not in state: # new object so its an add - object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = int(version) - object_diff[canonical_name][str(version)]["hash"] = hash_str - object_diff[canonical_name][str(version)]["action"] = "add" + obj_to_add = self.obj_json(int(version), hash_str) + object_diff[canonical_name][str(version)] = obj_to_add continue versions = state[canonical_name] @@ -255,23 +250,15 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: is_protocol_dev = versions[str(version)][1] == "dev" if is_protocol_dev: # force overwrite existing object so its an add - object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = int( - version - ) - object_diff[canonical_name][str(version)]["hash"] = hash_str - object_diff[canonical_name][str(version)]["action"] = "add" + obj_to_add = self.obj_json(int(version), hash_str) + object_diff[canonical_name][str(version)] = obj_to_add continue - - error_msg = ( - f"{canonical_name} for class {cls.__name__} fqn {cls} " - + f"version {version} hash has changed. " - + f"{hash_str} not in {versions.values()}. " - + "Is a unique __canonical_name__ for this subclass missing? " - + "If the class has changed you will need to define a new class with the changes, " - + "with same __canonical_name__ and bump the __version__ number." - + f"{cls.model_fields}" - ) + error_msg = f"""{canonical_name} for class {cls.__name__} fqn {cls}\ +version {version} hash has changed. {hash_str} not in {versions.values()}. \ +Is a unique __canonical_name__ for this subclass missing? +If the class has changed you will need to define a new class with the changes, \ +with same __canonical_name__ and bump the __version__ number. {cls.model_fields} +""" if get_dev_mode() or self.raise_exception: raise Exception(error_msg) @@ -280,10 +267,8 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: break else: # new object so its an add - object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = int(version) - object_diff[canonical_name][str(version)]["hash"] = hash_str - object_diff[canonical_name][str(version)]["action"] = "add" + obj_to_add = self.obj_json(int(version), hash_str) + object_diff[canonical_name][str(version)] = obj_to_add continue # now check for remove actions @@ -291,18 +276,14 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: for version, (hash_str, _) in state[canonical_name].items(): if canonical_name not in compare_dict: # missing so its a remove - object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = int(version) - object_diff[canonical_name][str(version)]["hash"] = hash_str - object_diff[canonical_name][str(version)]["action"] = "remove" + obj_to_remove = self.obj_json(int(version), hash_str, "remove") + object_diff[canonical_name][str(version)] = obj_to_remove continue versions = compare_dict[canonical_name] if str(version) not in versions.keys(): # missing so its a remove - object_diff[canonical_name][str(version)] = {} - object_diff[canonical_name][str(version)]["version"] = int(version) - object_diff[canonical_name][str(version)]["hash"] = hash_str - object_diff[canonical_name][str(version)]["action"] = "remove" + obj_to_remove = self.obj_json(int(version), hash_str, "remove") + object_diff[canonical_name][str(version)] = obj_to_remove continue return object_diff, compare_dict @@ -436,9 +417,9 @@ def validate_release(self) -> None: # Update older file path to newer file path latest_protocol_fp.rename(new_protocol_file_path) - protocol_history[latest_protocol]["release_name"] = ( - f"{current_syft_version}.json" - ) + protocol_history[latest_protocol][ + "release_name" + ] = f"{current_syft_version}.json" # Save history self.file_path.write_text(json.dumps(protocol_history, indent=2) + "\n") diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index e30f48dfd5a..9aff20c3113 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -27,6 +27,13 @@ "action": "add" } }, + "BlobRetrievalByURL": { + "5": { + "version": 5, + "hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3", + "action": "add" + } + }, "EnclaveMetadata": { "2": { "version": 2, @@ -47,11 +54,6 @@ } }, "JobItem": { - "4": { - "version": 4, - "hash": "6a7cc7c2bb4dd234c1508b0af4d3b403cd3b7b427578a775bf80dc36891923ed", - "action": "remove" - }, "5": { "version": 5, "hash": "82ee08442b09797ed7a3710c31de633bb308b1d2215f51b58a3e01a4c201055d", @@ -127,11 +129,6 @@ } }, "SyftLog": { - "3": { - "version": 3, - "hash": "8964d48238672e0e5d5db6b932cda4ee8eb77581949ab3f7a38a05b1efec13b7", - "action": "remove" - }, "4": { "version": 4, "hash": "ad6ef18ccd87fced669f3824d27ab423aaf52574b0cd4f720687aeaba77524e5", @@ -184,13 +181,6 @@ "hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b", "action": "remove" } - }, - "BlobRetrievalByURL": { - "5": { - "version": 5, - "hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3", - "action": "add" - } } } } diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 02957e5f23d..c7f0fdcd02d 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -13,6 +13,7 @@ # syft absolute import syft as sy +from syft.types.syft_object_registry import SyftObjectRegistry # relative from ..util.util import get_fully_qualified_name @@ -157,7 +158,18 @@ def recursive_serde_register( version, ) + TYPE_BANK[fqn] = serde_attributes + if hasattr(cls, "__canonical_name__"): + canonical_name = cls.__canonical_name__ + version = cls.__version__ + else: + # TODO: refactor + canonical_name = fqn.split(".")[-1] + version = 1 + + SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] = serde_attributes + if isinstance(alias_fqn, tuple): for alias in alias_fqn: @@ -215,23 +227,30 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild is_type = True msg = recursive_scheme.new_message() - fqn = get_fully_qualified_name(self) - if fqn not in TYPE_BANK: + + # todo: rewrite and make sure every object has a canonical name and version + canonical_name = SyftObjectRegistry.get_canonical_name(self) + version = getattr(self, "__version__", 1) + + if not SyftObjectRegistry.has_serde_class("", canonical_name, version): # third party - raise Exception(f"{fqn} not in TYPE_BANK") - msg.fullyQualifiedName = fqn + raise Exception(f"{canonical_name} version {version} not in SyftObjectRegistry") + + msg.canonicalName = canonical_name + msg.version=version + ( nonrecursive, serialize, - deserialize, + _, attribute_list, exclude_attrs_list, serde_overrides, hash_exclude_attrs, - cls, - attribute_types, - version, - ) = TYPE_BANK[fqn] + _, + _, + _, + ) = SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] if nonrecursive or is_type: if serialize is None: @@ -318,8 +337,11 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: except Exception: # nosec pass - if proto.fullyQualifiedName not in TYPE_BANK: - raise Exception(f"{proto.fullyQualifiedName} not in TYPE_BANK") + canonical_name = proto.canonicalName + version = getattr(proto, "version", -1) + fqn = getattr(proto, "fullyQualifiedName", "") + if not SyftObjectRegistry.has_serde_class(fqn, canonical_name, version): + raise Exception(f"{canonical_name} version {version} not in SyftObjectRegistry") # TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the # module name space in sub-processes or threads even though they are loaded on start @@ -328,18 +350,18 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: # causes some errors so it seems like we want to get the local one where possible ( nonrecursive, - serialize, + _, deserialize, - attribute_list, - exclude_attrs_list, + _, + _, serde_overrides, - hash_exclude_attrs, + _, cls, - attribute_types, + _, version, - ) = TYPE_BANK[proto.fullyQualifiedName] + ) = SyftObjectRegistry.get_serde_properties(fqn, canonical_name, version) - if class_type == type(None): + if class_type == type(None) or fqn != "": # yes this looks stupid but it works and the opposite breaks class_type = cls diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 001aa7a4e0f..fae5921828e 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -314,7 +314,7 @@ def migrate_data( migrated_value = value.migrate_to(to_klass.__version__) except Exception as e: return Err( - f"Failed to migrate data to {to_klass} for qk: {key}. Exception: {e}" + f"Failed to migrate data to {to_klass} {to_klass.__version__} for qk: {key}. Exception: {e}" ) result = self.set( uid=key, diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index d7aa3aca00b..d86f7363904 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -4,7 +4,7 @@ from enum import Enum import random from string import Template -from typing import Any +from typing import Any, Callable # third party from pydantic import Field @@ -12,6 +12,8 @@ from result import Err from result import Ok from result import Result +from syft.types.syft_migration import migrate +from syft.types.transforms import drop, make_set_default from typing_extensions import Self # relative @@ -28,7 +30,7 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_2, SYFT_OBJECT_VERSION_4 from ...types.syft_object import SYFT_OBJECT_VERSION_5 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject @@ -73,6 +75,32 @@ def center_content(text: Any) -> str: return center_div +@serializable() +class JobV4(SyncableSyftObject): + __canonical_name__ = "JobItem" + __version__ = SYFT_OBJECT_VERSION_4 + + id: UID + node_uid: UID + result: Any | None = None + resolved: bool = False + status: JobStatus = JobStatus.CREATED + log_id: UID | None = None + parent_job_id: UID | None = None + n_iters: int | None = 0 + current_iter: int | None = None + creation_time: str | None = None + action: Action | None = None + job_pid: int | None = None + job_worker_id: UID | None = None + updated_at: DateTime | None = None + user_code_id: UID | None = None + + __attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"] + __repr_attrs__ = ["id", "result", "resolved", "progress", "creation_time"] + __exclude_sync_diff_attrs__ = ["action"] + + @serializable() class Job(SyncableSyftObject): __canonical_name__ = "JobItem" @@ -720,6 +748,14 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # return dependencies +@migrate(Job, JobV4) +def upgrade_job() -> list[Callable]: + return [make_set_default("requested_by", UID())] + +@migrate(JobV4, Job) +def downgrade_job() -> list[Callable]: + return [drop("requested_by")] + @serializable() class JobInfo(SyftObject): __canonical_name__ = "JobInfo" diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index ba6d0761918..48431952d90 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,15 +1,33 @@ # stdlib -from typing import Any +from typing import Any, Callable from typing import ClassVar +from syft.types.syft_migration import migrate +from syft.types.transforms import drop, make_set_default + # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext -from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SYFT_OBJECT_VERSION_3, SYFT_OBJECT_VERSION_4 from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID +@serializable() +class SyftLogV3(SyncableSyftObject): + __canonical_name__ = "SyftLog" + __version__ = SYFT_OBJECT_VERSION_3 + + __repr_attrs__ = ["stdout", "stderr"] + __exclude_sync_diff_attrs__: list[str] = [] + __private_sync_attr_mocks__: ClassVar[dict[str, Any]] = { + "stderr": "", + "stdout": "", + } + + stdout: str = "" + stderr: str = "" + @serializable() class SyftLog(SyncableSyftObject): __canonical_name__ = "SyftLog" @@ -40,3 +58,13 @@ def get_sync_dependencies( self, context: AuthedServiceContext, **kwargs: dict ) -> list[UID]: # type: ignore return [self.job_id] + +@migrate(SyftLogV3, SyftLog) +def upgrade_syftlog() -> list[Callable]: + # TODO: FIX + return [make_set_default("job_id", UID())] + +@migrate(SyftLog, SyftLogV3) +def downgrade_syftlog() -> list[Callable]: + return [drop("job_id")] + diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index aedb59b2e24..3913c533514 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -7,6 +7,7 @@ from result import Err from result import Ok from result import Result +import traceback # relative from ...abstract_node import AbstractNode @@ -277,7 +278,7 @@ def init_notifier( return Ok("Notifier initialized successfully") except Exception as e: - raise Exception(f"Error initializing notifier. \n {e}") + raise Exception(f"Error initializing notifier. \n {traceback.format_exc()}") # This is not a public API. # This method is used by other services to dispatch notifications internally diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index b594be92775..c21c66b42f6 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -674,7 +674,9 @@ def _migrate_data( try: migrated_value = value.migrate_to(to_klass.__version__, context) except Exception: - return Err(f"Failed to migrate data to {to_klass} for qk: {key}") + import traceback + print(traceback.format_exc()) + return Err(f"Failed to migrate data to {to_klass} for qk {to_klass.__version__}: {key}") qk = self.settings.store_key.with_obj(key) result = self._update( credentials, diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index a290e4ff080..922ae642da9 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -50,6 +50,7 @@ if TYPE_CHECKING: # relative from ..service.sync.diff_state import AttrDiff + from .syft_object_registry import SyftObjectRegistry IntStr = int | str AbstractSetIntStr = Set[IntStr] @@ -138,90 +139,6 @@ class Context(SyftBaseObject): pass -class SyftObjectRegistry: - __object_version_registry__: dict[ - str, type["SyftObject"] | type["SyftObjectRegistry"] - ] = {} - __object_transform_registry__: dict[str, Callable] = {} - - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - if hasattr(cls, "__canonical_name__") and hasattr(cls, "__version__"): - mapping_string = f"{cls.__canonical_name__}_{cls.__version__}" - - if ( - mapping_string in cls.__object_version_registry__ - and not autoreload_enabled() - ): - current_cls = cls.__object_version_registry__[mapping_string] - if cls == current_cls: - # same class so noop - return None - - # user code is reinitialized which means it might have a new address - # in memory so for that we can just skip - if "syft.user" in cls.__module__: - # this happens every time we reload the user code - return None - else: - # this shouldn't happen and is usually a mistake of reusing the - # same __canonical_name__ and __version__ in two classes - raise Exception(f"Duplicate mapping for {mapping_string} and {cls}") - else: - # only if the cls has not been registered do we want to register it - cls.__object_version_registry__[mapping_string] = cls - - @classmethod - def versioned_class( - cls, name: str, version: int - ) -> type["SyftObject"] | type["SyftObjectRegistry"] | None: - mapping_string = f"{name}_{version}" - if mapping_string not in cls.__object_version_registry__: - return None - return cls.__object_version_registry__[mapping_string] - - @classmethod - def add_transform( - cls, - klass_from: str, - version_from: int, - klass_to: str, - version_to: int, - method: Callable, - ) -> None: - mapping_string = f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" - cls.__object_transform_registry__[mapping_string] = method - - @classmethod - def get_transform( - cls, type_from: type["SyftObject"], type_to: type["SyftObject"] - ) -> Callable: - for type_from_mro in type_from.mro(): - if issubclass(type_from_mro, SyftObject): - klass_from = type_from_mro.__canonical_name__ - version_from = type_from_mro.__version__ - else: - klass_from = type_from_mro.__name__ - version_from = None - for type_to_mro in type_to.mro(): - if issubclass(type_to_mro, SyftBaseObject): - klass_to = type_to_mro.__canonical_name__ - version_to = type_to_mro.__version__ - else: - klass_to = type_to_mro.__name__ - version_to = None - - mapping_string = ( - f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" - ) - if mapping_string in cls.__object_transform_registry__: - return cls.__object_transform_registry__[mapping_string] - raise Exception( - f"No mapping found for: {type_from} to {type_to} in" - f"the registry: {cls.__object_transform_registry__.keys()}" - ) - - class SyftMigrationRegistry: __migration_version_registry__: dict[str, dict[int, str]] = {} __migration_transform_registry__: dict[str, dict[str, Callable]] = {} @@ -378,7 +295,37 @@ def get_migration_for_version( ] -class SyftObject(SyftBaseObject, SyftObjectRegistry, SyftMigrationRegistry): +class RegisteredSyftObject(): + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + from .syft_object_registry import SyftObjectRegistry as reg + + if hasattr(reg, "__canonical_name__") and hasattr(reg, "__version__"): + mapping_string = f"{reg.__canonical_name__}_{reg.__version__}" + + if ( + mapping_string in reg.__object_version_registry__ + and not autoreload_enabled() + ): + current_cls = reg.__object_version_registry__[mapping_string] + if reg == current_cls: + # same class so noop + return None + + # user code is reinitialized which means it might have a new address + # in memory so for that we can just skip + if "syft.user" in reg.__module__: + # this happens every time we reload the user code + return None + else: + # this shouldn't happen and is usually a mistake of reusing the + # same __canonical_name__ and __version__ in two classes + raise Exception(f"Duplicate mapping for {mapping_string} and {reg}") + else: + # only if the cls has not been registered do we want to register it + reg.__object_version_registry__[mapping_string] = reg + +class SyftObject(SyftBaseObject, RegisteredSyftObject, SyftMigrationRegistry): __canonical_name__ = "SyftObject" __version__ = SYFT_OBJECT_VERSION_2 @@ -530,6 +477,7 @@ def __getitem__(self, key: str | int) -> Any: return self.__dict__.__getitem__(key) # type: ignore def _upgrade_version(self, latest: bool = True) -> "SyftObject": + from .syft_object_registry import SyftObjectRegistry constructor = SyftObjectRegistry.versioned_class( name=self.__canonical_name__, version=self.__version__ + 1 ) @@ -544,6 +492,7 @@ def _upgrade_version(self, latest: bool = True) -> "SyftObject": # transform from one supported type to another def to(self, projection: type, context: Context | None = None) -> Any: + from .syft_object_registry import SyftObjectRegistry # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) @@ -769,6 +718,7 @@ def short_uid(uid: UID | None) -> str | None: class StorableObjectType: def to(self, projection: type, context: Context | None = None) -> Any: # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? + from .syft_object_registry import SyftObjectRegistry transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py new file mode 100644 index 00000000000..40d5f308120 --- /dev/null +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -0,0 +1,136 @@ +# stdlib +from collections.abc import Callable +from typing import Any +from typing import TYPE_CHECKING + +from syft.util.util import get_fully_qualified_name + +SYFT_086_PROTOCOL_VERSION = '4' + +# third party + +# relative +if TYPE_CHECKING: + # relative + from syft.types.syft_object import SyftObject + + + + +class SyftObjectRegistry: + __object_version_registry__: dict[ + str, type["SyftObject"] | type["SyftObjectRegistry"] + ] = {} + __object_transform_registry__: dict[str, Callable] = {} + __object_serialization_registry__: dict[tuple[str, str]: tuple] = {} + + @classmethod + def get_canonical_name(cls, obj): + res = getattr(obj, "__canonical_name__", None) + if res is not None: + return res + else: + fqn = get_fully_qualified_name(obj) + return fqn.split(".")[-1] + + @classmethod + def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tuple: + from syft.serde.recursive import TYPE_BANK + if canonical_name != "" and canonical_name is not None: + return cls.__object_serialization_registry__[canonical_name, version] + else: + # this is for backward compatibility with 0.8.6 + try: + from syft.protocol.data_protocol import get_data_protocol + serde_props = TYPE_BANK[fqn] + klass = serde_props[7] + is_syftobject = hasattr(klass, "__canonical_name__") + if is_syftobject: + canonical_name = klass.__canonical_name__ + dp = get_data_protocol() + try: + version_mutations = dp.protocol_history[SYFT_086_PROTOCOL_VERSION]["object_versions"][canonical_name] + except Exception: + print(f"could not find {canonical_name} in protocol history") + raise + + version_086 = max([int(k) for k, v in version_mutations.items() if v["action"] == "add"]) + try: + res = cls.__object_serialization_registry__[canonical_name, version_086] + except Exception: + print(f"could not find {canonical_name} {version_086} in ObjectRegistry") + raise + return res + except Exception as e: + print(e) + import ipdb + ipdb.set_trace() + else: + # TODO, add refactoring for non syftobject versions + canonical_name = fqn.split(".")[-1] + version = 1 + return cls.__object_serialization_registry__[canonical_name, version] + + + @classmethod + def has_serde_class(cls, fqn:str, canonical_name: str, version: str) -> tuple: + from syft.serde.recursive import TYPE_BANK + if canonical_name != "" and canonical_name is not None: + return (canonical_name, version) in cls.__object_serialization_registry__ + else: + # this is for backward compatibility with 0.8.6 + return fqn in TYPE_BANK + + + + @classmethod + def add_transform( + cls, + klass_from: str, + version_from: int, + klass_to: str, + version_to: int, + method: Callable, + ) -> None: + mapping_string = f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" + cls.__object_transform_registry__[mapping_string] = method + + @classmethod + def get_transform( + cls, type_from: type["SyftObject"], type_to: type["SyftObject"] + ) -> Callable: + from .syft_object import SyftBaseObject, SyftObject + for type_from_mro in type_from.mro(): + if issubclass(type_from_mro, SyftObject): + klass_from = type_from_mro.__canonical_name__ + version_from = type_from_mro.__version__ + else: + klass_from = type_from_mro.__name__ + version_from = None + for type_to_mro in type_to.mro(): + if issubclass(type_to_mro, SyftBaseObject): + klass_to = type_to_mro.__canonical_name__ + version_to = type_to_mro.__version__ + else: + klass_to = type_to_mro.__name__ + version_to = None + + mapping_string = ( + f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" + ) + if mapping_string in SyftObjectRegistry.__object_transform_registry__: + return SyftObjectRegistry.__object_transform_registry__[mapping_string] + raise Exception( + f"No mapping found for: {type_from} to {type_to} in" + f"the registry: {SyftObjectRegistry.__object_transform_registry__.keys()}" + ) + + @classmethod + def versioned_class( + cls, name: str, version: int + ) -> type["SyftObject"] | type["SyftObjectRegistry"] | None: + from .syft_object_registry import SyftObjectRegistry + mapping_string = f"{name}_{version}" + if mapping_string not in SyftObjectRegistry.__object_version_registry__: + return None + return SyftObjectRegistry.__object_version_registry__[mapping_string] diff --git a/packages/syft/src/syft/types/transforms.py b/packages/syft/src/syft/types/transforms.py index 4b85edac53f..e1e444b2b42 100644 --- a/packages/syft/src/syft/types/transforms.py +++ b/packages/syft/src/syft/types/transforms.py @@ -14,7 +14,7 @@ from .grid_url import GridURL from .syft_object import Context from .syft_object import SyftBaseObject -from .syft_object import SyftObjectRegistry +from .syft_object_registry import SyftObjectRegistry from .uid import UID diff --git a/packages/syft/src/syft/util/schema.py b/packages/syft/src/syft/util/schema.py index 8ab54cbdea2..43f3d70c266 100644 --- a/packages/syft/src/syft/util/schema.py +++ b/packages/syft/src/syft/util/schema.py @@ -147,6 +147,7 @@ def resolve_references(json_mappings: dict[str, dict]) -> dict[str, dict]: def generate_json_schemas(output_path: str | None = None) -> None: + # TODO: should we also replace this with the SyftObjectRegistry? json_mappings = process_type_bank(sy.serde.recursive.TYPE_BANK) json_mappings = resolve_references(json_mappings) if not output_path: diff --git a/packages/syft/tests/syft/transforms/transforms_test.py b/packages/syft/tests/syft/transforms/transforms_test.py index d6555dc8657..465e876b464 100644 --- a/packages/syft/tests/syft/transforms/transforms_test.py +++ b/packages/syft/tests/syft/transforms/transforms_test.py @@ -8,7 +8,7 @@ # syft absolute from syft.types import transforms from syft.types.syft_object import SyftBaseObject -from syft.types.syft_object import SyftObjectRegistry +from syft.types.syft_object_registry import SyftObjectRegistry from syft.types.transforms import TransformContext from syft.types.transforms import validate_klass_and_version From 59ecc5d108d9ffef06ae8db42e557badf3ddd8c7 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 23 May 2024 17:25:38 +0530 Subject: [PATCH 048/309] update method to expose port on rathole service --- packages/syft/setup.cfg | 1 - packages/syft/src/syft/custom_worker/k8s.py | 6 ++ .../syft/service/network/rathole_service.py | 65 +++++++++++++------ 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 6f69310c34e..a64a91fc049 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -91,7 +91,6 @@ data_science = opendp==0.9.2 evaluate==0.4.1 recordlinkage==0.16 - dm-haiku==0.0.10 torch[cpu]==2.2.1 dev = diff --git a/packages/syft/src/syft/custom_worker/k8s.py b/packages/syft/src/syft/custom_worker/k8s.py index 60557c86afb..d9702e72f3f 100644 --- a/packages/syft/src/syft/custom_worker/k8s.py +++ b/packages/syft/src/syft/custom_worker/k8s.py @@ -12,6 +12,7 @@ from kr8s.objects import ConfigMap from kr8s.objects import Pod from kr8s.objects import Secret +from kr8s.objects import Service from pydantic import BaseModel from typing_extensions import Self @@ -177,6 +178,11 @@ def get_configmap(client: kr8s.Api, name: str) -> ConfigMap | None: config_map = client.get("configmaps", name) return config_map[0] if config_map else None + @staticmethod + def get_service(client: kr8s.Api, name: str) -> Service | None: + service = client.get("services", name) + return service[0] if service else None + @staticmethod def update_configmap( config_map: ConfigMap, diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index a514d499d98..0894544de44 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -1,12 +1,15 @@ # stdlib import secrets +from typing import cast # third party +from kr8s.objects import Service import yaml # relative from ...custom_worker.k8s import KubeUtils from ...custom_worker.k8s import get_kr8s_client +from ...types.uid import UID from .node_peer import NodePeer from .rathole import RatholeConfig from .rathole import get_rathole_port @@ -34,8 +37,10 @@ def add_host_to_server(self, peer: NodePeer) -> None: random_port = self.get_random_port() + peer_id = cast(UID, peer.id) + config = RatholeConfig( - uuid=peer.id.to_string(), + uuid=peer_id.to_string(), secret_token=peer.rathole_token, local_addr_host="0.0.0.0", local_addr_port=random_port, @@ -47,6 +52,9 @@ def add_host_to_server(self, peer: NodePeer) -> None: client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) + if rathole_config_map is None: + raise Exception("Rathole config map not found.") + client_filename = RatholeServerToml.filename toml_str = rathole_config_map.data[client_filename] @@ -90,6 +98,9 @@ def add_host_to_client( client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP ) + if rathole_config_map is None: + raise Exception("Rathole config map not found.") + client_filename = RatholeClientToml.filename toml_str = rathole_config_map.data[client_filename] @@ -114,6 +125,9 @@ def forward_port_to_proxy( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) + if rathole_proxy_config_map is None: + raise Exception("Rathole proxy config map not found.") + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] if not rathole_proxy: @@ -145,6 +159,9 @@ def add_dynamic_addr_to_rathole( self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP ) + if rathole_proxy_config_map is None: + raise Exception("Rathole proxy config map not found.") + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] if not rathole_proxy: @@ -169,36 +186,42 @@ def add_dynamic_addr_to_rathole( patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, ) - def add_entrypoint(self, port: int, peer_name: str) -> None: - """Add an entrypoint to the traefik config map.""" + self.expose_port_on_rathole_service(config.server_name, config.local_addr_port) - proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + def expose_port_on_rathole_service(self, port_name: str, port: int) -> None: + """Expose a port on the rathole service.""" - data = proxy_config_map.data + rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") - traefik_config_str = data["traefik.yml"] + rathole_service = cast(Service, rathole_service) - traefik_config = yaml.safe_load(traefik_config_str) + config = rathole_service.raw - traefik_config["entryPoints"][f"{peer_name}"] = {"address": f":{port}"} - - data["traefik.yml"] = yaml.safe_dump(traefik_config) - - KubeUtils.update_configmap(config_map=proxy_config_map, patch={"data": data}) + config["spec"]["ports"].append( + { + "name": port_name, + "port": port, + "targetPort": port, + "protocol": "TCP", + } + ) - def remove_endpoint(self, peer_name: str) -> None: - """Remove an entrypoint from the traefik config map.""" + rathole_service.patch(config) - proxy_config_map = KubeUtils.get_configmap(self.k8rs_client, PROXY_CONFIG_MAP) + def remove_port_on_rathole_service(self, port_name: str) -> None: + """Remove a port from the rathole service.""" - data = proxy_config_map.data + rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") - traefik_config_str = data["traefik.yml"] + rathole_service = cast(Service, rathole_service) - traefik_config = yaml.safe_load(traefik_config_str) + config = rathole_service.raw - del traefik_config["entryPoints"][f"{peer_name}"] + ports = config["spec"]["ports"] - data["traefik.yml"] = yaml.safe_dump(traefik_config) + for port in ports: + if port["name"] == port_name: + ports.remove(port) + break - KubeUtils.update_configmap(config_map=proxy_config_map, patch={"data": data}) + rathole_service.patch(config) From 7c540961a09ed5dbae5573270d3f735c6434b6fa Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 23 May 2024 22:25:43 +0530 Subject: [PATCH 049/309] fix proxy rule for dynamically added router rules for rathole - set RUST_LOG level to trace for debugging --- packages/grid/rathole/rathole.dockerfile | 2 +- packages/grid/rathole/start.sh | 4 ++-- packages/syft/src/syft/service/network/rathole_service.py | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 7166afd63dd..42d147527c7 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -13,7 +13,7 @@ RUN cargo build --locked --release --features ${FEATURES:-default} FROM python:${PYTHON_VERSION}-bookworm ARG RATHOLE_VERSION ENV MODE="client" -RUN apt update && apt install -y netcat-openbsd vim +RUN apt update && apt install -y netcat-openbsd vim rsync COPY --from=build /rathole/target/release/rathole /app/rathole WORKDIR /app diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index 4095f30f3aa..0e708908836 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -4,10 +4,10 @@ MODE=${MODE:-server} cp -L -r -f /conf/* conf/ if [[ $MODE == "server" ]]; then - /app/rathole conf/server.toml & + RUST_LOG=trace /app/rathole conf/server.toml & elif [[ $MODE = "client" ]]; then while true; do - /app/rathole conf/client.toml + RUST_LOG=trace /app/rathole conf/client.toml status=$? if [ $status -eq 0 ]; then break diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 0894544de44..2051035dd6a 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -175,8 +175,13 @@ def add_dynamic_addr_to_rathole( } } + proxy_rule = ( + f"Host(`{config.server_name}.syft.local`) || " + f"HostHeader(`{config.server_name}.syft.local`) && PathPrefix(`/`)" + ) + rathole_proxy["http"]["routers"][config.server_name] = { - "rule": f"Host(`{config.server_name}.syft.local`)", + "rule": proxy_rule, "service": config.server_name, "entryPoints": [entrypoint], } From 644ae705216b1fc1a15276b56bd73fb7dfc1e50b Mon Sep 17 00:00:00 2001 From: teo Date: Fri, 24 May 2024 09:45:57 +0300 Subject: [PATCH 050/309] added tox test --- tox.ini | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tox.ini b/tox.ini index f3763f23d0c..d3f9c28509f 100644 --- a/tox.ini +++ b/tox.ini @@ -37,6 +37,8 @@ envlist = seaweedfs.test.unit backend.test.basecpu e2e.test.notebook + migration.prepare + migraion.test skipsdist = True @@ -1234,3 +1236,33 @@ allowlist_externals = commands = bash -c 'ulimit -n 4096 || true' pytest --disable-warnings + +[testenv:migration.prepare] +description = Migration Test +deps = + syft + nbmake +allowlist_externals = + bash + pytest +; changedir +; setenv +commands = + bash -c 'python -c "import syft as sy; print(sy.__version__)"' + pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/0-prepare-migration-data.ipynb -vvvv + +[testenv:migration.test] +description = Migration Test +; setenv +deps = + -e{toxinidir}/packages/syft[dev] + nbmake +; changedir = {toxinidir}/packages/syft +allowlist_externals = + bash + tox + pytest +commands = + bash -c 'python -c "import syft as sy; print(sy.__version__)"' + tox -e migration.prepare + pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/1-connect-and-migrate.ipynb -vvvv From ff8e2dac3f18128f36f1097e90c44965d08c5a74 Mon Sep 17 00:00:00 2001 From: teo Date: Fri, 24 May 2024 09:46:10 +0300 Subject: [PATCH 051/309] added versions to missing classes --- .../syft/src/syft/service/action/action_graph.py | 1 + .../syft/src/syft/types/syft_object_registry.py | 14 +++++++------- .../syft/tests/syft/stores/store_mocks_test.py | 6 +++++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_graph.py b/packages/syft/src/syft/service/action/action_graph.py index 3a928da9f0c..6e69bbbb29b 100644 --- a/packages/syft/src/syft/service/action/action_graph.py +++ b/packages/syft/src/syft/service/action/action_graph.py @@ -356,6 +356,7 @@ class ActionGraphStore: @serializable() class InMemoryActionGraphStore(ActionGraphStore): __canonical_name__ = "InMemoryActionGraphStore" + __version__ = SYFT_OBJECT_VERSION_2 def __init__(self, store_config: StoreConfig, reset: bool = False): self.store_config: StoreConfig = store_config diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 40d5f308120..86693741c85 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -61,15 +61,15 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tu print(f"could not find {canonical_name} {version_086} in ObjectRegistry") raise return res + else: + # TODO, add refactoring for non syftobject versions + canonical_name = fqn.split(".")[-1] + version = 1 + return cls.__object_serialization_registry__[canonical_name, version] except Exception as e: print(e) - import ipdb - ipdb.set_trace() - else: - # TODO, add refactoring for non syftobject versions - canonical_name = fqn.split(".")[-1] - version = 1 - return cls.__object_serialization_registry__[canonical_name, version] + # import ipdb + # ipdb.set_trace() @classmethod diff --git a/packages/syft/tests/syft/stores/store_mocks_test.py b/packages/syft/tests/syft/stores/store_mocks_test.py index 39aa2700829..1fa49a61e77 100644 --- a/packages/syft/tests/syft/stores/store_mocks_test.py +++ b/packages/syft/tests/syft/stores/store_mocks_test.py @@ -7,7 +7,7 @@ from syft.store.document_store import PartitionSettings from syft.store.document_store import StoreConfig from syft.store.kv_document_store import KeyValueBackingStore -from syft.types.syft_object import SyftObject +from syft.types.syft_object import SYFT_OBJECT_VERSION_2, SyftObject from syft.types.uid import UID @@ -47,23 +47,27 @@ def __getitem__(self, key: Any) -> Any: @serializable() class MockObjectType(SyftObject): __canonical_name__ = "mock_type" + __version__ = SYFT_OBJECT_VERSION_2 @serializable() class MockStore(DocumentStore): __canonical_name__ = "MockStore" + __version__ = SYFT_OBJECT_VERSION_2 pass @serializable() class MockSyftObject(SyftObject): __canonical_name__ = f"MockSyftObject_{UID()}" + __version__ = SYFT_OBJECT_VERSION_2 data: Any @serializable() class MockStoreConfig(StoreConfig): __canonical_name__ = "MockStoreConfig" + __version__ = SYFT_OBJECT_VERSION_2 store_type: type[DocumentStore] = MockStore db_name: str = "testing" backing_store: type[KeyValueBackingStore] = MockKeyValueBackingStore From d0f1591da3c5e739bf993da34dfa1bd6673e5e4a Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 24 May 2024 17:40:08 +0200 Subject: [PATCH 052/309] fix migrations from 0.8.6 --- .gitignore | 3 ++ .../src/syft/protocol/protocol_version.json | 5 -- packages/syft/src/syft/serde/recursive.py | 7 ++- .../syft/src/syft/service/job/job_service.py | 1 + .../src/syft/service/output/output_service.py | 46 ++++++++++++++++++- .../syft/src/syft/store/document_store.py | 2 +- .../syft/src/syft/store/kv_document_store.py | 35 +++++++++++--- .../src/syft/types/syft_object_registry.py | 16 ++++--- 8 files changed, 92 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index ae0b09e4342..5d59ff4aa58 100644 --- a/.gitignore +++ b/.gitignore @@ -71,6 +71,9 @@ js/node_modules/* #nohup nohup.out +# jupyter lsp +.virtual_documents + # notebook data notebooks/helm/scenario_data.jsonl diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 9aff20c3113..4e16d05ecde 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -61,11 +61,6 @@ } }, "ExecutionOutput": { - "1": { - "version": 1, - "hash": "c2337099eba14767ead75fcc1b1fa265c1898461ede0b5e7758a0e8d11d1757d", - "action": "remove" - }, "2": { "version": 2, "hash": "854fe9df5bcbb5c7e5b7c467bac423cd98c32f93d6876fea7b8eb6c08f6596da", diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index c7f0fdcd02d..41ed13c63d3 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -229,8 +229,11 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild msg = recursive_scheme.new_message() # todo: rewrite and make sure every object has a canonical name and version - canonical_name = SyftObjectRegistry.get_canonical_name(self) - version = getattr(self, "__version__", 1) + canonical_name = SyftObjectRegistry.get_canonical_name(self, is_type=is_type) + if is_type: + version = 1 + else: + version = getattr(self, "__version__", 1) if not SyftObjectRegistry.has_serde_class("", canonical_name, version): # third party diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 323dff99ae9..01e79558b16 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -69,6 +69,7 @@ def get(self, context: AuthedServiceContext, uid: UID) -> Job | SyftError: @service_method( path="job.get_all", name="get_all", + roles=DATA_SCIENTIST_ROLE_LEVEL ) def get_all(self, context: AuthedServiceContext) -> list[Job] | SyftError: res = self.stash.get_all(context.credentials) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 3d32a3e622a..c552a753852 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -1,11 +1,13 @@ # stdlib -from typing import ClassVar +from typing import Callable, ClassVar # third party from pydantic import model_validator from result import Err from result import Ok from result import Result +from syft.types.syft_migration import migrate +from syft.types.transforms import drop, make_set_default # relative from ...client.api import APIRegistry @@ -18,7 +20,7 @@ from ...store.document_store import QueryKeys from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SYFT_OBJECT_VERSION_2 from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID from ...util.telemetry import instrument @@ -37,6 +39,36 @@ OutputPolicyIdPartitionKey = PartitionKey(key="output_policy_id", type_=UID) +@serializable() +class ExecutionOutputV1(SyncableSyftObject): + __canonical_name__ = "ExecutionOutput" + __version__ = SYFT_OBJECT_VERSION_1 + + executing_user_verify_key: SyftVerifyKey + user_code_link: LinkedObject + output_ids: list[UID] | dict[str, UID] | None = None + job_link: LinkedObject | None = None + created_at: DateTime = DateTime.now() + input_ids: dict[str, UID] | None = None + + # Required for __attr_searchable__, set by model_validator + user_code_id: UID + + # Output policy is not a linked object because its saved on the usercode + output_policy_id: UID | None = None + + __attr_searchable__: ClassVar[list[str]] = [ + "user_code_id", + "created_at", + "output_policy_id", + ] + __repr_attrs__: ClassVar[list[str]] = [ + "created_at", + "user_code_id", + "output_ids", + ] + + @serializable() class ExecutionOutput(SyncableSyftObject): __canonical_name__ = "ExecutionOutput" @@ -243,6 +275,16 @@ def get_by_output_policy_id( ) + +@migrate(ExecutionOutputV1, ExecutionOutput) +def upgrade_execution_output() -> list[Callable]: + return [make_set_default("job_id", None)] + +@migrate(ExecutionOutput, ExecutionOutputV1) +def downgrade_execution_output() -> list[Callable]: + return [drop("job_id")] + + @instrument @serializable() class OutputService(AbstractService): diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index fea96e6d456..43e3a7d7e29 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -208,7 +208,7 @@ def all(self) -> tuple[QueryKey, ...] | list[QueryKey]: def from_obj(partition_keys: PartitionKeys, obj: SyftObject) -> QueryKeys: qks = [] for partition_key in partition_keys.all: - pk_key = partition_key.key + pk_key = partition_key.key # name of the attribute pk_type = partition_key.type_ pk_value = getattr(obj, pk_key) # object has a method for getting these types diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index c21c66b42f6..bab9e2251e0 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -26,7 +26,7 @@ from ..service.response import SyftSuccess from ..types.syft_object import SyftObject from ..types.uid import UID -from .document_store import BaseStash +from .document_store import BaseStash, PartitionKeys from .document_store import PartitionKey from .document_store import QueryKey from .document_store import QueryKeys @@ -416,6 +416,7 @@ def _update( obj: SyftObject, has_permission: bool = False, overwrite: bool = False, + allow_missing_keys=False, ) -> Result[SyftObject, str]: try: if qk.value not in self.data: @@ -428,9 +429,20 @@ def _update( _original_unique_keys = self.settings.unique_keys.with_obj( _original_obj ) - _original_searchable_keys = self.settings.searchable_keys.with_obj( - _original_obj - ) + if allow_missing_keys: + searchable_keys = PartitionKeys( + pks=[ + x + for x in self.settings.searchable_keys.all + if hasattr(_original_obj, x.key) + ] + ) + _original_searchable_keys = searchable_keys.with_obj(_original_obj) + + else: + _original_searchable_keys = self.settings.searchable_keys.with_obj( + _original_obj + ) store_query_key = self.settings.store_key.with_obj(_original_obj) @@ -470,6 +482,11 @@ def _update( return Err(f"Failed to update obj {obj}, you have no permission") except Exception as e: + import ipdb + + ipdb.set_trace() + import traceback + print(traceback.format_exc()) return Err(f"Failed to update obj {obj} with error: {e}") def _get_all_from_store( @@ -641,9 +658,9 @@ def _set_data_and_keys( ck_col[pk_value] = store_query_key.value self.unique_keys[pk_key] = ck_col - self.unique_keys[store_query_key.key][store_query_key.value] = ( + self.unique_keys[store_query_key.key][ store_query_key.value - ) + ] = store_query_key.value sqks = searchable_query_keys.all for qk in sqks: @@ -675,8 +692,11 @@ def _migrate_data( migrated_value = value.migrate_to(to_klass.__version__, context) except Exception: import traceback + print(traceback.format_exc()) - return Err(f"Failed to migrate data to {to_klass} for qk {to_klass.__version__}: {key}") + return Err( + f"Failed to migrate data to {to_klass} for qk {to_klass.__version__}: {key}" + ) qk = self.settings.store_key.with_obj(key) result = self._update( credentials, @@ -684,6 +704,7 @@ def _migrate_data( obj=migrated_value, has_permission=has_permission, overwrite=True, + allow_missing_keys=True ) if result.is_err(): diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 40d5f308120..0b89bf89267 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -25,7 +25,11 @@ class SyftObjectRegistry: __object_serialization_registry__: dict[tuple[str, str]: tuple] = {} @classmethod - def get_canonical_name(cls, obj): + def get_canonical_name(cls, obj, is_type: bool): + if is_type: + # TODO: this is different for builtin types, make more generic + return "ModelMetaclass" + res = getattr(obj, "__canonical_name__", None) if res is not None: return res @@ -61,15 +65,15 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tu print(f"could not find {canonical_name} {version_086} in ObjectRegistry") raise return res + else: + # TODO, add refactoring for non syftobject versions + canonical_name = fqn.split(".")[-1] + version = 1 + return cls.__object_serialization_registry__[canonical_name, version] except Exception as e: print(e) import ipdb ipdb.set_trace() - else: - # TODO, add refactoring for non syftobject versions - canonical_name = fqn.split(".")[-1] - version = 1 - return cls.__object_serialization_registry__[canonical_name, version] @classmethod From 57b35b0b0f3d7741c32684d402330ada9218f63d Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 27 May 2024 16:09:30 +0530 Subject: [PATCH 053/309] fix lint --- packages/syft/src/syft/service/network/network_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 447713a6cff..e3a1c827a67 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -287,7 +287,7 @@ def exchange_credentials_with( self.rathole_service.add_host_to_client( peer_name=self_node_peer.name, - peer_id=self_node_peer.id.to_string(), + peer_id=str(self_node_peer.id), rathole_token=self_node_peer.rathole_token, remote_addr=remote_addr, ) From 1fbd8beb9f0c64adefc4a16f123ee395c47f780f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 28 May 2024 18:30:06 +0530 Subject: [PATCH 054/309] move rathole token to http routes - update get and post to patch host headers if rathole token is present - fix stream endpoint to work with rathole --- packages/syft/src/syft/client/client.py | 66 +++++++++++++++---- packages/syft/src/syft/node/routes.py | 18 ++--- .../service/network/association_request.py | 4 +- .../syft/service/network/network_service.py | 37 +++++------ .../src/syft/service/network/node_peer.py | 1 - .../syft/service/network/rathole_service.py | 4 +- .../syft/src/syft/service/network/routes.py | 2 + packages/syft/src/syft/util/util.py | 5 ++ 8 files changed, 90 insertions(+), 47 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index ba4dfc38c80..858822771dc 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -4,6 +4,7 @@ # stdlib import base64 from collections.abc import Callable +from collections.abc import Iterator from copy import deepcopy from enum import Enum from getpass import getpass @@ -48,9 +49,11 @@ from ..service.user.user_service import UserService from ..types.grid_url import GridURL from ..types.syft_object import SYFT_OBJECT_VERSION_2 +from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.uid import UID from ..util.logger import debug from ..util.telemetry import instrument +from ..util.util import generate_token from ..util.util import prompt_warning_message from ..util.util import thread_ident from ..util.util import verify_tls @@ -68,11 +71,6 @@ from ..service.network.node_peer import NodePeer -# use to enable mitm proxy -# from syft.grid.connections.http_connection import HTTPConnection -# HTTPConnection.proxies = {"http": "http://127.0.0.1:8080"} - - def upgrade_tls(url: GridURL, response: Response) -> GridURL: try: if response.url.startswith("https://") and url.protocol == "http": @@ -117,6 +115,7 @@ def forward_message_to_proxy( API_PATH = "/api/v2" DEFAULT_PYGRID_PORT = 80 DEFAULT_PYGRID_ADDRESS = f"http://localhost:{DEFAULT_PYGRID_PORT}" +INTERNAL_PROXY_URL = "http://proxy:80" class Routes(Enum): @@ -129,15 +128,16 @@ class Routes(Enum): STREAM = f"{API_PATH}/stream" -@serializable(attrs=["proxy_target_uid", "url"]) +@serializable(attrs=["proxy_target_uid", "url", "rathole_token"]) class HTTPConnection(NodeConnection): __canonical_name__ = "HTTPConnection" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 url: GridURL proxy_target_uid: UID | None = None routes: type[Routes] = Routes session_cache: Session | None = None + rathole_token: str | None = None @field_validator("url", mode="before") @classmethod @@ -149,7 +149,11 @@ def make_url(cls, v: Any) -> Any: ) def with_proxy(self, proxy_target_uid: UID) -> Self: - return HTTPConnection(url=self.url, proxy_target_uid=proxy_target_uid) + return HTTPConnection( + url=self.url, + proxy_target_uid=proxy_target_uid, + rathole_token=self.rathole_token, + ) def stream_via(self, proxy_uid: UID, url_path: str) -> GridURL: # Update the presigned url path to @@ -182,10 +186,24 @@ def session(self) -> Session: self.session_cache = session return self.session_cache - def _make_get(self, path: str, params: dict | None = None) -> bytes: - url = self.url.with_path(path) + def _make_get( + self, path: str, params: dict | None = None, stream: bool = False + ) -> bytes | Iterator[Any]: + headers = {} + url = self.url + + if self.rathole_token: + url = GridURL.from_url(INTERNAL_PROXY_URL) + headers = {"Host": self.host_or_ip} + + url = url.with_path(path) response = self.session.get( - str(url), verify=verify_tls(), proxies={}, params=params + str(url), + verify=verify_tls(), + proxies={}, + params=params, + headers=headers, + stream=stream, ) if response.status_code != 200: raise requests.ConnectionError( @@ -195,6 +213,9 @@ def _make_get(self, path: str, params: dict | None = None) -> bytes: # upgrade to tls if available self.url = upgrade_tls(self.url, response) + if stream: + return response.iter_content(chunk_size=None) + return response.content def _make_post( @@ -203,9 +224,21 @@ def _make_post( json: dict[str, Any] | None = None, data: bytes | None = None, ) -> bytes: - url = self.url.with_path(path) + headers = {} + url = self.url + + if self.rathole_token: + url = GridURL.from_url(INTERNAL_PROXY_URL) + headers = {"Host": self.host_or_ip} + + url = url.with_path(path) response = self.session.post( - str(url), verify=verify_tls(), json=json, proxies={}, data=data + str(url), + verify=verify_tls(), + json=json, + proxies={}, + data=data, + headers=headers, ) if response.status_code != 200: raise requests.ConnectionError( @@ -683,7 +716,10 @@ def guest(self) -> Self: ) def exchange_route( - self, client: Self, protocol: SyftProtocol = SyftProtocol.HTTP + self, + client: Self, + protocol: SyftProtocol = SyftProtocol.HTTP, + reverse_tunnel: bool = False, ) -> SyftSuccess | SyftError: # relative from ..service.network.routes import connection_to_route @@ -694,6 +730,8 @@ def exchange_route( if client.metadata is None: return SyftError(f"client {client}'s metadata is None!") + self_node_route.rathole_token = generate_token() if reverse_tunnel else None + return self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 5b25774ff18..f32aa5c3bd9 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -18,6 +18,7 @@ # relative from ..abstract_node import AbstractNode +from ..client.connection import NodeConnection from ..protocol.data_protocol import PROTOCOL_TYPE from ..serde.deserialize import _deserialize as deserialize from ..serde.serialize import _serialize as serialize @@ -50,7 +51,7 @@ def make_routes(worker: Worker) -> APIRouter: async def get_body(request: Request) -> bytes: return await request.body() - def _blob_url(peer_uid: UID, presigned_url: str) -> str: + def _get_node_connection(peer_uid: UID) -> NodeConnection: # relative from ..service.network.node_peer import route_to_connection @@ -58,9 +59,7 @@ def _blob_url(peer_uid: UID, presigned_url: str) -> str: peer = network_service.stash.get_by_uid(worker.verify_key, peer_uid).ok() peer_node_route = peer.pick_highest_priority_route() connection = route_to_connection(route=peer_node_route) - url = connection.to_blob_route(presigned_url) - - return str(url) + return connection @router.get("/stream/{peer_uid}/{url_path}/", name="stream") async def stream(peer_uid: str, url_path: str) -> StreamingResponse: @@ -71,17 +70,14 @@ async def stream(peer_uid: str, url_path: str) -> StreamingResponse: peer_uid_parsed = UID.from_string(peer_uid) - url = _blob_url(peer_uid=peer_uid_parsed, presigned_url=url_path_parsed) - try: - resp = requests.get(url=url, stream=True) # nosec - resp.raise_for_status() + peer_connection = _get_node_connection(peer_uid_parsed) + url = peer_connection.to_blob_route(url_path_parsed) + stream_response = peer_connection._make_get(url.path, stream=True) except requests.RequestException: raise HTTPException(404, "Failed to retrieve data from domain.") - return StreamingResponse( - resp.iter_content(chunk_size=None), media_type="text/event-stream" - ) + return StreamingResponse(stream_response, media_type="text/event-stream") @router.get( "/", diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index e963aec638e..ede16865685 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -44,7 +44,9 @@ def _run( service_ctx = context.to_service_ctx() - if self.remote_peer.rathole_token is None: + highest_route = self.remote_peer.pick_highest_priority_route() + + if highest_route.rathole_token is None: try: remote_client: SyftClient = self.remote_peer.client_with_context( context=service_ctx diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index e3a1c827a67..ec5d1c361b5 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -1,7 +1,6 @@ # stdlib from collections.abc import Callable from enum import Enum -from hashlib import sha256 import secrets from typing import Any @@ -186,9 +185,6 @@ def exchange_credentials_with( ) remote_node_peer = NodePeer.from_client(remote_client) - rathole_token = self._generate_token() - self_node_peer.rathole_token = rathole_token - # ask the remote client to add this node (represented by `self_node_peer`) as a peer # check locally if the remote node already exists as a peer existing_peer_result = self.stash.get_by_uid( @@ -278,19 +274,20 @@ def exchange_credentials_with( if result.is_err(): return SyftError(message="Failed to update route information.") - remote_url = GridURL( - host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port - ) - rathole_remote_addr = remote_url.as_container_host() + if self_node_peer.rathole_token: + remote_url = GridURL( + host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port + ) + rathole_remote_addr = remote_url.as_container_host() - remote_addr = rathole_remote_addr.url_no_protocol + remote_addr = rathole_remote_addr.url_no_protocol - self.rathole_service.add_host_to_client( - peer_name=self_node_peer.name, - peer_id=str(self_node_peer.id), - rathole_token=self_node_peer.rathole_token, - remote_addr=remote_addr, - ) + self.rathole_service.add_host_to_client( + peer_name=self_node_peer.name, + peer_id=str(self_node_peer.id), + rathole_token=self_node_peer.rathole_token, + remote_addr=remote_addr, + ) return ( SyftSuccess(message="Routes Exchanged") @@ -298,9 +295,6 @@ def exchange_credentials_with( else remote_res ) - def _generate_token(self) -> str: - return sha256(secrets.token_bytes(16)).hexdigest() - @service_method(path="network.add_peer", name="add_peer", roles=GUEST_ROLE_LEVEL) def add_peer( self, @@ -940,6 +934,7 @@ def from_grid_url(context: TransformContext) -> TransformContext: context.output["private"] = False context.output["proxy_target_uid"] = context.obj.proxy_target_uid context.output["priority"] = 1 + context.output["rathole_token"] = context.obj.rathole_token return context @@ -976,7 +971,11 @@ def node_route_to_http_connection( url = GridURL( protocol=obj.protocol, host_or_ip=obj.host_or_ip, port=obj.port ).as_container_host() - return HTTPConnection(url=url, proxy_target_uid=obj.proxy_target_uid) + return HTTPConnection( + url=url, + proxy_target_uid=obj.proxy_target_uid, + rathole_token=obj.rathole_token, + ) @transform(NodeMetadataV3, NodePeer) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index a6cb6d3eea3..35292dd89dd 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -85,7 +85,6 @@ class NodePeer(SyftObject): ping_status: NodePeerConnectionStatus | None = None ping_status_message: str | None = None pinged_timestamp: DateTime | None = None - rathole_token: str | None = None def existed_route( self, route: NodeRouteType | None = None, route_id: UID | None = None diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 2051035dd6a..67d4f78eddd 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -35,13 +35,15 @@ def add_host_to_server(self, peer: NodePeer) -> None: None """ + rathole_route = peer.pick_highest_priority_route() + random_port = self.get_random_port() peer_id = cast(UID, peer.id) config = RatholeConfig( uuid=peer_id.to_string(), - secret_token=peer.rathole_token, + secret_token=rathole_route.rathole_token, local_addr_host="0.0.0.0", local_addr_port=random_port, server_name=peer.name, diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index f3fa9b1ad1a..9e49f82c689 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -95,6 +95,7 @@ class HTTPNodeRoute(SyftObject, NodeRoute): port: int = 80 proxy_target_uid: UID | None = None priority: int = 1 + rathole_token: str | None = None def __eq__(self, other: Any) -> bool: if not isinstance(other, HTTPNodeRoute): @@ -107,6 +108,7 @@ def __hash__(self) -> int: + hash(self.port) + hash(self.protocol) + hash(self.proxy_target_uid) + + hash(self.rathole_token) ) def __str__(self) -> str: diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index b0affa2b1a0..a93c2ba8fe1 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -21,6 +21,7 @@ import platform import random import re +import secrets from secrets import randbelow import socket import sys @@ -919,3 +920,7 @@ def get_queue_address(port: int) -> str: def get_dev_mode() -> bool: return str_to_bool(os.getenv("DEV_MODE", "False")) + + +def generate_token() -> str: + return hashlib.sha256(secrets.token_bytes(16)).hexdigest() From 9c7983cbcacfb4352d0fe0f479825079c2d00b74 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 29 May 2024 00:49:25 +0530 Subject: [PATCH 055/309] fix passing host name in case of rathole connection - pass reverse tunnel via connect_to_gateway - add method to get rathole route --- packages/syft/src/syft/client/client.py | 21 +++++++++++----- .../syft/src/syft/client/domain_client.py | 7 +++++- .../src/syft/protocol/protocol_version.json | 24 +++++++++++++++++++ .../service/network/association_request.py | 4 ++-- .../syft/service/network/network_service.py | 12 ++++++++-- .../src/syft/service/network/node_peer.py | 6 +++++ .../syft/service/network/rathole_service.py | 2 +- .../syft/src/syft/service/network/routes.py | 3 ++- 8 files changed, 66 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 858822771dc..28e975efa03 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -53,7 +53,6 @@ from ..types.uid import UID from ..util.logger import debug from ..util.telemetry import instrument -from ..util.util import generate_token from ..util.util import prompt_warning_message from ..util.util import thread_ident from ..util.util import verify_tls @@ -194,7 +193,7 @@ def _make_get( if self.rathole_token: url = GridURL.from_url(INTERNAL_PROXY_URL) - headers = {"Host": self.host_or_ip} + headers = {"Host": self.url.host_or_ip} url = url.with_path(path) response = self.session.get( @@ -229,7 +228,7 @@ def _make_post( if self.rathole_token: url = GridURL.from_url(INTERNAL_PROXY_URL) - headers = {"Host": self.host_or_ip} + headers = {"Host": self.url.host_or_ip} url = url.with_path(path) response = self.session.post( @@ -336,9 +335,20 @@ def register(self, new_user: UserCreate) -> SyftSigningKey: def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError: msg_bytes: bytes = _serialize(obj=signed_call, to_bytes=True) + + headers = {} + + if self.rathole_token: + api_url = GridURL.from_url(INTERNAL_PROXY_URL) + api_url = api_url.with_path(self.routes.ROUTE_API_CALL.value) + headers = {"Host": self.url.host_or_ip} + else: + api_url = self.api_url + response = requests.post( # nosec - url=str(self.api_url), + url=api_url, data=msg_bytes, + headers=headers, ) if response.status_code != 200: @@ -730,12 +740,11 @@ def exchange_route( if client.metadata is None: return SyftError(f"client {client}'s metadata is None!") - self_node_route.rathole_token = generate_token() if reverse_tunnel else None - return self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, remote_node_verify_key=client.metadata.to(NodeMetadataV3).verify_key, + reverse_tunnel=reverse_tunnel, ) else: raise ValueError( diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 8f1e7cb9dc8..6428fd7b851 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -288,6 +288,7 @@ def connect_to_gateway( email: str | None = None, password: str | None = None, protocol: str | SyftProtocol = SyftProtocol.HTTP, + reverse_tunnel: bool = False, ) -> SyftSuccess | SyftError | None: if isinstance(protocol, str): protocol = SyftProtocol(protocol) @@ -305,7 +306,11 @@ def connect_to_gateway( if isinstance(client, SyftError): return client - res = self.exchange_route(client, protocol=protocol) + res = self.exchange_route( + client, + protocol=protocol, + reverse_tunnel=reverse_tunnel, + ) if isinstance(res, SyftSuccess): if self.metadata: return SyftSuccess( diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index b4050ab030f..04e7b652bcb 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -198,6 +198,30 @@ "hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995", "action": "add" } + }, + "HTTPConnection": { + "2": { + "version": 2, + "hash": "68409295f8916ceb22a8cf4abf89f5e4bcff0d75dc37e16ede37250ada28df59", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "5e363abe2875beec89a3f4f4f5c53e15f9893fb98e5da71e2fa6c0f619883b1f", + "action": "add" + } + }, + "HTTPNodeRoute": { + "2": { + "version": 2, + "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "89ace8067c392b802fe23a99446a8ae464a9dad0b49d8b2c3871b631451acec4", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index ede16865685..2590b3b42fe 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -44,9 +44,9 @@ def _run( service_ctx = context.to_service_ctx() - highest_route = self.remote_peer.pick_highest_priority_route() + rathole_route = self.remote_peer.get_rathole_route() - if highest_route.rathole_token is None: + if rathole_route.rathole_token is None: try: remote_client: SyftClient = self.remote_peer.client_with_context( context=service_ctx diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index ec5d1c361b5..de560747ac6 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -30,6 +30,7 @@ from ...types.transforms import transform_method from ...types.uid import UID from ...util.telemetry import instrument +from ...util.util import generate_token from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey @@ -166,6 +167,7 @@ def exchange_credentials_with( self_node_route: NodeRoute, remote_node_route: NodeRoute, remote_node_verify_key: SyftVerifyKey, + reverse_tunnel: bool = False, ) -> Request | SyftSuccess | SyftError: """ Exchange Route With Another Node. If there is a pending association request, return it @@ -174,6 +176,11 @@ def exchange_credentials_with( # Step 1: Validate the Route self_node_peer = self_node_route.validate_with_context(context=context) + if reverse_tunnel: + _rathole_route = self_node_peer.node_routes[-1] + _rathole_route.rathole_token = generate_token() + _rathole_route.host_or_ip = f"{self_node_peer.name}.syft.local" + if isinstance(self_node_peer, SyftError): return self_node_peer @@ -274,7 +281,8 @@ def exchange_credentials_with( if result.is_err(): return SyftError(message="Failed to update route information.") - if self_node_peer.rathole_token: + if reverse_tunnel: + rathole_route = self_node_peer.get_rathole_route() remote_url = GridURL( host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port ) @@ -285,7 +293,7 @@ def exchange_credentials_with( self.rathole_service.add_host_to_client( peer_name=self_node_peer.name, peer_id=str(self_node_peer.id), - rathole_token=self_node_peer.rathole_token, + rathole_token=rathole_route.rathole_token, remote_addr=remote_addr, ) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 35292dd89dd..e6ac045e9f1 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -269,6 +269,12 @@ def pick_highest_priority_route(self) -> NodeRoute: highest_priority_route = route return highest_priority_route + def get_rathole_route(self) -> NodeRoute | None: + for route in self.node_routes: + if hasattr(route, "rathole_token") and route.rathole_token: + return route + return None + def delete_route( self, route: NodeRouteType | None = None, route_id: UID | None = None ) -> SyftError | None: diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 67d4f78eddd..ad5e783c7dd 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -35,7 +35,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: None """ - rathole_route = peer.pick_highest_priority_route() + rathole_route = peer.get_rathole_route() random_port = self.get_random_port() diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 9e49f82c689..1d9ec116467 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -19,6 +19,7 @@ from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.uid import UID @@ -87,7 +88,7 @@ def validate_with_context( @serializable() class HTTPNodeRoute(SyftObject, NodeRoute): __canonical_name__ = "HTTPNodeRoute" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 host_or_ip: str private: bool = False From 2b7172e20a9025bdf240feb02bb0a54e03d1698d Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Sun, 26 May 2024 22:35:23 +0200 Subject: [PATCH 056/309] add widget --- .../syft/assets/css/tabulator_pysyft.min.css | 1655 ++++++++++++++++- .../syft/src/syft/assets/jinja/table.jinja2 | 6 +- packages/syft/src/syft/assets/js/table.js | 76 +- .../syft/src/syft/service/sync/diff_state.py | 7 + .../src/syft/service/sync/resolve_widget.py | 160 ++ .../components/tabulator_template.py | 21 +- 6 files changed, 1913 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/assets/css/tabulator_pysyft.min.css b/packages/syft/src/syft/assets/css/tabulator_pysyft.min.css index f474df40562..fde1ee7edc8 100644 --- a/packages/syft/src/syft/assets/css/tabulator_pysyft.min.css +++ b/packages/syft/src/syft/assets/css/tabulator_pysyft.min.css @@ -1,6 +1,1651 @@ -:root{--tabulator-background-color:#fff;--tabulator-border-color:rgba(0,0,0,.12);--tabulator-text-size:16px;--tabulator-header-background-color:#f5f5f5;--tabulator-header-text-color:#555;--tabulator-header-border-color:rgba(0,0,0,.12);--tabulator-header-separator-color:rgba(0,0,0,.12);--tabulator-header-margin:4px;--tabulator-sort-arrow-hover:#555;--tabulator-sort-arrow-active:#666;--tabulator-sort-arrow-inactive:#bbb;--tabulator-column-resize-guide-color:#999;--tabulator-row-background-color:#fff;--tabulator-row-alt-background-color:#f8f8f8;--tabulator-row-border-color:rgba(0,0,0,.12);--tabulator-row-text-color:#333;--tabulator-row-hover-background:#e1f5fe;--tabulator-row-selected-background:#17161d;--tabulator-row-selected-background-hover:#17161d;--tabulator-edit-box-color:#17161d;--tabulator-error-color:#d00;--tabulator-footer-background-color:transparent;--tabulator-footer-text-color:#555;--tabulator-footer-border-color:rgba(0,0,0,.12);--tabulator-footer-separator-color:rgba(0,0,0,.12);--tabulator-footer-active-color:#17161d;--tabulator-spreadsheet-active-tab-color:#fff;--tabulator-range-border-color:#17161d;--tabulator-range-handle-color:#17161d;--tabulator-range-header-selected-background:var( - --tabulator-range-border-color - );--tabulator-range-header-selected-text-color:#fff;--tabulator-range-header-highlight-background:colors-gray-timberwolf;--tabulator-range-header-text-highlight-background:#fff;--tabulator-pagination-button-background:#fff;--tabulator-pagination-button-background-hover:#06c;--tabulator-pagination-button-color:#999;--tabulator-pagination-button-color-hover:#fff;--tabulator-pagination-button-color-active:#000;--tabulator-cell-padding:15px}body.vscode-dark,body[data-jp-theme-light=false]{--tabulator-background-color:#080808;--tabulator-border-color:#666;--tabulator-text-size:16px;--tabulator-header-background-color:#212121;--tabulator-header-text-color:#555;--tabulator-header-border-color:#666;--tabulator-header-separator-color:#666;--tabulator-header-margin:4px;--tabulator-sort-arrow-hover:#fff;--tabulator-sort-arrow-active:#e6e6e6;--tabulator-sort-arrow-inactive:#666;--tabulator-column-resize-guide-color:#999;--tabulator-row-background-color:#080808;--tabulator-row-alt-background-color:#212121;--tabulator-row-border-color:#666;--tabulator-row-text-color:#f8f8f8;--tabulator-row-hover-background:#333;--tabulator-row-selected-background:#241e1e;--tabulator-row-selected-background-hover:#333;--tabulator-edit-box-color:#333;--tabulator-error-color:#d00;--tabulator-footer-background-color:transparent;--tabulator-footer-text-color:#555;--tabulator-footer-border-color:rgba(0,0,0,.12);--tabulator-footer-separator-color:rgba(0,0,0,.12);--tabulator-footer-active-color:#17161d;--tabulator-spreadsheet-active-tab-color:#fff;--tabulator-range-border-color:#17161d;--tabulator-range-handle-color:var(--tabulator-range-border-color);--tabulator-range-header-selected-background:var( - --tabulator-range-border-color - );--tabulator-range-header-selected-text-color:#fff;--tabulator-range-header-highlight-background:#d6d6d6;--tabulator-range-header-text-highlight-background:#fff;--tabulator-pagination-button-background:#212121;--tabulator-pagination-button-background-hover:#555;--tabulator-pagination-button-color:#999;--tabulator-pagination-button-color-hover:#fff;--tabulator-pagination-button-color-active:#fff;--tabulator-cell-padding:15px}.tabulator{border:1px solid var(--tabulator-border-color);font-size:var(--tabulator-text-size);overflow:hidden;position:relative;text-align:left;-webkit-transform:translateZ(0);-moz-transform:translateZ(0);-ms-transform:translateZ(0);-o-transform:translateZ(0);transform:translateZ(0)}.tabulator[tabulator-layout=fitDataFill] .tabulator-tableholder .tabulator-table{min-width:100%}.tabulator[tabulator-layout=fitDataTable]{display:inline-block}.tabulator.tabulator-block-select,.tabulator.tabulator-ranges .tabulator-cell:not(.tabulator-editing){user-select:none}.tabulator .tabulator-header{background-color:var(--tabulator-header-background-color);border-bottom:1px solid var(--tabulator-header-separator-color);box-sizing:border-box;color:var(--tabulator-header-text-color);font-weight:700;outline:none;overflow:hidden;position:relative;-moz-user-select:none;-khtml-user-select:none;-webkit-user-select:none;-o-user-select:none;white-space:nowrap;width:100%}.tabulator .tabulator-header.tabulator-header-hidden{display:none}.tabulator .tabulator-header .tabulator-header-contents{overflow:hidden;position:relative}.tabulator .tabulator-header .tabulator-header-contents .tabulator-headers{display:inline-block}.tabulator .tabulator-header .tabulator-col{background:var(--tabulator-header-background-color);border-right:1px solid var(--tabulator-header-border-color);box-sizing:border-box;display:inline-flex;flex-direction:column;justify-content:flex-start;overflow:hidden;position:relative;text-align:left;vertical-align:bottom}.tabulator .tabulator-header .tabulator-col.tabulator-moving{background:hsl(var(--tabulator-header-background-color),calc(var(--tabulator-header-background-color) - 5%))!important;border:1px solid var(--tabulator-header-separator-color);pointer-events:none;position:absolute}.tabulator .tabulator-header .tabulator-col.tabulator-range-highlight{background-color:var(--tabulator-range-header-highlight-background);color:var(--tabulator-range-header-text-highlight-background)}.tabulator .tabulator-header .tabulator-col.tabulator-range-selected{background-color:var(--tabulator-range-header-selected-background);color:var(--tabulator-range-header-selected-text-color)}.tabulator .tabulator-header .tabulator-col .tabulator-col-content{box-sizing:border-box;padding:4px;position:relative}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-header-popup-button{padding:0 8px}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-header-popup-button:hover{cursor:pointer;opacity:.6}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title-holder{position:relative}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title{box-sizing:border-box;overflow:hidden;text-overflow:ellipsis;vertical-align:bottom;white-space:nowrap;width:100%}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title.tabulator-col-title-wrap{text-overflow:clip;white-space:normal}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title .tabulator-title-editor{background:#fff;border:1px solid #999;box-sizing:border-box;padding:1px;width:100%}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title .tabulator-header-popup-button+.tabulator-title-editor{width:calc(100% - 22px)}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter{align-items:center;bottom:0;display:flex;position:absolute;right:4px;top:0}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter .tabulator-arrow{border-bottom:6px solid var(--tabulator-sort-arrow-inactive);border-left:6px solid transparent;border-right:6px solid transparent;height:0;width:0}.tabulator .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols{border-top:1px solid var(--tabulator-header-border-color);display:flex;margin-right:-1px;overflow:hidden;position:relative}.tabulator .tabulator-header .tabulator-col .tabulator-header-filter{box-sizing:border-box;margin-top:2px;position:relative;text-align:center;width:100%}.tabulator .tabulator-header .tabulator-col .tabulator-header-filter textarea{height:auto!important}.tabulator .tabulator-header .tabulator-col .tabulator-header-filter svg{margin-top:3px}.tabulator .tabulator-header .tabulator-col .tabulator-header-filter input::-ms-clear{height:0;width:0}.tabulator .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title{padding-right:25px}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-header .tabulator-col.tabulator-sortable.tabulator-col-sorter-element:hover{background-color:hsl(var(--tabulator-header-background-color),calc(var(--tabulator-header-background-color) - 10%))!important;cursor:pointer}}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter{color:var(--tabulator-sort-arrow-inactive)}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover{border-bottom:6px solid var(--tabulator-sort-arrow-hover);cursor:pointer}}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow{border-bottom:6px solid var(--tabulator-sort-arrow-inactive);border-top:none}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter{color:var(--tabulator-sort-arrow-active)}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover{border-bottom:6px solid var(--tabulator-sort-arrow-hover);cursor:pointer}}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow{border-bottom:6px solid var(--tabulator-sort-arrow-active);border-top:none}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter{color:var(--tabulator-sort-arrow-active)}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover{border-top:6px solid var(--tabulator-sort-arrow-hover);cursor:pointer}}.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow{border-bottom:none;border-top:6px solid var(--tabulator-sort-arrow-active);color:var(--tabulator-sort-arrow-active)}.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical .tabulator-col-content .tabulator-col-title{align-items:center;display:flex;justify-content:center;text-orientation:mixed;writing-mode:vertical-rl}.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-col-vertical-flip .tabulator-col-title{transform:rotate(180deg)}.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable .tabulator-col-title{padding-right:0;padding-top:20px}.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable.tabulator-col-vertical-flip .tabulator-col-title{padding-bottom:20px;padding-right:0}.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable .tabulator-col-sorter{bottom:auto;justify-content:center;left:0;right:0;top:4px}.tabulator .tabulator-header .tabulator-frozen{left:0;position:sticky;z-index:11}.tabulator .tabulator-header .tabulator-frozen.tabulator-frozen-left{border-right:2px solid var(--tabulator-row-border-color)}.tabulator .tabulator-header .tabulator-frozen.tabulator-frozen-right{border-left:2px solid var(--tabulator-row-border-color)}.tabulator .tabulator-header .tabulator-calcs-holder{border-bottom:1px solid var(--tabulator-header-border-color);border-top:1px solid var(--tabulator-row-border-color);box-sizing:border-box;display:inline-block}.tabulator .tabulator-header .tabulator-calcs-holder,.tabulator .tabulator-header .tabulator-calcs-holder .tabulator-row{background:hsl(var(--tabulator-header-background-color),calc(var(--tabulator-header-background-color) + 5%))!important}.tabulator .tabulator-header .tabulator-calcs-holder .tabulator-row .tabulator-col-resize-handle{display:none}.tabulator .tabulator-header .tabulator-frozen-rows-holder{display:inline-block}.tabulator .tabulator-tableholder{-webkit-overflow-scrolling:touch;overflow:auto;position:relative;white-space:nowrap;width:100%}.tabulator .tabulator-tableholder:focus{outline:none}.tabulator .tabulator-tableholder .tabulator-placeholder{align-items:center;box-sizing:border-box;display:flex;justify-content:center;min-width:100%;width:100%}.tabulator .tabulator-tableholder .tabulator-placeholder[tabulator-render-mode=virtual]{min-height:100%}.tabulator .tabulator-tableholder .tabulator-placeholder .tabulator-placeholder-contents{color:#ccc;display:inline-block;font-size:20px;font-weight:700;padding:10px;text-align:center;white-space:normal}.tabulator .tabulator-tableholder .tabulator-table{background-color:var(--tabulator-row-background-color);color:var(--tabulator-row-text-color);display:inline-block;overflow:visible;position:relative;white-space:nowrap}.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs{background:hsl(var(--tabulator-row-atl-background-color),calc(var(--tabulator-row-alt-background-color) - 5%))!important;font-weight:700}.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs.tabulator-calcs-top{border-bottom:2px solid var(--tabulator-row-border-color)}.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs.tabulator-calcs-bottom{border-top:2px solid var(--tabulator-row-border-color)}.tabulator .tabulator-tableholder .tabulator-range-overlay{inset:0;pointer-events:none;position:absolute;z-index:10}.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range{border:1px solid var(--tabulator-range-border-color);box-sizing:border-box;position:absolute}.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range.tabulator-range-active:after{background-color:var(--tabulator-range-handle-color);border-radius:999px;bottom:-3px;content:"";height:6px;position:absolute;right:-3px;width:6px}.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range-cell-active{border:2px solid var(--tabulator-range-border-color);box-sizing:border-box;position:absolute}.tabulator .tabulator-footer{border-top:1px solid var(--tabulator-footer-separator-color);color:var(--tabulator-footer-text-color);font-weight:700;user-select:none;-moz-user-select:none;-khtml-user-select:none;-webkit-user-select:none;-o-user-select:none;white-space:nowrap}.tabulator .tabulator-footer .tabulator-footer-contents{align-items:center;display:flex;flex-direction:row;justify-content:space-between;padding:5px 10px}.tabulator .tabulator-footer .tabulator-footer-contents:empty{display:none}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs{margin-top:-5px;overflow-x:auto}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab{border:1px solid var(--tabulator-border-color);border-bottom-left-radius:5px;border-bottom-right-radius:5px;border-top:none;display:inline-block;font-size:.9em;padding:5px}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab:hover{cursor:pointer;opacity:.7}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab.tabulator-spreadsheet-tab-active{background:var(--tabulator-spreadsheet-active-tab-color)}.tabulator .tabulator-footer .tabulator-calcs-holder{border-bottom:1px solid var(--tabulator-row-border-color);border-top:1px solid var(--tabulator-row-border-color);box-sizing:border-box;overflow:hidden;text-align:left;width:100%}.tabulator .tabulator-footer .tabulator-calcs-holder .tabulator-row{display:inline-block}.tabulator .tabulator-footer .tabulator-calcs-holder .tabulator-row .tabulator-col-resize-handle{display:none}.tabulator .tabulator-footer .tabulator-calcs-holder:only-child{border-bottom:none;margin-bottom:-5px}.tabulator .tabulator-footer>*+.tabulator-page-counter{margin-left:10px}.tabulator .tabulator-footer .tabulator-page-counter{font-weight:400}.tabulator .tabulator-footer .tabulator-paginator{color:var(--tabulator-footer-text-color);flex:1;font-family:inherit;font-size:inherit;font-weight:inherit;text-align:right}.tabulator .tabulator-footer .tabulator-page-size{border:1px solid var(--tabulator-footer-border-color);border-radius:3px;display:inline-block;margin:0 5px;padding:2px 5px}.tabulator .tabulator-footer .tabulator-pages{margin:0 7px}.tabulator .tabulator-footer .tabulator-page{background:hsla(0,0%,100%,.2);border:1px solid var(--tabulator-footer-border-color);border-radius:3px;display:inline-block;margin:0 2px;padding:2px 5px}.tabulator .tabulator-footer .tabulator-page.active{color:var(--tabulator-footer-active-color)}.tabulator .tabulator-footer .tabulator-page:disabled{opacity:.5}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-footer .tabulator-page:not(disabled):hover{background:rgba(0,0,0,.2);color:#fff;cursor:pointer}}.tabulator .tabulator-col-resize-handle{display:inline-block;margin-left:-3px;margin-right:-3px;position:relative;vertical-align:middle;width:6px;z-index:11}@media (hover:hover) and (pointer:fine){.tabulator .tabulator-col-resize-handle:hover{cursor:ew-resize}}.tabulator .tabulator-col-resize-handle:last-of-type{margin-right:0;width:3px}.tabulator .tabulator-col-resize-guide{height:100%;margin-left:-.5px;top:0;width:4px}.tabulator .tabulator-col-resize-guide,.tabulator .tabulator-row-resize-guide{background-color:var(--tabulator-column-resize-guide-color);opacity:.5;position:absolute}.tabulator .tabulator-row-resize-guide{height:4px;left:0;margin-top:-.5px;width:100%}.tabulator .tabulator-alert{align-items:center;background:rgba(0,0,0,.4);display:flex;height:100%;left:0;position:absolute;text-align:center;top:0;width:100%;z-index:100}.tabulator .tabulator-alert .tabulator-alert-msg{background:#fff;border-radius:10px;display:inline-block;font-size:16px;font-weight:700;margin:0 auto;padding:10px 20px}.tabulator .tabulator-alert .tabulator-alert-msg.tabulator-alert-state-msg{border:4px solid #333;color:#000}.tabulator .tabulator-alert .tabulator-alert-msg.tabulator-alert-state-error{border:4px solid #d00;color:#590000}.tabulator-row{background-color:var(--tabulator-row-background-color);box-sizing:border-box;min-height:calc(var(--tabulator-text-size) + var(--tabulator-header-margin)*2);position:relative}.tabulator-row.tabulator-row-even{background-color:var(--tabulator-row-alt-background-color)}@media (hover:hover) and (pointer:fine){.tabulator-row.tabulator-selectable:hover{background-color:var(--tabulator-row-hover-background);cursor:pointer}}.tabulator-row.tabulator-selected{background-color:var(--tabulator-row-selected-background)}@media (hover:hover) and (pointer:fine){.tabulator-row.tabulator-selected:hover{background-color:var(--tabulator-row-selected-background-hover);cursor:pointer}}.tabulator-row.tabulator-row-moving{background:#fff;border:1px solid #000}.tabulator-row.tabulator-moving{border-bottom:1px solid var(--tabulator-row-border-color);border-top:1px solid var(--tabulator-row-border-color);pointer-events:none;position:absolute;z-index:15}.tabulator-row.tabulator-range-highlight .tabulator-cell.tabulator-range-row-header{background-color:var(--tabulator-range-header-highlight-background);color:var(--tabulator-range-header-text-highlight-background)}.tabulator-row.tabulator-range-highlight.tabulator-range-selected .tabulator-cell.tabulator-range-row-header,.tabulator-row.tabulator-range-selected .tabulator-cell.tabulator-range-row-header{background-color:var(--tabulator-range-header-selected-background);color:var(--tabulator-range-header-selected-text-color)}.tabulator-row .tabulator-row-resize-handle{bottom:0;height:5px;left:0;position:absolute;right:0}.tabulator-row .tabulator-row-resize-handle.prev{bottom:auto;top:0}@media (hover:hover) and (pointer:fine){.tabulator-row .tabulator-row-resize-handle:hover{cursor:ns-resize}}.tabulator-row .tabulator-responsive-collapse{border-bottom:1px solid var(--tabulator-row-border-color);border-top:1px solid var(--tabulator-row-border-color);box-sizing:border-box;padding:5px}.tabulator-row .tabulator-responsive-collapse:empty{display:none}.tabulator-row .tabulator-responsive-collapse table{font-size:var(--tabulator-text-size)}.tabulator-row .tabulator-responsive-collapse table tr td{position:relative}.tabulator-row .tabulator-responsive-collapse table tr td:first-of-type{padding-right:10px}.tabulator-row .tabulator-cell{border-right:1px solid var(--tabulator-row-border-color);box-sizing:border-box;display:inline-block;outline:none;overflow:hidden;padding:4px;position:relative;text-overflow:ellipsis;vertical-align:middle;white-space:nowrap}.tabulator-row .tabulator-cell.tabulator-row-header{border-bottom:1px solid var(--tabulator-row-border-color)}.tabulator-row .tabulator-cell.tabulator-frozen{background-color:inherit;display:inline-block;left:0;position:sticky;z-index:11}.tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-left{border-right:2px solid var(--tabulator-row-border-color)}.tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-right{border-left:2px solid var(--tabulator-row-border-color)}.tabulator-row .tabulator-cell.tabulator-editing{border:1px solid var(--tabulator-edit-box-color);outline:none;padding:0}.tabulator-row .tabulator-cell.tabulator-editing input,.tabulator-row .tabulator-cell.tabulator-editing select{background:transparent;border:1px;outline:none}.tabulator-row .tabulator-cell.tabulator-validation-fail{border:1px solid var(--tabulator-error-color)}.tabulator-row .tabulator-cell.tabulator-validation-fail input,.tabulator-row .tabulator-cell.tabulator-validation-fail select{background:transparent;border:1px;color:var(--tabulator-error-color)}.tabulator-row .tabulator-cell.tabulator-row-handle{align-items:center;display:inline-flex;justify-content:center;-moz-user-select:none;-khtml-user-select:none;-webkit-user-select:none;-o-user-select:none}.tabulator-row .tabulator-cell.tabulator-row-handle .tabulator-row-handle-box{width:80%}.tabulator-row .tabulator-cell.tabulator-row-handle .tabulator-row-handle-box .tabulator-row-handle-bar{background:#666;height:3px;margin-top:2px;width:100%}.tabulator-row .tabulator-cell.tabulator-range-selected:not(.tabulator-range-only-cell-selected):not(.tabulator-range-row-header){background-color:var(--tabulator-row-selected-background)}.tabulator-row .tabulator-cell .tabulator-data-tree-branch-empty{display:inline-block;width:7px}.tabulator-row .tabulator-cell .tabulator-data-tree-branch{border-bottom:2px solid var(--tabulator-row-border-color);border-bottom-left-radius:1px;border-left:2px solid var(--tabulator-row-border-color);display:inline-block;height:9px;margin-right:5px;margin-top:-9px;vertical-align:middle;width:7px}.tabulator-row .tabulator-cell .tabulator-data-tree-control{align-items:center;background:rgba(0,0,0,.1);border:1px solid var(--tabulator-row-text-color);border-radius:2px;display:inline-flex;height:11px;justify-content:center;margin-right:5px;overflow:hidden;vertical-align:middle;width:11px}@media (hover:hover) and (pointer:fine){.tabulator-row .tabulator-cell .tabulator-data-tree-control:hover{background:rgba(0,0,0,.2);cursor:pointer}}.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse{background:transparent;display:inline-block;height:7px;position:relative;width:1px}.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after{background:var(--tabulator-row-text-color);content:"";height:1px;left:-3px;position:absolute;top:3px;width:7px}.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand{background:var(--tabulator-row-text-color);display:inline-block;height:7px;position:relative;width:1px}.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand:after{background:var(--tabulator-row-text-color);content:"";height:1px;left:-3px;position:absolute;top:3px;width:7px}.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle{align-items:center;background:#666;border-radius:20px;color:var(--tabulator-row-background-color);display:inline-flex;font-size:1.1em;font-weight:700;height:15px;justify-content:center;-moz-user-select:none;-khtml-user-select:none;-webkit-user-select:none;-o-user-select:none;width:15px}@media (hover:hover) and (pointer:fine){.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle:hover{cursor:pointer;opacity:.7}}.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle.open .tabulator-responsive-collapse-toggle-close{display:initial}.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle.open .tabulator-responsive-collapse-toggle-open{display:none}.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle svg{stroke:var(--tabulator-row-background-color)}.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle .tabulator-responsive-collapse-toggle-close{display:none}.tabulator-row .tabulator-cell .tabulator-traffic-light{border-radius:14px;display:inline-block;height:14px;width:14px}.tabulator-row.tabulator-group{background:#ccc;border-bottom:1px solid #999;border-right:1px solid var(--tabulator-row-border-color);border-top:1px solid #999;box-sizing:border-box;font-weight:700;min-width:100%;padding:5px 5px 5px 10px}@media (hover:hover) and (pointer:fine){.tabulator-row.tabulator-group:hover{background-color:rgba(0,0,0,.1);cursor:pointer}}.tabulator-row.tabulator-group.tabulator-group-visible .tabulator-arrow{border-bottom:0;border-left:6px solid transparent;border-right:6px solid transparent;border-top:6px solid var(--tabulator-sort-arrow-active);margin-right:10px}.tabulator-row.tabulator-group.tabulator-group-level-1{padding-left:30px}.tabulator-row.tabulator-group.tabulator-group-level-2{padding-left:50px}.tabulator-row.tabulator-group.tabulator-group-level-3{padding-left:70px}.tabulator-row.tabulator-group.tabulator-group-level-4{padding-left:90px}.tabulator-row.tabulator-group.tabulator-group-level-5{padding-left:110px}.tabulator-row.tabulator-group .tabulator-group-toggle{display:inline-block}.tabulator-row.tabulator-group .tabulator-arrow{border-bottom:6px solid transparent;border-left:6px solid var(--tabulator-sort-arrow-active);border-right:0;border-top:6px solid transparent;display:inline-block;height:0;margin-right:16px;vertical-align:middle;width:0}.tabulator-row.tabulator-group span{color:#d00}.tabulator-toggle{background:#dcdcdc;border:1px solid #ccc;box-sizing:border-box;display:flex;flex-direction:row}.tabulator-toggle.tabulator-toggle-on{background:#1c6cc2}.tabulator-toggle .tabulator-toggle-switch{background:#fff;border:1px solid #ccc;box-sizing:border-box}.tabulator-popup-container{-webkit-overflow-scrolling:touch;background:var(--tabulator-row-background-color);border:1px solid var(--tabulator-row-border-color);box-shadow:0 0 5px 0 rgba(0,0,0,.2);box-sizing:border-box;display:inline-block;font-size:var(--tabulator-text-size);overflow-y:auto;position:absolute;z-index:10000}.tabulator-popup{border-radius:3px;padding:5px}.tabulator-tooltip{border-radius:2px;box-shadow:none;font-size:12px;max-width:min(500px,100%);padding:3px 5px;pointer-events:none}.tabulator-menu .tabulator-menu-item{box-sizing:border-box;padding:5px 10px;position:relative;user-select:none}.tabulator-menu .tabulator-menu-item.tabulator-menu-item-disabled{opacity:.5}@media (hover:hover) and (pointer:fine){.tabulator-menu .tabulator-menu-item:not(.tabulator-menu-item-disabled):hover{background:var(--tabulator-row-alt-background-color);cursor:pointer}}.tabulator-menu .tabulator-menu-item.tabulator-menu-item-submenu{padding-right:25px}.tabulator-menu .tabulator-menu-item.tabulator-menu-item-submenu:after{border-color:var(--tabulator-row-border-color);border-style:solid;border-width:1px 1px 0 0;content:"";display:inline-block;height:7px;position:absolute;right:10px;top:calc(5px + .4em);transform:rotate(45deg);vertical-align:top;width:7px}.tabulator-menu .tabulator-menu-separator{border-top:1px solid var(--tabulator-row-border-color)}.tabulator-edit-list{-webkit-overflow-scrolling:touch;font-size:var(--tabulator-text-size);max-height:200px;overflow-y:auto}.tabulator-edit-list .tabulator-edit-list-item{color:var(--tabulator-row-text-color);outline:none;padding:4px}.tabulator-edit-list .tabulator-edit-list-item.active{background:var(--tabulator-edit-box-color);color:var(--tabulator-row-background-color)}.tabulator-edit-list .tabulator-edit-list-item.active.focused{outline:1px solid rgba(var(--tabulator-row-background-color),.5)}.tabulator-edit-list .tabulator-edit-list-item.focused{outline:1px solid var(--tabulator-edit-box-color)}@media (hover:hover) and (pointer:fine){.tabulator-edit-list .tabulator-edit-list-item:hover{background:var(--tabulator-edit-box-color);color:var(--tabulator-row-background-color);cursor:pointer}}.tabulator-edit-list .tabulator-edit-list-placeholder{color:var(--tabulator-row-text-color);padding:4px;text-align:center}.tabulator-edit-list .tabulator-edit-list-group{border-bottom:1px solid var(--tabulator-row-border-color);color:var(--tabulator-row-text-color);font-weight:700;padding:6px 4px 4px}.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-2,.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-2{padding-left:12px}.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-3,.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-3{padding-left:20px}.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-4,.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-4{padding-left:28px}.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-5,.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-5{padding-left:36px}.tabulator.tabulator-ltr{direction:ltr}.tabulator.tabulator-rtl{direction:rtl;text-align:initial}.tabulator.tabulator-rtl .tabulator-header .tabulator-col{border-left:1px solid var(--tabulator-header-border-color);border-right:initial;text-align:initial}.tabulator.tabulator-rtl .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols{margin-left:-1px;margin-right:0}.tabulator.tabulator-rtl .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title{padding-left:25px;padding-right:0}.tabulator.tabulator-rtl .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter{left:8px;right:auto}.tabulator.tabulator-rtl .tabulator-tableholder .tabulator-range-overlay .tabulator-range.tabulator-range-active:after{background-color:var(--tabulator-range-handle-color);border-radius:999px;bottom:-3px;content:"";height:6px;left:-3px;position:absolute;right:auto;width:6px}.tabulator.tabulator-rtl .tabulator-row .tabulator-cell{border-left:1px solid var(--tabulator-row-border-color);border-right:initial}.tabulator.tabulator-rtl .tabulator-row .tabulator-cell .tabulator-data-tree-branch{border-bottom-left-radius:0;border-bottom-right-radius:1px;border-left:initial;border-right:2px solid var(--tabulator-row-border-color);margin-left:5px;margin-right:0}.tabulator.tabulator-rtl .tabulator-row .tabulator-cell .tabulator-data-tree-control{margin-left:5px;margin-right:0}.tabulator.tabulator-rtl .tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-left{border-left:2px solid var(--tabulator-row-border-color)}.tabulator.tabulator-rtl .tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-right{border-right:2px solid var(--tabulator-row-border-color)}.tabulator.tabulator-rtl .tabulator-row .tabulator-col-resize-handle:last-of-type{margin-left:0;margin-right:-3px;width:3px}.tabulator.tabulator-rtl .tabulator-footer .tabulator-calcs-holder{text-align:initial}.tabulator-print-fullscreen{bottom:0;left:0;position:absolute;right:0;top:0;z-index:10000}body.tabulator-print-fullscreen-hide>:not(.tabulator-print-fullscreen){display:none!important}.tabulator-print-table{border-collapse:collapse}.tabulator-print-table .tabulator-data-tree-branch{border-bottom:2px solid var(--tabulator-row-border-color);border-bottom-left-radius:1px;border-left:2px solid var(--tabulator-row-border-color);display:inline-block;height:9px;margin-right:5px;margin-top:-9px;vertical-align:middle;width:7px}.tabulator-print-table .tabulator-print-table-group{background:#ccc;border-bottom:1px solid #999;border-right:1px solid var(--tabulator-row-border-color);border-top:1px solid #999;box-sizing:border-box;font-weight:700;min-width:100%;padding:5px 5px 5px 10px}@media (hover:hover) and (pointer:fine){.tabulator-print-table .tabulator-print-table-group:hover{background-color:rgba(0,0,0,.1);cursor:pointer}}.tabulator-print-table .tabulator-print-table-group.tabulator-group-visible .tabulator-arrow{border-bottom:0;border-left:6px solid transparent;border-right:6px solid transparent;border-top:6px solid var(--tabulator-sort-arrow-active);margin-right:10px}.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-1 td{padding-left:30px!important}.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-2 td{padding-left:50px!important}.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-3 td{padding-left:70px!important}.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-4 td{padding-left:90px!important}.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-5 td{padding-left:110px!important}.tabulator-print-table .tabulator-print-table-group .tabulator-group-toggle{display:inline-block}.tabulator-print-table .tabulator-print-table-group .tabulator-arrow{border-bottom:6px solid transparent;border-left:6px solid var(--tabulator-sort-arrow-active);border-right:0;border-top:6px solid transparent;display:inline-block;height:0;margin-right:16px;vertical-align:middle;width:0}.tabulator-print-table .tabulator-print-table-group span{color:#d00}.tabulator-print-table .tabulator-data-tree-control{align-items:center;background:rgba(0,0,0,.1);border:1px solid var(--tabulator-row-text-color);border-radius:2px;display:inline-flex;height:11px;justify-content:center;margin-right:5px;overflow:hidden;vertical-align:middle;width:11px}@media (hover:hover) and (pointer:fine){.tabulator-print-table .tabulator-data-tree-control:hover{background:rgba(0,0,0,.2);cursor:pointer}}.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse{background:transparent;display:inline-block;height:7px;position:relative;width:1px}.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after{background:var(--tabulator-row-text-color);content:"";height:1px;left:-3px;position:absolute;top:3px;width:7px}.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand{background:var(--tabulator-row-text-color);display:inline-block;height:7px;position:relative;width:1px}.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand:after{background:var(--tabulator-row-text-color);content:"";height:1px;left:-3px;position:absolute;top:3px;width:7px}.tabulator{background-color:var(--tabulator-background-color);max-width:100%;width:100%}.tabulator .tabulator-header{color:inherit}.tabulator .tabulator-header .tabulator-col{border-top:none}.tabulator .tabulator-header .tabulator-col:first-of-type{border-left:none}.tabulator .tabulator-header .tabulator-col:last-of-type{border-right:none}.tabulator .tabulator-header .tabulator-col:not(first-of-type),.tabulator .tabulator-header .tabulator-col:not(last-of-type){border-right:1px solid var(--tabulator-header-border-color)}.tabulator .tabulator-header .tabulator-col .tabulator-col-content{padding:var(--tabulator-cell-padding)}.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter{right:-10px}.tabulator .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols{border-top:1px solid var(--tabulator-border-color)}.tabulator .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title{padding-right:10px}.tabulator .tabulator-header .tabulator-calcs-holder{border-bottom:1px solid var(--tabulator-header-separator-color);width:100%}.tabulator .tabulator-header .tabulator-frozen-rows-holder{min-width:600%}.tabulator .tabulator-header .tabulator-frozen-rows-holder:empty{display:none}.tabulator .tabulator-header .tabulator-frozen .tabulator-frozen-left,.tabulator .tabulator-header .tabulator-frozen .tabulator-frozen-right{background:inherit}.tabulator .tabulator-tableholder .tabulator-table{color:inherit}.tabulator .tabulator-footer{background-color:var(--tabulator-footer-background-color);color:inherit}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab{font-weight:400;padding:8px 12px}.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab.tabulator-spreadsheet-tab-active{color:var(--tabulator-footer-active-color)}.tabulator .tabulator-footer .tabulator-paginator{color:inherit}.tabulator .tabulator-footer .tabulator-page{background:var(--tabulator-pagination-button-background);border-radius:0;border-right:none;color:var(--tabulator-pagination-button-color);margin:5px 0 0;padding:8px 12px}.tabulator .tabulator-footer .tabulator-page:first-of-type,.tabulator .tabulator-footer .tabulator-page[data-page=next]{border-bottom-left-radius:4px;border-top-left-radius:4px}.tabulator .tabulator-footer .tabulator-page:last-of-type,.tabulator .tabulator-footer .tabulator-page[data-page=prev]{border:1px solid var(--tabulator-footer-border-color);border-bottom-right-radius:4px;border-top-right-radius:4px}.tabulator .tabulator-footer .tabulator-page:not(disabled):hover{background:var(--tabulator-pagination-button-background-hover);color:var(--tabulator-pagination-button-color-hover)}.tabulator .tabulator-footer .tabulator-page.active,.tabulator .tabulator-footer .tabulator-page[data-page=first] :not(disabled):not(:hover),.tabulator .tabulator-footer .tabulator-page[data-page=last] :not(disabled):not(:hover),.tabulator .tabulator-footer .tabulator-page[data-page=next] :not(disabled):not(:hover),.tabulator .tabulator-footer .tabulator-page[data-page=prev] :not(disabled):not(:hover){color:var(--tabulator-pagination-button-color-active)}.tabulator.striped .tabulator-row:nth-child(2n){background-color:var(--tabulator-row-alt-background-color)}.tabulator.striped .tabulator-row:nth-child(2n).tabulator-selected{background-color:var(--tabulator-row-selected-background)!important}@media (hover:hover) and (pointer:fine){.tabulator.striped .tabulator-row:nth-child(2n).tabulator-selectable:hover{background-color:var(--tabulator-row-hover-background);cursor:pointer}.tabulator.striped .tabulator-row:nth-child(2n).tabulator-selected:hover{background-color:var(--tabulator-row-selected-background-hover)!important;cursor:pointer}}.tabulator-row{border-bottom:1px solid var(--tabulator-row-border-color);min-height:calc(var(--tabulator-text-size) + var(--tabulator-cell-padding)*2)}.tabulator-row.tabulator-row-even{background-color:var(--tabulator-row-background-color)}.tabulator-row .tabulator-cell{padding:var(--tabulator-cell-padding)}.tabulator-row .tabulator-cell:last-of-type{border-right:none}.tabulator-row .tabulator-cell.tabulator-row-header{background:var(--tabulator-header-background-color);border-bottom:none;border-right:1px solid var(--tabulator-border-color)}.tabulator-row .tabulator-cell .tabulator-data-tree-control{border:1px solid #ccc}.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after,.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand,.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand:after{background:#ccc}.tabulator-row.tabulator-group{background:#fafafa}.tabulator-row.tabulator-group span{color:#666;margin-left:10px}.tabulator-edit-select-list{background:var(--tabulator-header-background-color)}.tabulator-edit-select-list .tabulator-edit-select-list-item{color:inherit}.tabulator-edit-select-list .tabulator-edit-select-list-item.active{color:var(--tabulator-header-background-color)}.tabulator-edit-select-list .tabulator-edit-select-list-item.active.focused{outline:1px solid rgba(var(--tabulator-header-background-color),.5)}@media (hover:hover) and (pointer:fine){.tabulator-edit-select-list .tabulator-edit-select-list-item:hover{color:var(--tabulator-header-background-color)}}.tabulator-edit-select-list .tabulator-edit-select-list-group,.tabulator-edit-select-list .tabulator-edit-select-list-notice{color:inherit}.tabulator.tabulator-rtl .tabulator-header .tabulator-col{border-left:none;border-right:none}.tabulator-print-table .tabulator-print-table-group{background:#fafafa}.tabulator-print-table .tabulator-print-table-group span{color:#666;margin-left:10px}.tabulator-print-table .tabulator-data-tree-control{border:1px solid #ccc}.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after,.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand,.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand:after{background:#ccc} +:root { + --tabulator-background-color: #fff; + --tabulator-border-color: rgba(0, 0, 0, .12); + --tabulator-text-size: 16px; + --tabulator-header-background-color: #f5f5f5; + --tabulator-header-text-color: #555; + --tabulator-header-border-color: rgba(0, 0, 0, .12); + --tabulator-header-separator-color: rgba(0, 0, 0, .12); + --tabulator-header-margin: 4px; + --tabulator-sort-arrow-hover: #555; + --tabulator-sort-arrow-active: #666; + --tabulator-sort-arrow-inactive: #bbb; + --tabulator-column-resize-guide-color: #999; + --tabulator-row-background-color: #fff; + --tabulator-row-alt-background-color: #f8f8f8; + --tabulator-row-border-color: rgba(0, 0, 0, .12); + --tabulator-row-text-color: #333; + --tabulator-row-hover-background: #e1f5fe; + --tabulator-row-selected-background: #ace5ff; + --tabulator-row-selected-background-hover: #9bcfe8; + --tabulator-edit-box-color: #17161d; + --tabulator-error-color: #d00; + --tabulator-footer-background-color: transparent; + --tabulator-footer-text-color: #555; + --tabulator-footer-border-color: rgba(0, 0, 0, .12); + --tabulator-footer-separator-color: rgba(0, 0, 0, .12); + --tabulator-footer-active-color: #17161d; + --tabulator-spreadsheet-active-tab-color: #fff; + --tabulator-range-border-color: #17161d; + --tabulator-range-handle-color: #17161d; + --tabulator-range-header-selected-background: var(--tabulator-range-border-color); + --tabulator-range-header-selected-text-color: #fff; + --tabulator-range-header-highlight-background: colors-gray-timberwolf; + --tabulator-range-header-text-highlight-background: #fff; + --tabulator-pagination-button-background: #fff; + --tabulator-pagination-button-background-hover: #06c; + --tabulator-pagination-button-color: #999; + --tabulator-pagination-button-color-hover: #fff; + --tabulator-pagination-button-color-active: #000; + --tabulator-cell-padding: 15px +} + +body.vscode-dark, +body[data-jp-theme-light=false] { + --tabulator-background-color: #080808; + --tabulator-border-color: #666; + --tabulator-text-size: 16px; + --tabulator-header-background-color: #212121; + --tabulator-header-text-color: #555; + --tabulator-header-border-color: #666; + --tabulator-header-separator-color: #666; + --tabulator-header-margin: 4px; + --tabulator-sort-arrow-hover: #fff; + --tabulator-sort-arrow-active: #e6e6e6; + --tabulator-sort-arrow-inactive: #666; + --tabulator-column-resize-guide-color: #999; + --tabulator-row-background-color: #080808; + --tabulator-row-alt-background-color: #212121; + --tabulator-row-border-color: #666; + --tabulator-row-text-color: #f8f8f8; + --tabulator-row-hover-background: #333; + --tabulator-row-selected-background: #3d355d; + --tabulator-row-selected-background-hover: #483f69; + --tabulator-edit-box-color: #333; + --tabulator-error-color: #d00; + --tabulator-footer-background-color: transparent; + --tabulator-footer-text-color: #555; + --tabulator-footer-border-color: rgba(0, 0, 0, .12); + --tabulator-footer-separator-color: rgba(0, 0, 0, .12); + --tabulator-footer-active-color: #17161d; + --tabulator-spreadsheet-active-tab-color: #fff; + --tabulator-range-border-color: #17161d; + --tabulator-range-handle-color: var(--tabulator-range-border-color); + --tabulator-range-header-selected-background: var(--tabulator-range-border-color); + --tabulator-range-header-selected-text-color: #fff; + --tabulator-range-header-highlight-background: #d6d6d6; + --tabulator-range-header-text-highlight-background: #fff; + --tabulator-pagination-button-background: #212121; + --tabulator-pagination-button-background-hover: #555; + --tabulator-pagination-button-color: #999; + --tabulator-pagination-button-color-hover: #fff; + --tabulator-pagination-button-color-active: #fff; + --tabulator-cell-padding: 15px +} + +.tabulator { + border: 1px solid var(--tabulator-border-color); + font-size: var(--tabulator-text-size); + overflow: hidden; + position: relative; + text-align: left; + -webkit-transform: translateZ(0); + -moz-transform: translateZ(0); + -ms-transform: translateZ(0); + -o-transform: translateZ(0); + transform: translateZ(0) +} + +.tabulator[tabulator-layout=fitDataFill] .tabulator-tableholder .tabulator-table { + min-width: 100% +} + +.tabulator[tabulator-layout=fitDataTable] { + display: inline-block +} + +.tabulator.tabulator-block-select, +.tabulator.tabulator-ranges .tabulator-cell:not(.tabulator-editing) { + user-select: none +} + +.tabulator .tabulator-header { + background-color: var(--tabulator-header-background-color); + border-bottom: 1px solid var(--tabulator-header-separator-color); + box-sizing: border-box; + color: var(--tabulator-header-text-color); + font-weight: 700; + outline: none; + overflow: hidden; + position: relative; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + white-space: nowrap; + width: 100% +} + +.tabulator .tabulator-header.tabulator-header-hidden { + display: none +} + +.tabulator .tabulator-header .tabulator-header-contents { + overflow: hidden; + position: relative +} + +.tabulator .tabulator-header .tabulator-header-contents .tabulator-headers { + display: inline-block +} + +.tabulator .tabulator-header .tabulator-col { + background: var(--tabulator-header-background-color); + border-right: 1px solid var(--tabulator-header-border-color); + box-sizing: border-box; + display: inline-flex; + flex-direction: column; + justify-content: flex-start; + overflow: hidden; + position: relative; + text-align: left; + vertical-align: bottom +} + +.tabulator .tabulator-header .tabulator-col.tabulator-moving { + background: hsl(var(--tabulator-header-background-color), calc(var(--tabulator-header-background-color) - 5%)) !important; + border: 1px solid var(--tabulator-header-separator-color); + pointer-events: none; + position: absolute +} + +.tabulator .tabulator-header .tabulator-col.tabulator-range-highlight { + background-color: var(--tabulator-range-header-highlight-background); + color: var(--tabulator-range-header-text-highlight-background) +} + +.tabulator .tabulator-header .tabulator-col.tabulator-range-selected { + background-color: var(--tabulator-range-header-selected-background); + color: var(--tabulator-range-header-selected-text-color) +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content { + box-sizing: border-box; + padding: 4px; + position: relative +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-header-popup-button { + padding: 0 8px +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-header-popup-button:hover { + cursor: pointer; + opacity: .6 +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title-holder { + position: relative +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title { + box-sizing: border-box; + overflow: hidden; + text-overflow: ellipsis; + vertical-align: bottom; + white-space: nowrap; + width: 100% +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title.tabulator-col-title-wrap { + text-overflow: clip; + white-space: normal +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title .tabulator-title-editor { + background: #fff; + border: 1px solid #999; + box-sizing: border-box; + padding: 1px; + width: 100% +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-title .tabulator-header-popup-button+.tabulator-title-editor { + width: calc(100% - 22px) +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter { + align-items: center; + bottom: 0; + display: flex; + position: absolute; + right: 4px; + top: 0 +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter .tabulator-arrow { + border-bottom: 6px solid var(--tabulator-sort-arrow-inactive); + border-left: 6px solid transparent; + border-right: 6px solid transparent; + height: 0; + width: 0 +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols { + border-top: 1px solid var(--tabulator-header-border-color); + display: flex; + margin-right: -1px; + overflow: hidden; + position: relative +} + +.tabulator .tabulator-header .tabulator-col .tabulator-header-filter { + box-sizing: border-box; + margin-top: 2px; + position: relative; + text-align: center; + width: 100% +} + +.tabulator .tabulator-header .tabulator-col .tabulator-header-filter textarea { + height: auto !important +} + +.tabulator .tabulator-header .tabulator-col .tabulator-header-filter svg { + margin-top: 3px +} + +.tabulator .tabulator-header .tabulator-col .tabulator-header-filter input::-ms-clear { + height: 0; + width: 0 +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title { + padding-right: 25px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-header .tabulator-col.tabulator-sortable.tabulator-col-sorter-element:hover { + background-color: hsl(var(--tabulator-header-background-color), calc(var(--tabulator-header-background-color) - 10%)) !important; + cursor: pointer + } +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter { + color: var(--tabulator-sort-arrow-inactive) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover { + border-bottom: 6px solid var(--tabulator-sort-arrow-hover); + cursor: pointer + } +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=none] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow { + border-bottom: 6px solid var(--tabulator-sort-arrow-inactive); + border-top: none +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter { + color: var(--tabulator-sort-arrow-active) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover { + border-bottom: 6px solid var(--tabulator-sort-arrow-hover); + cursor: pointer + } +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=ascending] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow { + border-bottom: 6px solid var(--tabulator-sort-arrow-active); + border-top: none +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter { + color: var(--tabulator-sort-arrow-active) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter.tabulator-col-sorter-element .tabulator-arrow:hover { + border-top: 6px solid var(--tabulator-sort-arrow-hover); + cursor: pointer + } +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable[aria-sort=descending] .tabulator-col-content .tabulator-col-sorter .tabulator-arrow { + border-bottom: none; + border-top: 6px solid var(--tabulator-sort-arrow-active); + color: var(--tabulator-sort-arrow-active) +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical .tabulator-col-content .tabulator-col-title { + align-items: center; + display: flex; + justify-content: center; + text-orientation: mixed; + writing-mode: vertical-rl +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-col-vertical-flip .tabulator-col-title { + transform: rotate(180deg) +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable .tabulator-col-title { + padding-right: 0; + padding-top: 20px +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable.tabulator-col-vertical-flip .tabulator-col-title { + padding-bottom: 20px; + padding-right: 0 +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-vertical.tabulator-sortable .tabulator-col-sorter { + bottom: auto; + justify-content: center; + left: 0; + right: 0; + top: 4px +} + +.tabulator .tabulator-header .tabulator-frozen { + left: 0; + position: sticky; + z-index: 11 +} + +.tabulator .tabulator-header .tabulator-frozen.tabulator-frozen-left { + border-right: 2px solid var(--tabulator-row-border-color) +} + +.tabulator .tabulator-header .tabulator-frozen.tabulator-frozen-right { + border-left: 2px solid var(--tabulator-row-border-color) +} + +.tabulator .tabulator-header .tabulator-calcs-holder { + border-bottom: 1px solid var(--tabulator-header-border-color); + border-top: 1px solid var(--tabulator-row-border-color); + box-sizing: border-box; + display: inline-block +} + +.tabulator .tabulator-header .tabulator-calcs-holder, +.tabulator .tabulator-header .tabulator-calcs-holder .tabulator-row { + background: hsl(var(--tabulator-header-background-color), calc(var(--tabulator-header-background-color) + 5%)) !important +} + +.tabulator .tabulator-header .tabulator-calcs-holder .tabulator-row .tabulator-col-resize-handle { + display: none +} + +.tabulator .tabulator-header .tabulator-frozen-rows-holder { + display: inline-block +} + +.tabulator .tabulator-tableholder { + -webkit-overflow-scrolling: touch; + overflow: auto; + position: relative; + white-space: nowrap; + width: 100% +} + +.tabulator .tabulator-tableholder:focus { + outline: none +} + +.tabulator .tabulator-tableholder .tabulator-placeholder { + align-items: center; + box-sizing: border-box; + display: flex; + justify-content: center; + min-width: 100%; + width: 100% +} + +.tabulator .tabulator-tableholder .tabulator-placeholder[tabulator-render-mode=virtual] { + min-height: 100% +} + +.tabulator .tabulator-tableholder .tabulator-placeholder .tabulator-placeholder-contents { + color: #ccc; + display: inline-block; + font-size: 20px; + font-weight: 700; + padding: 10px; + text-align: center; + white-space: normal +} + +.tabulator .tabulator-tableholder .tabulator-table { + background-color: var(--tabulator-row-background-color); + color: var(--tabulator-row-text-color); + display: inline-block; + overflow: visible; + position: relative; + white-space: nowrap +} + +.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs { + background: hsl(var(--tabulator-row-atl-background-color), calc(var(--tabulator-row-alt-background-color) - 5%)) !important; + font-weight: 700 +} + +.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs.tabulator-calcs-top { + border-bottom: 2px solid var(--tabulator-row-border-color) +} + +.tabulator .tabulator-tableholder .tabulator-table .tabulator-row.tabulator-calcs.tabulator-calcs-bottom { + border-top: 2px solid var(--tabulator-row-border-color) +} + +.tabulator .tabulator-tableholder .tabulator-range-overlay { + inset: 0; + pointer-events: none; + position: absolute; + z-index: 10 +} + +.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range { + border: 1px solid var(--tabulator-range-border-color); + box-sizing: border-box; + position: absolute +} + +.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range.tabulator-range-active:after { + background-color: var(--tabulator-range-handle-color); + border-radius: 999px; + bottom: -3px; + content: ""; + height: 6px; + position: absolute; + right: -3px; + width: 6px +} + +.tabulator .tabulator-tableholder .tabulator-range-overlay .tabulator-range-cell-active { + border: 2px solid var(--tabulator-range-border-color); + box-sizing: border-box; + position: absolute +} + +.tabulator .tabulator-footer { + color: var(--tabulator-footer-text-color); + font-weight: 700; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + white-space: nowrap +} + +.tabulator .tabulator-footer .tabulator-footer-contents { + align-items: center; + display: flex; + flex-direction: row; + justify-content: space-between; + padding: 5px 10px +} + +.tabulator .tabulator-footer .tabulator-footer-contents:empty { + display: none +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs { + margin-top: -5px; + overflow-x: auto +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab { + border: 1px solid var(--tabulator-border-color); + border-bottom-left-radius: 5px; + border-bottom-right-radius: 5px; + border-top: none; + display: inline-block; + font-size: .9em; + padding: 5px +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab:hover { + cursor: pointer; + opacity: .7 +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab.tabulator-spreadsheet-tab-active { + background: var(--tabulator-spreadsheet-active-tab-color) +} + +.tabulator .tabulator-footer .tabulator-calcs-holder { + border-bottom: 1px solid var(--tabulator-row-border-color); + border-top: 1px solid var(--tabulator-row-border-color); + box-sizing: border-box; + overflow: hidden; + text-align: left; + width: 100% +} + +.tabulator .tabulator-footer .tabulator-calcs-holder .tabulator-row { + display: inline-block +} + +.tabulator .tabulator-footer .tabulator-calcs-holder .tabulator-row .tabulator-col-resize-handle { + display: none +} + +.tabulator .tabulator-footer .tabulator-calcs-holder:only-child { + border-bottom: none; + margin-bottom: -5px +} + +.tabulator .tabulator-footer>*+.tabulator-page-counter { + margin-left: 10px +} + +.tabulator .tabulator-footer .tabulator-page-counter { + font-weight: 400 +} + +.tabulator .tabulator-footer .tabulator-paginator { + color: var(--tabulator-footer-text-color); + flex: 1; + font-family: inherit; + font-size: inherit; + font-weight: inherit; + text-align: right +} + +.tabulator .tabulator-footer .tabulator-page-size { + border: 1px solid var(--tabulator-footer-border-color); + border-radius: 3px; + display: inline-block; + margin: 0 5px; + padding: 2px 5px +} + +.tabulator .tabulator-footer .tabulator-pages { + margin: 0 7px +} + +.tabulator .tabulator-footer .tabulator-page { + background: hsla(0, 0%, 100%, .2); + border: 1px solid var(--tabulator-footer-border-color); + border-radius: 3px; + display: inline-block; + margin: 0 2px; + padding: 2px 5px +} + +.tabulator .tabulator-footer .tabulator-page.active { + color: var(--tabulator-footer-active-color) +} + +.tabulator .tabulator-footer .tabulator-page:disabled { + opacity: .5 +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-footer .tabulator-page:not(disabled):hover { + background: rgba(0, 0, 0, .2); + color: #fff; + cursor: pointer + } +} + +.tabulator .tabulator-col-resize-handle { + display: inline-block; + margin-left: -3px; + margin-right: -3px; + position: relative; + vertical-align: middle; + width: 6px; + z-index: 11 +} + +@media (hover:hover) and (pointer:fine) { + .tabulator .tabulator-col-resize-handle:hover { + cursor: ew-resize + } +} + +.tabulator .tabulator-col-resize-handle:last-of-type { + margin-right: 0; + width: 3px +} + +.tabulator .tabulator-col-resize-guide { + height: 100%; + margin-left: -.5px; + top: 0; + width: 4px +} + +.tabulator .tabulator-col-resize-guide, +.tabulator .tabulator-row-resize-guide { + background-color: var(--tabulator-column-resize-guide-color); + opacity: .5; + position: absolute +} + +.tabulator .tabulator-row-resize-guide { + height: 4px; + left: 0; + margin-top: -.5px; + width: 100% +} + +.tabulator .tabulator-alert { + align-items: center; + background: rgba(0, 0, 0, .4); + display: flex; + height: 100%; + left: 0; + position: absolute; + text-align: center; + top: 0; + width: 100%; + z-index: 100 +} + +.tabulator .tabulator-alert .tabulator-alert-msg { + background: #fff; + border-radius: 10px; + display: inline-block; + font-size: 16px; + font-weight: 700; + margin: 0 auto; + padding: 10px 20px +} + +.tabulator .tabulator-alert .tabulator-alert-msg.tabulator-alert-state-msg { + border: 4px solid #333; + color: #000 +} + +.tabulator .tabulator-alert .tabulator-alert-msg.tabulator-alert-state-error { + border: 4px solid #d00; + color: #590000 +} + +.tabulator-row { + background-color: var(--tabulator-row-background-color); + box-sizing: border-box; + min-height: calc(var(--tabulator-text-size) + var(--tabulator-header-margin)*2); + position: relative +} + +.tabulator-row.tabulator-row-even { + background-color: var(--tabulator-row-alt-background-color) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row.tabulator-selectable:hover { + background-color: var(--tabulator-row-hover-background); + cursor: pointer + } +} + +.tabulator-row.tabulator-selected { + background-color: var(--tabulator-row-selected-background) !important; +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row.tabulator-selected:hover { + background-color: var(--tabulator-row-selected-background-hover) !important; + cursor: pointer + } +} + +.tabulator-row.tabulator-row-moving { + background: #fff; + border: 1px solid #000 +} + +.tabulator-row.tabulator-moving { + border-bottom: 1px solid var(--tabulator-row-border-color); + border-top: 1px solid var(--tabulator-row-border-color); + pointer-events: none; + position: absolute; + z-index: 15 +} + +.tabulator-row.tabulator-range-highlight .tabulator-cell.tabulator-range-row-header { + background-color: var(--tabulator-range-header-highlight-background); + color: var(--tabulator-range-header-text-highlight-background) +} + +.tabulator-row.tabulator-range-highlight.tabulator-range-selected .tabulator-cell.tabulator-range-row-header, +.tabulator-row.tabulator-range-selected .tabulator-cell.tabulator-range-row-header { + background-color: var(--tabulator-range-header-selected-background); + color: var(--tabulator-range-header-selected-text-color) +} + +.tabulator-row .tabulator-row-resize-handle { + bottom: 0; + height: 5px; + left: 0; + position: absolute; + right: 0 +} + +.tabulator-row .tabulator-row-resize-handle.prev { + bottom: auto; + top: 0 +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row .tabulator-row-resize-handle:hover { + cursor: ns-resize + } +} + +.tabulator-row .tabulator-responsive-collapse { + border-bottom: 1px solid var(--tabulator-row-border-color); + border-top: 1px solid var(--tabulator-row-border-color); + box-sizing: border-box; + padding: 5px +} + +.tabulator-row .tabulator-responsive-collapse:empty { + display: none +} + +.tabulator-row .tabulator-responsive-collapse table { + font-size: var(--tabulator-text-size) +} + +.tabulator-row .tabulator-responsive-collapse table tr td { + position: relative +} + +.tabulator-row .tabulator-responsive-collapse table tr td:first-of-type { + padding-right: 10px +} + +.tabulator-row .tabulator-cell { + border-right: 1px solid var(--tabulator-row-border-color); + box-sizing: border-box; + display: inline-block; + outline: none; + overflow: hidden; + padding: 4px; + position: relative; + text-overflow: ellipsis; + vertical-align: middle; + white-space: nowrap +} + +.tabulator-row .tabulator-cell.tabulator-row-header { + border-bottom: 1px solid var(--tabulator-row-border-color) +} + +.tabulator-row .tabulator-cell.tabulator-frozen { + background-color: inherit; + display: inline-block; + left: 0; + position: sticky; + z-index: 11 +} + +.tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-left { + border-right: 2px solid var(--tabulator-row-border-color) +} + +.tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-right { + border-left: 2px solid var(--tabulator-row-border-color) +} + +.tabulator-row .tabulator-cell.tabulator-editing { + border: 1px solid var(--tabulator-edit-box-color); + outline: none; + padding: 0 +} + +.tabulator-row .tabulator-cell.tabulator-editing input, +.tabulator-row .tabulator-cell.tabulator-editing select { + background: transparent; + border: 1px; + outline: none +} + +.tabulator-row .tabulator-cell.tabulator-validation-fail { + border: 1px solid var(--tabulator-error-color) +} + +.tabulator-row .tabulator-cell.tabulator-validation-fail input, +.tabulator-row .tabulator-cell.tabulator-validation-fail select { + background: transparent; + border: 1px; + color: var(--tabulator-error-color) +} + +.tabulator-row .tabulator-cell.tabulator-row-handle { + align-items: center; + display: inline-flex; + justify-content: center; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none +} + +.tabulator-row .tabulator-cell.tabulator-row-handle .tabulator-row-handle-box { + width: 80% +} + +.tabulator-row .tabulator-cell.tabulator-row-handle .tabulator-row-handle-box .tabulator-row-handle-bar { + background: #666; + height: 3px; + margin-top: 2px; + width: 100% +} + +.tabulator-row .tabulator-cell.tabulator-range-selected:not(.tabulator-range-only-cell-selected):not(.tabulator-range-row-header) { + background-color: var(--tabulator-row-selected-background) +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-branch-empty { + display: inline-block; + width: 7px +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-branch { + border-bottom: 2px solid var(--tabulator-row-border-color); + border-bottom-left-radius: 1px; + border-left: 2px solid var(--tabulator-row-border-color); + display: inline-block; + height: 9px; + margin-right: 5px; + margin-top: -9px; + vertical-align: middle; + width: 7px +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control { + align-items: center; + background: rgba(0, 0, 0, .1); + border: 1px solid var(--tabulator-row-text-color); + border-radius: 2px; + display: inline-flex; + height: 11px; + justify-content: center; + margin-right: 5px; + overflow: hidden; + vertical-align: middle; + width: 11px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row .tabulator-cell .tabulator-data-tree-control:hover { + background: rgba(0, 0, 0, .2); + cursor: pointer + } +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse { + background: transparent; + display: inline-block; + height: 7px; + position: relative; + width: 1px +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after { + background: var(--tabulator-row-text-color); + content: ""; + height: 1px; + left: -3px; + position: absolute; + top: 3px; + width: 7px +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand { + background: var(--tabulator-row-text-color); + display: inline-block; + height: 7px; + position: relative; + width: 1px +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand:after { + background: var(--tabulator-row-text-color); + content: ""; + height: 1px; + left: -3px; + position: absolute; + top: 3px; + width: 7px +} + +.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle { + align-items: center; + background: #666; + border-radius: 20px; + color: var(--tabulator-row-background-color); + display: inline-flex; + font-size: 1.1em; + font-weight: 700; + height: 15px; + justify-content: center; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + width: 15px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle:hover { + cursor: pointer; + opacity: .7 + } +} + +.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle.open .tabulator-responsive-collapse-toggle-close { + display: initial +} + +.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle.open .tabulator-responsive-collapse-toggle-open { + display: none +} + +.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle svg { + stroke: var(--tabulator-row-background-color) +} + +.tabulator-row .tabulator-cell .tabulator-responsive-collapse-toggle .tabulator-responsive-collapse-toggle-close { + display: none +} + +.tabulator-row .tabulator-cell .tabulator-traffic-light { + border-radius: 14px; + display: inline-block; + height: 14px; + width: 14px +} + +.tabulator-row.tabulator-group { + background: #ccc; + border-bottom: 1px solid #999; + border-right: 1px solid var(--tabulator-row-border-color); + border-top: 1px solid #999; + box-sizing: border-box; + font-weight: 700; + min-width: 100%; + padding: 5px 5px 5px 10px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-row.tabulator-group:hover { + background-color: rgba(0, 0, 0, .1); + cursor: pointer + } +} + +.tabulator-row.tabulator-group.tabulator-group-visible .tabulator-arrow { + border-bottom: 0; + border-left: 6px solid transparent; + border-right: 6px solid transparent; + border-top: 6px solid var(--tabulator-sort-arrow-active); + margin-right: 10px +} + +.tabulator-row.tabulator-group.tabulator-group-level-1 { + padding-left: 30px +} + +.tabulator-row.tabulator-group.tabulator-group-level-2 { + padding-left: 50px +} + +.tabulator-row.tabulator-group.tabulator-group-level-3 { + padding-left: 70px +} + +.tabulator-row.tabulator-group.tabulator-group-level-4 { + padding-left: 90px +} + +.tabulator-row.tabulator-group.tabulator-group-level-5 { + padding-left: 110px +} + +.tabulator-row.tabulator-group .tabulator-group-toggle { + display: inline-block +} + +.tabulator-row.tabulator-group .tabulator-arrow { + border-bottom: 6px solid transparent; + border-left: 6px solid var(--tabulator-sort-arrow-active); + border-right: 0; + border-top: 6px solid transparent; + display: inline-block; + height: 0; + margin-right: 16px; + vertical-align: middle; + width: 0 +} + +.tabulator-row.tabulator-group span { + color: #d00 +} + +.tabulator-toggle { + background: #dcdcdc; + border: 1px solid #ccc; + box-sizing: border-box; + display: flex; + flex-direction: row +} + +.tabulator-toggle.tabulator-toggle-on { + background: #1c6cc2 +} + +.tabulator-toggle .tabulator-toggle-switch { + background: #fff; + border: 1px solid #ccc; + box-sizing: border-box +} + +.tabulator-popup-container { + -webkit-overflow-scrolling: touch; + background: var(--tabulator-row-background-color); + border: 1px solid var(--tabulator-row-border-color); + box-shadow: 0 0 5px 0 rgba(0, 0, 0, .2); + box-sizing: border-box; + display: inline-block; + font-size: var(--tabulator-text-size); + overflow-y: auto; + position: absolute; + z-index: 10000 +} + +.tabulator-popup { + border-radius: 3px; + padding: 5px +} + +.tabulator-tooltip { + border-radius: 2px; + box-shadow: none; + font-size: 12px; + max-width: min(500px, 100%); + padding: 3px 5px; + pointer-events: none +} + +.tabulator-menu .tabulator-menu-item { + box-sizing: border-box; + padding: 5px 10px; + position: relative; + user-select: none +} + +.tabulator-menu .tabulator-menu-item.tabulator-menu-item-disabled { + opacity: .5 +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-menu .tabulator-menu-item:not(.tabulator-menu-item-disabled):hover { + background: var(--tabulator-row-alt-background-color); + cursor: pointer + } +} + +.tabulator-menu .tabulator-menu-item.tabulator-menu-item-submenu { + padding-right: 25px +} + +.tabulator-menu .tabulator-menu-item.tabulator-menu-item-submenu:after { + border-color: var(--tabulator-row-border-color); + border-style: solid; + border-width: 1px 1px 0 0; + content: ""; + display: inline-block; + height: 7px; + position: absolute; + right: 10px; + top: calc(5px + .4em); + transform: rotate(45deg); + vertical-align: top; + width: 7px +} + +.tabulator-menu .tabulator-menu-separator { + border-top: 1px solid var(--tabulator-row-border-color) +} + +.tabulator-edit-list { + -webkit-overflow-scrolling: touch; + font-size: var(--tabulator-text-size); + max-height: 200px; + overflow-y: auto +} + +.tabulator-edit-list .tabulator-edit-list-item { + color: var(--tabulator-row-text-color); + outline: none; + padding: 4px +} + +.tabulator-edit-list .tabulator-edit-list-item.active { + background: var(--tabulator-edit-box-color); + color: var(--tabulator-row-background-color) +} + +.tabulator-edit-list .tabulator-edit-list-item.active.focused { + outline: 1px solid rgba(var(--tabulator-row-background-color), .5) +} + +.tabulator-edit-list .tabulator-edit-list-item.focused { + outline: 1px solid var(--tabulator-edit-box-color) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-edit-list .tabulator-edit-list-item:hover { + background: var(--tabulator-edit-box-color); + color: var(--tabulator-row-background-color); + cursor: pointer + } +} + +.tabulator-edit-list .tabulator-edit-list-placeholder { + color: var(--tabulator-row-text-color); + padding: 4px; + text-align: center +} + +.tabulator-edit-list .tabulator-edit-list-group { + border-bottom: 1px solid var(--tabulator-row-border-color); + color: var(--tabulator-row-text-color); + font-weight: 700; + padding: 6px 4px 4px +} + +.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-2, +.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-2 { + padding-left: 12px +} + +.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-3, +.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-3 { + padding-left: 20px +} + +.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-4, +.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-4 { + padding-left: 28px +} + +.tabulator-edit-list .tabulator-edit-list-group.tabulator-edit-list-group-level-5, +.tabulator-edit-list .tabulator-edit-list-item.tabulator-edit-list-group-level-5 { + padding-left: 36px +} + +.tabulator.tabulator-ltr { + direction: ltr +} + +.tabulator.tabulator-rtl { + direction: rtl; + text-align: initial +} + +.tabulator.tabulator-rtl .tabulator-header .tabulator-col { + border-left: 1px solid var(--tabulator-header-border-color); + border-right: initial; + text-align: initial +} + +.tabulator.tabulator-rtl .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols { + margin-left: -1px; + margin-right: 0 +} + +.tabulator.tabulator-rtl .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title { + padding-left: 25px; + padding-right: 0 +} + +.tabulator.tabulator-rtl .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter { + left: 8px; + right: auto +} + +.tabulator.tabulator-rtl .tabulator-tableholder .tabulator-range-overlay .tabulator-range.tabulator-range-active:after { + background-color: var(--tabulator-range-handle-color); + border-radius: 999px; + bottom: -3px; + content: ""; + height: 6px; + left: -3px; + position: absolute; + right: auto; + width: 6px +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-cell { + border-left: 1px solid var(--tabulator-row-border-color); + border-right: initial +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-cell .tabulator-data-tree-branch { + border-bottom-left-radius: 0; + border-bottom-right-radius: 1px; + border-left: initial; + border-right: 2px solid var(--tabulator-row-border-color); + margin-left: 5px; + margin-right: 0 +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-cell .tabulator-data-tree-control { + margin-left: 5px; + margin-right: 0 +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-left { + border-left: 2px solid var(--tabulator-row-border-color) +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-cell.tabulator-frozen.tabulator-frozen-right { + border-right: 2px solid var(--tabulator-row-border-color) +} + +.tabulator.tabulator-rtl .tabulator-row .tabulator-col-resize-handle:last-of-type { + margin-left: 0; + margin-right: -3px; + width: 3px +} + +.tabulator.tabulator-rtl .tabulator-footer .tabulator-calcs-holder { + text-align: initial +} + +.tabulator-print-fullscreen { + bottom: 0; + left: 0; + position: absolute; + right: 0; + top: 0; + z-index: 10000 +} + +body.tabulator-print-fullscreen-hide>:not(.tabulator-print-fullscreen) { + display: none !important +} + +.tabulator-print-table { + border-collapse: collapse +} + +.tabulator-print-table .tabulator-data-tree-branch { + border-bottom: 2px solid var(--tabulator-row-border-color); + border-bottom-left-radius: 1px; + border-left: 2px solid var(--tabulator-row-border-color); + display: inline-block; + height: 9px; + margin-right: 5px; + margin-top: -9px; + vertical-align: middle; + width: 7px +} + +.tabulator-print-table .tabulator-print-table-group { + background: #ccc; + border-bottom: 1px solid #999; + border-right: 1px solid var(--tabulator-row-border-color); + border-top: 1px solid #999; + box-sizing: border-box; + font-weight: 700; + min-width: 100%; + padding: 5px 5px 5px 10px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-print-table .tabulator-print-table-group:hover { + background-color: rgba(0, 0, 0, .1); + cursor: pointer + } +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-visible .tabulator-arrow { + border-bottom: 0; + border-left: 6px solid transparent; + border-right: 6px solid transparent; + border-top: 6px solid var(--tabulator-sort-arrow-active); + margin-right: 10px +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-1 td { + padding-left: 30px !important +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-2 td { + padding-left: 50px !important +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-3 td { + padding-left: 70px !important +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-4 td { + padding-left: 90px !important +} + +.tabulator-print-table .tabulator-print-table-group.tabulator-group-level-5 td { + padding-left: 110px !important +} + +.tabulator-print-table .tabulator-print-table-group .tabulator-group-toggle { + display: inline-block +} + +.tabulator-print-table .tabulator-print-table-group .tabulator-arrow { + border-bottom: 6px solid transparent; + border-left: 6px solid var(--tabulator-sort-arrow-active); + border-right: 0; + border-top: 6px solid transparent; + display: inline-block; + height: 0; + margin-right: 16px; + vertical-align: middle; + width: 0 +} + +.tabulator-print-table .tabulator-print-table-group span { + color: #d00 +} + +.tabulator-print-table .tabulator-data-tree-control { + align-items: center; + background: rgba(0, 0, 0, .1); + border: 1px solid var(--tabulator-row-text-color); + border-radius: 2px; + display: inline-flex; + height: 11px; + justify-content: center; + margin-right: 5px; + overflow: hidden; + vertical-align: middle; + width: 11px +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-print-table .tabulator-data-tree-control:hover { + background: rgba(0, 0, 0, .2); + cursor: pointer + } +} + +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse { + background: transparent; + display: inline-block; + height: 7px; + position: relative; + width: 1px +} + +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after { + background: var(--tabulator-row-text-color); + content: ""; + height: 1px; + left: -3px; + position: absolute; + top: 3px; + width: 7px +} + +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand { + background: var(--tabulator-row-text-color); + display: inline-block; + height: 7px; + position: relative; + width: 1px +} + +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand:after { + background: var(--tabulator-row-text-color); + content: ""; + height: 1px; + left: -3px; + position: absolute; + top: 3px; + width: 7px +} + +.tabulator { + background-color: var(--tabulator-background-color); + max-width: 100%; + width: 100% +} + +.tabulator .tabulator-header { + color: inherit +} + +.tabulator .tabulator-header .tabulator-col { + border-top: none +} + +.tabulator .tabulator-header .tabulator-col:first-of-type { + border-left: none +} + +.tabulator .tabulator-header .tabulator-col:last-of-type { + border-right: none +} + +.tabulator .tabulator-header .tabulator-col:not(first-of-type), +.tabulator .tabulator-header .tabulator-col:not(last-of-type) { + border-right: 1px solid var(--tabulator-header-border-color) +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content { + padding: var(--tabulator-cell-padding) +} + +.tabulator .tabulator-header .tabulator-col .tabulator-col-content .tabulator-col-sorter { + right: -10px +} + +.tabulator .tabulator-header .tabulator-col.tabulator-col-group .tabulator-col-group-cols { + border-top: 1px solid var(--tabulator-border-color) +} + +.tabulator .tabulator-header .tabulator-col.tabulator-sortable .tabulator-col-title { + padding-right: 10px +} + +.tabulator .tabulator-header .tabulator-calcs-holder { + border-bottom: 1px solid var(--tabulator-header-separator-color); + width: 100% +} + +.tabulator .tabulator-header .tabulator-frozen-rows-holder { + min-width: 600% +} + +.tabulator .tabulator-header .tabulator-frozen-rows-holder:empty { + display: none +} + +.tabulator .tabulator-header .tabulator-frozen .tabulator-frozen-left, +.tabulator .tabulator-header .tabulator-frozen .tabulator-frozen-right { + background: inherit +} + +.tabulator .tabulator-tableholder .tabulator-table { + color: inherit +} + +.tabulator .tabulator-footer { + background-color: var(--tabulator-footer-background-color); + color: inherit +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab { + font-weight: 400; + padding: 8px 12px +} + +.tabulator .tabulator-footer .tabulator-spreadsheet-tabs .tabulator-spreadsheet-tab.tabulator-spreadsheet-tab-active { + color: var(--tabulator-footer-active-color) +} + +.tabulator .tabulator-footer .tabulator-paginator { + color: inherit +} + +.tabulator .tabulator-footer .tabulator-page { + background: var(--tabulator-pagination-button-background); + border-radius: 0; + border-right: none; + color: var(--tabulator-pagination-button-color); + margin: 5px 0 0; + padding: 8px 12px +} + +.tabulator .tabulator-footer .tabulator-page:first-of-type, +.tabulator .tabulator-footer .tabulator-page[data-page=next] { + border-bottom-left-radius: 4px; + border-top-left-radius: 4px +} + +.tabulator .tabulator-footer .tabulator-page:last-of-type, +.tabulator .tabulator-footer .tabulator-page[data-page=prev] { + border: 1px solid var(--tabulator-footer-border-color); + border-bottom-right-radius: 4px; + border-top-right-radius: 4px +} + +.tabulator .tabulator-footer .tabulator-page:not(disabled):hover { + background: var(--tabulator-pagination-button-background-hover); + color: var(--tabulator-pagination-button-color-hover) +} + +.tabulator .tabulator-footer .tabulator-page.active, +.tabulator .tabulator-footer .tabulator-page[data-page=first] :not(disabled):not(:hover), +.tabulator .tabulator-footer .tabulator-page[data-page=last] :not(disabled):not(:hover), +.tabulator .tabulator-footer .tabulator-page[data-page=next] :not(disabled):not(:hover), +.tabulator .tabulator-footer .tabulator-page[data-page=prev] :not(disabled):not(:hover) { + color: var(--tabulator-pagination-button-color-active) +} + +.tabulator.striped .tabulator-row:nth-child(2n) { + background-color: var(--tabulator-row-alt-background-color) +} + +.tabulator.striped .tabulator-row:nth-child(2n).tabulator-selected { + background-color: var(--tabulator-row-selected-background) !important +} + +@media (hover:hover) and (pointer:fine) { + .tabulator.striped .tabulator-row:nth-child(2n).tabulator-selectable:hover { + background-color: var(--tabulator-row-hover-background); + cursor: pointer + } + + .tabulator.striped .tabulator-row:nth-child(2n).tabulator-selected:hover { + background-color: var(--tabulator-row-selected-background-hover) !important; + cursor: pointer + } +} + +.tabulator-row { + border-bottom: 1px solid var(--tabulator-row-border-color); + min-height: calc(var(--tabulator-text-size) + var(--tabulator-cell-padding)*2) +} + +.tabulator-row.tabulator-row-even { + background-color: var(--tabulator-row-background-color) +} + +.tabulator-row .tabulator-cell { + padding: var(--tabulator-cell-padding) +} + +.tabulator-row .tabulator-cell:last-of-type { + border-right: none +} + +.tabulator-row .tabulator-cell.tabulator-row-header { + background: var(--tabulator-header-background-color); + border-bottom: none; + border-right: 1px solid var(--tabulator-border-color) +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control { + border: 1px solid #ccc +} + +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after, +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand, +.tabulator-row .tabulator-cell .tabulator-data-tree-control .tabulator-data-tree-control-expand:after { + background: #ccc +} + +.tabulator-row.tabulator-group { + background: #fafafa +} + +.tabulator-row.tabulator-group span { + color: #666; + margin-left: 10px +} + +.tabulator-edit-select-list { + background: var(--tabulator-header-background-color) +} + +.tabulator-edit-select-list .tabulator-edit-select-list-item { + color: inherit +} + +.tabulator-edit-select-list .tabulator-edit-select-list-item.active { + color: var(--tabulator-header-background-color) +} + +.tabulator-edit-select-list .tabulator-edit-select-list-item.active.focused { + outline: 1px solid rgba(var(--tabulator-header-background-color), .5) +} + +@media (hover:hover) and (pointer:fine) { + .tabulator-edit-select-list .tabulator-edit-select-list-item:hover { + color: var(--tabulator-header-background-color) + } +} + +.tabulator-edit-select-list .tabulator-edit-select-list-group, +.tabulator-edit-select-list .tabulator-edit-select-list-notice { + color: inherit +} + +.tabulator.tabulator-rtl .tabulator-header .tabulator-col { + border-left: none; + border-right: none +} + +.tabulator-print-table .tabulator-print-table-group { + background: #fafafa +} + +.tabulator-print-table .tabulator-print-table-group span { + color: #666; + margin-left: 10px +} + +.tabulator-print-table .tabulator-data-tree-control { + border: 1px solid #ccc +} + +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-collapse:after, +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand, +.tabulator-print-table .tabulator-data-tree-control .tabulator-data-tree-control-expand:after { + background: #ccc +} + /*# sourceMappingURL=tabulator_pysyft.min.css.map */ \ No newline at end of file diff --git a/packages/syft/src/syft/assets/jinja/table.jinja2 b/packages/syft/src/syft/assets/jinja/table.jinja2 index 1eb580ef01a..f750a80d0ab 100644 --- a/packages/syft/src/syft/assets/jinja/table.jinja2 +++ b/packages/syft/src/syft/assets/jinja/table.jinja2 @@ -38,11 +38,13 @@ diff --git a/packages/syft/src/syft/assets/js/table.js b/packages/syft/src/syft/assets/js/table.js index 1a257627d6a..ff6c5e84d12 100644 --- a/packages/syft/src/syft/assets/js/table.js +++ b/packages/syft/src/syft/assets/js/table.js @@ -6,7 +6,6 @@ TABULATOR_CSS = document.querySelectorAll(".escape-unfocus").forEach((input) => { input.addEventListener("keydown", (event) => { if (event.key === "Escape") { - console.log("Escape key pressed"); event.stopPropagation(); input.blur(); } @@ -58,7 +57,14 @@ function load_tabulator(elementId) { }); } -function buildTable(columns, rowHeader, data, uid) { +function buildTable( + columns, + rowHeader, + data, + uid, + pagination = true, + maxHeight = null, +) { const tableId = `table-${uid}`; const searchBarId = `search-${uid}`; const numrowsId = `numrows-${uid}`; @@ -73,11 +79,13 @@ function buildTable(columns, rowHeader, data, uid) { data: data, columns: columns, rowHeader: rowHeader, + index: "_table_repr_index", layout: "fitDataStretch", resizableColumnFit: true, resizableColumnGuide: true, - pagination: "local", + pagination: pagination, paginationSize: 5, + maxHeight: maxHeight, }); // Events needed for cell overflow: @@ -100,6 +108,7 @@ function buildTable(columns, rowHeader, data, uid) { numrowsElement.innerHTML = data.length; } + configureHighlightSingleRow(table, uid); configureSearch(table, searchBarId, columns); return table; @@ -129,3 +138,64 @@ function configureSearch(table, searchBarId, columns) { table.setFilter([filterArray]); }); } + +function configureHighlightSingleRow(table, uid) { + // Listener for rowHighlight events, with fields: + // uid: string, table uid + // index: number | string, row index to highlight + // jumpToRow: bool, if true, jumps to page where the row is located + document.addEventListener("rowHighlight", function (e) { + if (e.detail.uid === uid) { + let row_idx = e.detail.index; + let rows = table.getRows(); + for (let row of rows) { + if (row.getIndex() == row_idx) { + row.select(); + if (e.detail.jumpToRow) { + table.setPageToRow(row_idx); + table.scrollToRow(row_idx, "top", false); + } + } else { + row.deselect(); + } + } + } + }); +} + +function waitForTable(uid, timeout = 1000) { + return new Promise((resolve, reject) => { + // Check if the table is ready immediately + if (window["table_" + uid]) { + resolve(); + } else { + // Otherwise, check every 100ms until the table is ready or the timeout is reached + var startTime = Date.now(); + var checkTableInterval = setInterval(function () { + if (window["table_" + uid]) { + clearInterval(checkTableInterval); + resolve(); + } else if (Date.now() - startTime > timeout) { + clearInterval(checkTableInterval); + reject(`Timeout: table_"${uid}" not found.`); + } + }, 100); + } + }); +} + +function highlightSingleRow(uid, index = null, jumpToRow = false) { + // Highlight a single row in the table with the given uid + // If index is not provided or doesn't exist, all rows are deselected + waitForTable(uid) + .then(() => { + document.dispatchEvent( + new CustomEvent("rowHighlight", { + detail: { uid, index, jumpToRow }, + }), + ); + }) + .catch((error) => { + console.log(error); + }); +} diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 3cb360eafa4..f176f852afa 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -11,6 +11,7 @@ from typing import Literal # third party +import ipywidgets from loguru import logger import pandas as pd from pydantic import model_validator @@ -1118,6 +1119,12 @@ class NodeDiff(SyftObject): include_ignored: bool = False + def resolve(self) -> ipywidgets.Widget: + # relative + from .resolve_widget import PaginatedResolveWidget + + return PaginatedResolveWidget(batches=self.batches).build() + def __getitem__(self, idx: Any) -> ObjectDiffBatch: return self.batches[idx] diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index dd9dadc505e..9d0f21cdcda 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -1,11 +1,14 @@ # stdlib +from collections.abc import Callable from enum import Enum from enum import auto import html +import secrets from typing import Any from uuid import uuid4 # third party +from IPython import display import ipywidgets as widgets from ipywidgets import Button from ipywidgets import Checkbox @@ -22,6 +25,8 @@ from ...util.notebook_ui.components.sync import MainDescription from ...util.notebook_ui.components.sync import SyncWidgetHeader from ...util.notebook_ui.components.sync import TypeLabel +from ...util.notebook_ui.components.tabulator_template import build_tabulator_table +from ...util.notebook_ui.components.tabulator_template import highlight_single_row from ...util.notebook_ui.styles import CSS_CODE from ..action.action_object import ActionObject from ..api.api import TwinAPIEndpoint @@ -590,3 +595,158 @@ def separator(self) -> widgets.HTML: def build_header(self) -> HTML: header_html = SyncWidgetHeader(diff_batch=self.obj_diff_batch).to_html() return HTML(value=header_html) + + +class PaginationControl: + def __init__(self, data: list, callback: Callable[[int], None]): + self.data = data + self.callback = callback + self.current_index = 0 + self.index_label = widgets.Label(value=f"Index: {self.current_index}") + + self.first_button = widgets.Button(description="First") + self.previous_button = widgets.Button(description="Previous") + self.next_button = widgets.Button(description="Next") + self.last_button = widgets.Button(description="Last") + + self.first_button.on_click(self.go_to_first) + self.previous_button.on_click(self.go_to_previous) + self.next_button.on_click(self.go_to_next) + self.last_button.on_click(self.go_to_last) + self.output = widgets.Output() + + self.buttons = widgets.HBox( + [ + self.first_button, + self.previous_button, + self.next_button, + self.last_button, + ] + ) + self.update_buttons() + self.update_index_callback() + + def update_index_label(self) -> None: + self.index_label.value = f"Current: {self.current_index}" + + def update_buttons(self) -> None: + self.first_button.disabled = self.current_index == 0 + self.previous_button.disabled = self.current_index == 0 + self.next_button.disabled = self.current_index == len(self.data) - 1 + self.last_button.disabled = self.current_index == len(self.data) - 1 + + def go_to_first(self, b: Button) -> None: + self.current_index = 0 + self.update_index_callback() + + def go_to_previous(self, b: Button) -> None: + if self.current_index > 0: + self.current_index -= 1 + self.update_index_callback() + + def go_to_next(self, b: Button) -> None: + if self.current_index < len(self.data) - 1: + self.current_index += 1 + self.update_index_callback() + + def go_to_last(self, b: Button) -> None: + self.current_index = len(self.data) - 1 + self.update_index_callback() + + def update_index_callback(self) -> None: + self.update_index_label() + self.update_buttons() + + # NOTE self.output is required to display IPython.display.HTML + # IPython.display.HTML is used to execute JS code + with self.output: + self.callback(self.current_index) + + def build(self) -> widgets.VBox: + return widgets.VBox( + [widgets.HBox([self.buttons, self.index_label]), self.output] + ) + + +class PaginatedWidget: + def __init__( + self, children: list, on_paginate_callback: Callable[[int], None] | None = None + ): + # on_paginate_callback is an optional secondary callback, + # called after updating the page index and displaying the new widget + self.children = children + self.on_paginate_callback = on_paginate_callback + self.current_index = 0 + self.container = widgets.VBox() + + self.pagination_control = PaginationControl(children, self.on_paginate) + + # Initial display + self.on_paginate(self.pagination_control.current_index) + + def __getitem__(self, index: int) -> widgets.Widget: + return self.children[index] + + def on_paginate(self, index: int) -> None: + self.container.children = [self.children[index]] + if self.on_paginate_callback: + self.on_paginate_callback(index) + + def build(self) -> widgets.VBox: + return widgets.VBox([self.pagination_control.build(), self.container]) + + +class PaginatedResolveWidget: + """ + PaginatedResolveWidget is a widget that displays + a ResolveWidget for each ObjectDiffBatch, + paginated by a PaginationControl widget. + """ + + def __init__(self, batches: list[ObjectDiffBatch]): + self.batches = batches + self.resolve_widgets = [ + ResolveWidget(obj_diff_batch=batch) for batch in self.batches + ] + + self.table_uid = secrets.token_hex(4) + + # Disable the table pagination to avoid the double pagination buttons + self.batch_table = build_tabulator_table( + obj=batches, + uid=self.table_uid, + max_height=500, + pagination=False, + ) + + self.paginated_widget = PaginatedWidget( + children=[widget.widget for widget in self.resolve_widgets], + on_paginate_callback=self.on_paginate, + ) + + self.table_output = widgets.Output() + with self.table_output: + display.display(display.HTML(self.batch_table)) + highlight_single_row( + self.table_uid, self.paginated_widget.current_index, jump_to_row=True + ) + + def on_paginate(self, index: int) -> None: + return highlight_single_row(self.table_uid, index, jump_to_row=True) + + def build(self) -> widgets.VBox: + return widgets.VBox([self.table_output, self.paginated_widget.build()]) + + def click_sync(self, index: int) -> SyftSuccess | SyftError: + return self.resolve_widgets[index].click_sync() + + def click_share_all_private_data(self, index: int) -> None: + self.resolve_widgets[index].click_share_all_private_data() + + def _share_all(self) -> None: + for widget in self.resolve_widgets: + widget.click_share_all_private_data() + + def _sync_all(self) -> None: + for widget in self.resolve_widgets: + widget.click_sync() diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 60ba8da4915..e79016ab7a9 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -92,7 +92,12 @@ def format_table_data(table_data: list[dict[str, Any]]) -> list[dict[str, str]]: return formatted -def build_tabulator_table(obj: Any) -> str | None: +def build_tabulator_table( + obj: Any, + uid: str | None = None, + max_height: int | None = None, + pagination: bool = True, +) -> str | None: try: table_data, table_metadata = prepare_table_data(obj) if len(table_data) == 0: @@ -113,10 +118,11 @@ def build_tabulator_table(obj: Any) -> str | None: if icon is None: icon = Icon.TABLE.svg + uid = uid if uid is not None else secrets.token_hex(4) column_data, row_header = create_tabulator_columns(table_metadata["columns"]) table_data = format_table_data(table_data) table_html = table_template.render( - uid=secrets.token_hex(4), + uid=uid, columns=json.dumps(column_data), row_header=json.dumps(row_header), data=json.dumps(table_data), @@ -127,6 +133,8 @@ def build_tabulator_table(obj: Any) -> str | None: name=table_metadata["name"], tabulator_js=tabulator_js, tabulator_css=tabulator_css, + max_height=json.dumps(max_height), + pagination=json.dumps(pagination), ) return table_html @@ -140,3 +148,12 @@ def show_table(obj: Any) -> None: table = build_tabulator_table(obj) if table is not None: display(HTML(table)) + + +def highlight_single_row( + table_uid: str, + index: int | str | None = None, + jump_to_row: bool = True, +) -> None: + js_code = f"" + display(HTML(js_code)) From 667c4d4d0c940cc4036bc38b003cf0f4816029a8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 27 May 2024 10:34:34 +0200 Subject: [PATCH 057/309] minor fixes --- packages/syft/src/syft/service/sync/diff_state.py | 2 +- packages/syft/src/syft/service/sync/resolve_widget.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index f176f852afa..1b2b36bf2f6 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1123,7 +1123,7 @@ def resolve(self) -> ipywidgets.Widget: # relative from .resolve_widget import PaginatedResolveWidget - return PaginatedResolveWidget(batches=self.batches).build() + return PaginatedResolveWidget(batches=self.batches) def __getitem__(self, idx: Any) -> ObjectDiffBatch: return self.batches[idx] diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 9d0f21cdcda..41aa071be18 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -731,6 +731,11 @@ def __init__(self, batches: list[ObjectDiffBatch]): self.table_uid, self.paginated_widget.current_index, jump_to_row=True ) + self.widget = self.build() + + def __getitem__(self, index: int) -> ResolveWidget: + return self.resolve_widgets[index] + def on_paginate(self, index: int) -> None: return highlight_single_row(self.table_uid, index, jump_to_row=True) @@ -750,3 +755,6 @@ def _share_all(self) -> None: def _sync_all(self) -> None: for widget in self.resolve_widgets: widget.click_sync() + + def _repr_mimebundle_(self, **kwargs: dict) -> dict[str, str] | None: + return self.widget._repr_mimebundle_(**kwargs) From 6262647b50adebc129475264211011226d5c51bd Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 27 May 2024 14:51:55 +0200 Subject: [PATCH 058/309] add decision column, disable sort --- .../syft/src/syft/assets/jinja/table.jinja2 | 1 + packages/syft/src/syft/assets/js/table.js | 29 ++++++++++++- .../syft/src/syft/service/sync/diff_state.py | 19 ++++++++- .../src/syft/service/sync/resolve_widget.py | 41 +++++++++++++++---- .../components/tabulator_template.py | 19 ++++++++- 5 files changed, 97 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/assets/jinja/table.jinja2 b/packages/syft/src/syft/assets/jinja/table.jinja2 index f750a80d0ab..7a44d798540 100644 --- a/packages/syft/src/syft/assets/jinja/table.jinja2 +++ b/packages/syft/src/syft/assets/jinja/table.jinja2 @@ -45,6 +45,7 @@ "{{ uid }}", pagination={{ pagination }}, maxHeight={{ max_height }}, + headerSort={{ header_sort }}, ) diff --git a/packages/syft/src/syft/assets/js/table.js b/packages/syft/src/syft/assets/js/table.js index ff6c5e84d12..35fee482bd9 100644 --- a/packages/syft/src/syft/assets/js/table.js +++ b/packages/syft/src/syft/assets/js/table.js @@ -64,6 +64,7 @@ function buildTable( uid, pagination = true, maxHeight = null, + headerSort = true, ) { const tableId = `table-${uid}`; const searchBarId = `search-${uid}`; @@ -86,6 +87,7 @@ function buildTable( pagination: pagination, paginationSize: 5, maxHeight: maxHeight, + headerSort: headerSort, }); // Events needed for cell overflow: @@ -152,7 +154,8 @@ function configureHighlightSingleRow(table, uid) { if (row.getIndex() == row_idx) { row.select(); if (e.detail.jumpToRow) { - table.setPageToRow(row_idx); + // catch promise in case the table does not have pagination + table.setPageToRow(row_idx).catch((_) => {}); table.scrollToRow(row_idx, "top", false); } } else { @@ -169,7 +172,7 @@ function waitForTable(uid, timeout = 1000) { if (window["table_" + uid]) { resolve(); } else { - // Otherwise, check every 100ms until the table is ready or the timeout is reached + // Otherwise, poll until the table is ready or timeout var startTime = Date.now(); var checkTableInterval = setInterval(function () { if (window["table_" + uid]) { @@ -199,3 +202,25 @@ function highlightSingleRow(uid, index = null, jumpToRow = false) { console.log(error); }); } + +function updateTableCell(uid, index, field, value) { + // Update the value of a cell in the table with the given uid + waitForTable(uid) + .then(() => { + const table = window["table_" + uid]; + if (!table) { + throw new Error(`Table with uid ${uid} not found.`); + } + + const row = table.getRow(index); + if (!row) { + throw new Error(`Row with index ${index} not found.`); + } + + // Update the cell value + row.update({ [field]: value }); + }) + .catch((error) => { + console.error(error); + }); +} diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 1b2b36bf2f6..efef97e6335 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -39,6 +39,7 @@ from ...types.uid import UID from ...util import options from ...util.colors import SURFACE +from ...util.notebook_ui.components.sync import Label from ...util.notebook_ui.components.sync import SyncTableObject from ...util.notebook_ui.icons import Icon from ...util.notebook_ui.styles import FONT_CSS @@ -705,6 +706,21 @@ def root_id(self) -> UID: def root_type(self) -> type: return self.root_diff.obj_type + def decision_badge(self) -> str: + if self.decision is None: + return "" + if self.decision == SyncDecision.IGNORE: + decision_str = "IGNORED" + badge_color = "label-red" + if self.decision == SyncDecision.SKIP: + decision_str = "SKIPPED" + badge_color = "label-gray" + else: + decision_str = "SYNCED" + badge_color = "label-green" + + return Label(value=decision_str, label_class=badge_color).to_html() + @property def is_ignored(self) -> bool: return self.decision == SyncDecision.IGNORE @@ -847,9 +863,10 @@ def _coll_repr_(self) -> dict[str, Any]: high_html = SyncTableObject(object=self.root_diff.high_obj).to_html() return { - "Merge status": self.status_badge(), + "Diff status": self.status_badge(), "Public Sync State": low_html, "Private sync state": high_html, + "Decision": self.decision_badge(), } @property diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 41aa071be18..60856d752f2 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -2,6 +2,7 @@ from collections.abc import Callable from enum import Enum from enum import auto +from functools import partial import html import secrets from typing import Any @@ -27,6 +28,7 @@ from ...util.notebook_ui.components.sync import TypeLabel from ...util.notebook_ui.components.tabulator_template import build_tabulator_table from ...util.notebook_ui.components.tabulator_template import highlight_single_row +from ...util.notebook_ui.components.tabulator_template import update_table_cell from ...util.notebook_ui.styles import CSS_CODE from ..action.action_object import ActionObject from ..api.api import TwinAPIEndpoint @@ -411,11 +413,14 @@ def _on_share_private_data_change(self, change: Any) -> None: class ResolveWidget: - def __init__(self, obj_diff_batch: ObjectDiffBatch): + def __init__( + self, obj_diff_batch: ObjectDiffBatch, on_sync_callback: Callable | None = None + ): self.obj_diff_batch: ObjectDiffBatch = obj_diff_batch self.id2widget: dict[ UID, CollapsableObjectDiffWidget | MainObjectDiffWidget ] = {} + self.on_sync_callback = on_sync_callback self.main_widget = self.build() self.result_widget = VBox() # Placeholder for SyftSuccess / SyftError self.widget = VBox( @@ -468,6 +473,8 @@ def click_sync(self, *args: list, **kwargs: dict) -> SyftSuccess | SyftError: ) self.set_widget_result_state(res) + if self.on_sync_callback: + self.on_sync_callback() return res @property @@ -635,21 +642,21 @@ def update_buttons(self) -> None: self.next_button.disabled = self.current_index == len(self.data) - 1 self.last_button.disabled = self.current_index == len(self.data) - 1 - def go_to_first(self, b: Button) -> None: + def go_to_first(self, b: Button | None) -> None: self.current_index = 0 self.update_index_callback() - def go_to_previous(self, b: Button) -> None: + def go_to_previous(self, b: Button | None) -> None: if self.current_index > 0: self.current_index -= 1 self.update_index_callback() - def go_to_next(self, b: Button) -> None: + def go_to_next(self, b: Button | None) -> None: if self.current_index < len(self.data) - 1: self.current_index += 1 self.update_index_callback() - def go_to_last(self, b: Button) -> None: + def go_to_last(self, b: Button | None) -> None: self.current_index = len(self.data) - 1 self.update_index_callback() @@ -705,8 +712,12 @@ class PaginatedResolveWidget: def __init__(self, batches: list[ObjectDiffBatch]): self.batches = batches - self.resolve_widgets = [ - ResolveWidget(obj_diff_batch=batch) for batch in self.batches + self.resolve_widgets: list[ResolveWidget] = [ + ResolveWidget( + obj_diff_batch=batch, + on_sync_callback=partial(self.on_click_sync, i), + ) + for i, batch in enumerate(self.batches) ] self.table_uid = secrets.token_hex(4) @@ -717,6 +728,7 @@ def __init__(self, batches: list[ObjectDiffBatch]): uid=self.table_uid, max_height=500, pagination=False, + header_sort=False, ) self.paginated_widget = PaginatedWidget( @@ -733,6 +745,21 @@ def __init__(self, batches: list[ObjectDiffBatch]): self.widget = self.build() + def on_click_sync(self, index: int) -> None: + self.update_table_sync_decision(index) + if self.batches[index].decision is not None: + self.paginated_widget.pagination_control.go_to_next(None) + + def update_table_sync_decision(self, index: int) -> None: + new_decision = self.batches[index].decision_badge() + with self.table_output: + update_table_cell( + uid=self.table_uid, + index=index, + field="Decision", + value=new_decision, + ) + def __getitem__(self, index: int) -> ResolveWidget: return self.resolve_widgets[index] diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index e79016ab7a9..ee0576cc206 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -23,6 +23,7 @@ def create_tabulator_columns( column_names: list[str], column_widths: dict | None = None, + header_sort: bool = True, ) -> tuple[list[dict], dict | None]: """Returns tuple of (columns, row_header) for tabulator table""" if column_widths is None: @@ -33,10 +34,10 @@ def create_tabulator_columns( if TABLE_INDEX_KEY in column_names: row_header = { "field": TABLE_INDEX_KEY, - "headerSort": True, "frozen": True, "widthGrow": 0.3, "minWidth": 60, + "headerSort": header_sort, } for colname in column_names: @@ -48,6 +49,7 @@ def create_tabulator_columns( "resizable": True, "minWidth": 60, "maxInitialWidth": 500, + "headerSort": header_sort, } if colname in column_widths: column["widthGrow"] = column_widths[colname] @@ -97,6 +99,7 @@ def build_tabulator_table( uid: str | None = None, max_height: int | None = None, pagination: bool = True, + header_sort: bool = True, ) -> str | None: try: table_data, table_metadata = prepare_table_data(obj) @@ -119,7 +122,9 @@ def build_tabulator_table( icon = Icon.TABLE.svg uid = uid if uid is not None else secrets.token_hex(4) - column_data, row_header = create_tabulator_columns(table_metadata["columns"]) + column_data, row_header = create_tabulator_columns( + table_metadata["columns"], header_sort=header_sort + ) table_data = format_table_data(table_data) table_html = table_template.render( uid=uid, @@ -135,6 +140,7 @@ def build_tabulator_table( tabulator_css=tabulator_css, max_height=json.dumps(max_height), pagination=json.dumps(pagination), + header_sort=json.dumps(header_sort), ) return table_html @@ -157,3 +163,12 @@ def highlight_single_row( ) -> None: js_code = f"" display(HTML(js_code)) + + +def update_table_cell(uid: str, index: int, field: str, value: str) -> None: + js_code = f""" + + """ + display(HTML(js_code)) From 17cc6d384af599394bc2bbda0aacacb0e79b7f66 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 27 May 2024 15:08:08 +0200 Subject: [PATCH 059/309] fix old resolve method + typing --- packages/syft/src/syft/client/syncing.py | 12 +++++++++--- packages/syft/src/syft/service/sync/diff_state.py | 8 ++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 7185cb5316e..52bf80a5486 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -86,9 +86,15 @@ def get_user_input_for_resolve() -> SyncDecision: print(f"Please choose between {options_str}") -def resolve(obj_diff_batch: ObjectDiffBatch) -> ResolveWidget: - widget = ResolveWidget(obj_diff_batch) - return widget +def resolve(obj: ObjectDiffBatch | NodeDiff) -> ResolveWidget: + if isinstance(obj, NodeDiff): + return obj.resolve() + elif isinstance(obj, ObjectDiffBatch): + return ResolveWidget(obj) + else: + raise ValueError( + f"Invalid type: could not resolve object with type {type(obj).__qualname__}" + ) @deprecated(reason="resolve_single has been renamed to resolve", return_syfterror=True) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index efef97e6335..e91b405136a 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -9,9 +9,9 @@ from typing import Any from typing import ClassVar from typing import Literal +from typing import TYPE_CHECKING # third party -import ipywidgets from loguru import logger import pandas as pd from pydantic import model_validator @@ -61,6 +61,10 @@ from ..user.user import UserView from .sync_state import SyncState +if TYPE_CHECKING: + # relative + from .resolve_widget import PaginatedResolveWidget + sketchy_tab = "‎ " * 4 @@ -1136,7 +1140,7 @@ class NodeDiff(SyftObject): include_ignored: bool = False - def resolve(self) -> ipywidgets.Widget: + def resolve(self) -> "PaginatedResolveWidget": # relative from .resolve_widget import PaginatedResolveWidget From 810bd655210d73af61f3abd6d319b7c3540e3e71 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 27 May 2024 16:36:16 +0200 Subject: [PATCH 060/309] add comments --- packages/syft/src/syft/client/syncing.py | 29 +++++-------------- .../syft/src/syft/service/sync/diff_state.py | 7 +++++ .../src/syft/service/sync/resolve_widget.py | 7 ++++- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 52bf80a5486..428117634ef 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -8,6 +8,7 @@ from ..service.sync.diff_state import NodeDiff from ..service.sync.diff_state import ObjectDiffBatch from ..service.sync.diff_state import SyncInstruction +from ..service.sync.resolve_widget import PaginatedResolveWidget from ..service.sync.resolve_widget import ResolveWidget from ..service.sync.sync_state import SyncState from ..types.uid import UID @@ -71,34 +72,18 @@ def compare_clients( ) -def get_user_input_for_resolve() -> SyncDecision: - options = [x.value for x in SyncDecision] - options_str = ", ".join(options[:-1]) + f" or {options[-1]}" - print(f"How do you want to sync these objects? choose between {options_str}") - - while True: - decision = input() - decision = decision.lower() - - try: - return SyncDecision(decision) - except ValueError: - print(f"Please choose between {options_str}") - - -def resolve(obj: ObjectDiffBatch | NodeDiff) -> ResolveWidget: - if isinstance(obj, NodeDiff): - return obj.resolve() - elif isinstance(obj, ObjectDiffBatch): - return ResolveWidget(obj) - else: +def resolve(obj: ObjectDiffBatch | NodeDiff) -> ResolveWidget | PaginatedResolveWidget: + if not isinstance(obj, ObjectDiffBatch | NodeDiff): raise ValueError( f"Invalid type: could not resolve object with type {type(obj).__qualname__}" ) + return obj.resolve() @deprecated(reason="resolve_single has been renamed to resolve", return_syfterror=True) -def resolve_single(obj_diff_batch: ObjectDiffBatch) -> ResolveWidget: +def resolve_single( + obj_diff_batch: ObjectDiffBatch, +) -> ResolveWidget | PaginatedResolveWidget: return resolve(obj_diff_batch) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index e91b405136a..f943174d75a 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -64,6 +64,7 @@ if TYPE_CHECKING: # relative from .resolve_widget import PaginatedResolveWidget + from .resolve_widget import ResolveWidget sketchy_tab = "‎ " * 4 @@ -560,6 +561,12 @@ class ObjectDiffBatch(SyftObject): root_diff: ObjectDiff sync_direction: SyncDirection | None + def resolve(self) -> "ResolveWidget": + # relative + from .resolve_widget import ResolveWidget + + return ResolveWidget(self) + def walk_graph( self, deps: dict[UID, list[UID]], diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 60856d752f2..7d683139aae 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -699,8 +699,13 @@ def on_paginate(self, index: int) -> None: if self.on_paginate_callback: self.on_paginate_callback(index) + def spacer(self, height: int) -> widgets.HTML: + return widgets.HTML(f"
") + def build(self) -> widgets.VBox: - return widgets.VBox([self.pagination_control.build(), self.container]) + return widgets.VBox( + [self.pagination_control.build(), self.spacer(8), self.container] + ) class PaginatedResolveWidget: From 307c4b8f6a5a3b0d60d75d82f00c2947594a2994 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 27 May 2024 16:50:06 +0200 Subject: [PATCH 061/309] fix tests --- packages/syft/src/syft/service/sync/resolve_widget.py | 2 +- .../syft/tests/syft/service/sync/sync_resolve_single_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 7d683139aae..261ed28e075 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -719,7 +719,7 @@ def __init__(self, batches: list[ObjectDiffBatch]): self.batches = batches self.resolve_widgets: list[ResolveWidget] = [ ResolveWidget( - obj_diff_batch=batch, + batch, on_sync_callback=partial(self.on_click_sync, i), ) for i, batch in enumerate(self.batches) diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index b3972532521..07585c6de87 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -36,7 +36,7 @@ def compare_and_resolve( diff_state_before = compare_clients(from_client, to_client) for obj_diff_batch in diff_state_before.active_batches: widget = resolve( - obj_diff_batch=obj_diff_batch, + obj_diff_batch, ) if decision_callback: decision = decision_callback(obj_diff_batch) From 5a931f675da29a04c2b85ed0301bcb838da65939 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Thu, 16 May 2024 17:23:29 -0300 Subject: [PATCH 062/309] Use autosplat for settings update. Add docstring to Update service --- .../syft/service/settings/settings_service.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 30e8242ceb9..f93a019e26a 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -61,10 +61,36 @@ def set( else: return SyftError(message=result.err()) - @service_method(path="settings.update", name="update") + @service_method(path="settings.update", name="update", autosplat=["settings"]) def update( self, context: AuthedServiceContext, settings: NodeSettingsUpdate ) -> Result[SyftSuccess, SyftError]: + """ + Update the Node Settings using the provided values. + + Args: + name: Optional[str] + Node name + organization: Optional[str] + Organization name + description: Optional[str] + Node description + on_board: Optional[bool] + Show onboarding panel when a user logs in for the first time + signup_enabled: Optional[bool] + Enable/Disable registration + admin_email: Optional[str] + Administrator email + association_request_auto_approval: Optional[bool] + + Returns: + Result[SyftSuccess, SyftError]: A result indicating the success or failure of the update operation. + + Example: + >>> node_client.update(name='foo', organization='bar', description='baz', signup_enabled=True) + SyftSuccess: Settings updated successfully. + """ + result = self.stash.get_all(context.credentials) if result.is_ok(): current_settings = result.ok() From b1d96e43ab6d38f8920a6e426b9099a48676de3a Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 20 May 2024 12:27:10 -0300 Subject: [PATCH 063/309] Fix autosplat when it's called from inner services --- packages/syft/src/syft/service/service.py | 3 ++- packages/syft/src/syft/service/settings/settings_service.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index c28fc1157d1..e7ccd5c25ad 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -340,6 +340,7 @@ def wrapper(func: Any) -> Callable: _path = class_name + "." + func_name signature = inspect.signature(func) signature = signature_remove_self(signature) + signature_with_context = deepcopy(signature) signature = signature_remove_context(signature) input_signature = deepcopy(signature) @@ -353,7 +354,7 @@ def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: ) if autosplat is not None and len(autosplat) > 0: args, kwargs = reconstruct_args_kwargs( - signature=input_signature, + signature=signature_with_context, autosplat=autosplat, args=args, kwargs=kwargs, diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index f93a019e26a..80ac0cdcc55 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -165,7 +165,7 @@ def allow_guest_signup( result = method(context=context, settings=settings) - if result.is_err(): + if isinstance(result, SyftError): return SyftError(message=f"Failed to update settings: {result.err()}") message = "enabled" if enable else "disabled" From 0245797694e60a3d8ccdc59bfc4bd1d9a53ee808 Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 20 May 2024 14:32:22 -0300 Subject: [PATCH 064/309] Fix unit/notebook tests --- packages/syft/src/syft/service/service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index e7ccd5c25ad..9576ac2b92e 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -288,6 +288,7 @@ def reconstruct_args_kwargs( final_kwargs[param_key] = param.default else: raise Exception(f"Missing {param_key} not in kwargs.") + final_kwargs['context'] = kwargs['context'] if 'context' in kwargs else None return (args, final_kwargs) @@ -340,7 +341,6 @@ def wrapper(func: Any) -> Callable: _path = class_name + "." + func_name signature = inspect.signature(func) signature = signature_remove_self(signature) - signature_with_context = deepcopy(signature) signature = signature_remove_context(signature) input_signature = deepcopy(signature) @@ -354,7 +354,7 @@ def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: ) if autosplat is not None and len(autosplat) > 0: args, kwargs = reconstruct_args_kwargs( - signature=signature_with_context, + signature=input_signature, autosplat=autosplat, args=args, kwargs=kwargs, From ba11b4dc769f880bd2d456ae70c0f7ebda0b8c6d Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 20 May 2024 15:27:42 -0300 Subject: [PATCH 065/309] Update settings test --- .../syft/tests/syft/settings/settings_service_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index 13592539a7f..56bf414f373 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -151,7 +151,7 @@ def mock_stash_get_all(root_verify_key) -> Ok: monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all) # update the settings in the settings stash using settings_service - response = settings_service.update(authed_context, update_settings) + response = settings_service.update(context=authed_context, settings=update_settings) # not_updated_settings = response.ok()[1] @@ -174,7 +174,7 @@ def mock_stash_get_all_error(credentials) -> Err: return Err(mock_error_message) monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all_error) - response = settings_service.update(authed_context, update_settings) + response = settings_service.update(context=authed_context, settings=update_settings) assert isinstance(response, SyftError) assert response.message == mock_error_message @@ -185,7 +185,7 @@ def test_settingsservice_update_stash_empty( update_settings: NodeSettingsUpdate, authed_context: AuthedServiceContext, ) -> None: - response = settings_service.update(authed_context, update_settings) + response = settings_service.update(context=authed_context, settings=update_settings) assert isinstance(response, SyftError) assert response.message == "No settings found" @@ -214,7 +214,7 @@ def mock_stash_update_error(credentials, update_settings: NodeSettings) -> Err: monkeypatch.setattr(settings_service.stash, "update", mock_stash_update_error) - response = settings_service.update(authed_context, update_settings) + response = settings_service.update(context=authed_context, settings=update_settings) assert isinstance(response, SyftError) assert response.message == mock_update_error_message From 552a8ed2e85f7e37f5efca1b738f92005b593256 Mon Sep 17 00:00:00 2001 From: alexnicita Date: Fri, 24 May 2024 15:53:30 -0400 Subject: [PATCH 066/309] rename new_project.send() to new_project.start() --- notebooks/api/0.8/01-submit-code.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/api/0.8/01-submit-code.ipynb b/notebooks/api/0.8/01-submit-code.ipynb index 761d1a96e7a..8448680171b 100644 --- a/notebooks/api/0.8/01-submit-code.ipynb +++ b/notebooks/api/0.8/01-submit-code.ipynb @@ -482,7 +482,7 @@ "outputs": [], "source": [ "# Once we start the project, it will submit the project along with the code request to the Domain Server\n", - "project = new_project.send()\n", + "project = new_project.start()\n", "project" ] }, From 46a0a765b4817d229c7242d33062a2cce0341626 Mon Sep 17 00:00:00 2001 From: S Rasswanth <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 27 May 2024 16:36:07 +0530 Subject: [PATCH 067/309] Revert "rename new_project.send() to new_project.start()" --- notebooks/api/0.8/01-submit-code.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/api/0.8/01-submit-code.ipynb b/notebooks/api/0.8/01-submit-code.ipynb index 8448680171b..761d1a96e7a 100644 --- a/notebooks/api/0.8/01-submit-code.ipynb +++ b/notebooks/api/0.8/01-submit-code.ipynb @@ -482,7 +482,7 @@ "outputs": [], "source": [ "# Once we start the project, it will submit the project along with the code request to the Domain Server\n", - "project = new_project.start()\n", + "project = new_project.send()\n", "project" ] }, From 7a8368a64928ce6c7913cd903d81e834935e8eb5 Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 27 May 2024 09:51:40 -0300 Subject: [PATCH 068/309] Fix lint --- packages/syft/src/syft/service/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 9576ac2b92e..77a4679e1a3 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -288,7 +288,7 @@ def reconstruct_args_kwargs( final_kwargs[param_key] = param.default else: raise Exception(f"Missing {param_key} not in kwargs.") - final_kwargs['context'] = kwargs['context'] if 'context' in kwargs else None + final_kwargs["context"] = kwargs["context"] if "context" in kwargs else None return (args, final_kwargs) From 30100454ad1a91c77a99830adbc535ea8f612e35 Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 27 May 2024 09:58:57 -0300 Subject: [PATCH 069/309] Fix unit tests --- packages/syft/tests/syft/users/user_service_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 54a32bef836..7c0cc32562a 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -229,7 +229,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err: expected_output = [guest_user.to(UserView)] # Search via id - response = user_service.search(authed_context, id=guest_user.id) + response = user_service.search(context=authed_context, id=guest_user.id) assert isinstance(response, list) assert all( r.to_dict() == expected.to_dict() @@ -238,7 +238,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err: # assert response.to_dict() == expected_output.to_dict() # Search via email - response = user_service.search(authed_context, email=guest_user.email) + response = user_service.search(context=authed_context, email=guest_user.email) assert isinstance(response, list) assert all( r.to_dict() == expected.to_dict() @@ -246,7 +246,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err: ) # Search via name - response = user_service.search(authed_context, name=guest_user.name) + response = user_service.search(context=authed_context, name=guest_user.name) assert isinstance(response, list) assert all( r.to_dict() == expected.to_dict() @@ -255,7 +255,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err: # Search via verify_key response = user_service.search( - authed_context, + context=authed_context, verify_key=guest_user.verify_key, ) assert isinstance(response, list) @@ -266,7 +266,7 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Ok | Err: # Search via multiple kwargs response = user_service.search( - authed_context, name=guest_user.name, email=guest_user.email + context=authed_context, name=guest_user.name, email=guest_user.email ) assert isinstance(response, list) assert all( @@ -279,7 +279,7 @@ def test_userservice_search_with_invalid_kwargs( user_service: UserService, authed_context: AuthedServiceContext ) -> None: # Search with invalid kwargs - response = user_service.search(authed_context, role=ServiceRole.GUEST) + response = user_service.search(context=authed_context, role=ServiceRole.GUEST) assert isinstance(response, SyftError) assert "Invalid Search parameters" in response.message From fc2b703327a2793e3c5babcb411100678beff1ee Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Mon, 27 May 2024 10:25:50 -0300 Subject: [PATCH 070/309] Fix lint --- packages/syft/src/syft/service/service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 77a4679e1a3..55f2c1f4b5d 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -288,7 +288,10 @@ def reconstruct_args_kwargs( final_kwargs[param_key] = param.default else: raise Exception(f"Missing {param_key} not in kwargs.") - final_kwargs["context"] = kwargs["context"] if "context" in kwargs else None + + if "context": + final_kwargs["context"] = kwargs["context"] + return (args, final_kwargs) From 3e9286311d246fb8c53d5b08e5696f0b0c5a9e8c Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 29 May 2024 11:01:28 +1000 Subject: [PATCH 071/309] Fixed linting / tests adding error handling for None types --- packages/grid/backend/grid/core/config.py | 3 +++ .../backend/backend-statefulset.yaml | 2 +- .../service/network/association_request.py | 2 +- .../syft/service/network/network_service.py | 19 +++++++++++++++++-- .../src/syft/service/network/node_peer.py | 2 +- .../syft/src/syft/service/network/rathole.py | 4 ++-- .../syft/service/network/rathole_service.py | 2 ++ tox.ini | 2 ++ 8 files changed, 29 insertions(+), 7 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 8c55b8cd3f7..33d65719fe8 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,6 +155,9 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) + REVERSE_TUNNEL_RATHOLE_ENABLED: bool = str_to_bool( + os.getenv("REVERSE_TUNNEL_RATHOLE_ENABLED", "false") + ) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index f9b2dd42353..1fed68448dc 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -90,7 +90,7 @@ spec: {{- if .Values.rathole.enabled }} - name: RATHOLE_PORT value: {{ .Values.rathole.port | quote }} - - name: RATHOLE_ENABLED + - name: REVERSE_TUNNEL_RATHOLE_ENABLED value: "true" {{- end }} # MongoDB diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 2590b3b42fe..0d1363b54e6 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -46,7 +46,7 @@ def _run( rathole_route = self.remote_peer.get_rathole_route() - if rathole_route.rathole_token is None: + if rathole_route and rathole_route.rathole_token is None: try: remote_client: SyftClient = self.remote_peer.client_with_context( context=service_ctx diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index de560747ac6..f8e9b29658d 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -31,7 +31,9 @@ from ...types.uid import UID from ...util.telemetry import instrument from ...util.util import generate_token +from ...util.util import get_env from ...util.util import prompt_warning_message +from ...util.util import str_to_bool from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey from ..metadata.node_metadata import NodeMetadataV3 @@ -61,6 +63,12 @@ NodeTypePartitionKey = PartitionKey(key="node_type", type_=NodeType) OrderByNamePartitionKey = PartitionKey(key="name", type_=str) +REVERSE_TUNNEL_RATHOLE_ENABLED = "REVERSE_TUNNEL_RATHOLE_ENABLED" + + +def get_rathole_enabled() -> bool: + return str_to_bool(get_env(REVERSE_TUNNEL_RATHOLE_ENABLED, "false")) + @serializable() class NodePeerAssociationStatus(Enum): @@ -151,7 +159,8 @@ class NetworkService(AbstractService): def __init__(self, store: DocumentStore) -> None: self.store = store self.stash = NetworkStash(store=store) - self.rathole_service = RatholeService() + if get_rathole_enabled(): + self.rathole_service = RatholeService() # TODO: Check with MADHAVA, can we even allow guest user to introduce routes to # domain nodes? @@ -281,8 +290,14 @@ def exchange_credentials_with( if result.is_err(): return SyftError(message="Failed to update route information.") - if reverse_tunnel: + if reverse_tunnel and get_rathole_enabled(): rathole_route = self_node_peer.get_rathole_route() + if not rathole_route: + raise Exception( + "Failed to exchange credentials. " + + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" + ) + remote_url = GridURL( host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port ) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index e6ac045e9f1..4fe2b5f38b2 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -269,7 +269,7 @@ def pick_highest_priority_route(self) -> NodeRoute: highest_priority_route = route return highest_priority_route - def get_rathole_route(self) -> NodeRoute | None: + def get_rathole_route(self) -> HTTPNodeRoute | None: for route in self.node_routes: if hasattr(route, "rathole_token") and route.rathole_token: return route diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py index e102134ada6..d311bfa9c2b 100644 --- a/packages/syft/src/syft/service/network/rathole.py +++ b/packages/syft/src/syft/service/network/rathole.py @@ -1,5 +1,5 @@ -# stdlib -from typing import Self +# third party +from typing_extensions import Self # relative from ...serde.serializable import serializable diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index ad5e783c7dd..2bcde4fd2f4 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -36,6 +36,8 @@ def add_host_to_server(self, peer: NodePeer) -> None: """ rathole_route = peer.get_rathole_route() + if not rathole_route: + raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") random_port = self.get_random_port() diff --git a/tox.ini b/tox.ini index 35cb08c9fb7..c282c971edc 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,7 @@ [tox] envlist = + dev.k8s.launch.domain + dev.k8s.launch.gateway dev.k8s.registry dev.k8s.start dev.k8s.deploy From 74ad27711ca2a4904feceff77fdcce63fa323f2b Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 29 May 2024 16:52:07 +1000 Subject: [PATCH 072/309] Re-add changed node route --- packages/syft/src/syft/service/network/network_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index f8e9b29658d..4532201aaae 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -189,6 +189,7 @@ def exchange_credentials_with( _rathole_route = self_node_peer.node_routes[-1] _rathole_route.rathole_token = generate_token() _rathole_route.host_or_ip = f"{self_node_peer.name}.syft.local" + self_node_peer.node_routes[-1] = _rathole_route if isinstance(self_node_peer, SyftError): return self_node_peer From 47d5852c03b36436d05fabe3cda2d9d590bfd0e9 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Thu, 30 May 2024 16:30:45 +1000 Subject: [PATCH 073/309] Change rathole port in gateway dev mode --- packages/grid/devspace.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 616eaa3ea4c..ed6dd9bbf11 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -164,6 +164,11 @@ profiles: path: dev.backend.containers.backend-container.ssh.localPort value: 3481 + # Mongo + - op: replace + path: dev.rathole.ports[0].port + value: 2334:2333 + - name: gcp patches: - op: replace From 85467f1ea00f4b78c60d536e96ad515f77b04535 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 13:30:51 +0200 Subject: [PATCH 074/309] add proper object registry with canonical name and version, add old types, refactor serialization, add client facing API's for migration --- packages/syft/src/syft/node/node.py | 8 +- .../syft/src/syft/service/job/job_stash.py | 13 +- packages/syft/src/syft/service/log/log.py | 16 +- .../src/syft/service/migration/__init__.py | 0 .../service/migration/migration_service.py | 359 ++++++++++++++++++ .../migration/object_migration_state.py | 82 ++++ .../syft/service/notifier/notifier_service.py | 6 +- .../src/syft/service/output/output_service.py | 13 +- .../syft/src/syft/store/kv_document_store.py | 16 +- packages/syft/src/syft/types/syft_object.py | 11 +- .../src/syft/types/syft_object_registry.py | 67 ++-- 11 files changed, 537 insertions(+), 54 deletions(-) create mode 100644 packages/syft/src/syft/service/migration/__init__.py create mode 100644 packages/syft/src/syft/service/migration/migration_service.py create mode 100644 packages/syft/src/syft/service/migration/object_migration_state.py diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 14632f914a0..2647790259d 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -405,7 +405,6 @@ def __init__( self.create_initial_settings(admin_email=root_email) - self.init_blob_storage(config=blob_storage_config) # Migrate data before any operation on db @@ -425,8 +424,6 @@ def __init__( if background_tasks: self.run_peer_health_checks(context=context) - - NodeRegistry.set_node_for(self.id, self) @property @@ -1637,11 +1634,12 @@ def create_admin_new( return result.ok() else: raise Exception(f"Could not create user: {result}") - except Exception as e: + except Exception: # import ipdb # ipdb.set_trace() + # stdlib import traceback - + print("Unable to create new admin", traceback.format_exc()) return None diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index d86f7363904..b91851ac27b 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -1,10 +1,11 @@ # stdlib +from collections.abc import Callable from datetime import datetime from datetime import timedelta from enum import Enum import random from string import Template -from typing import Any, Callable +from typing import Any # third party from pydantic import Field @@ -12,8 +13,6 @@ from result import Err from result import Ok from result import Result -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default from typing_extensions import Self # relative @@ -30,10 +29,14 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_2, SYFT_OBJECT_VERSION_4 +from ...types.syft_migration import migrate +from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SYFT_OBJECT_VERSION_5 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID from ...util import options from ...util.colors import SURFACE @@ -752,10 +755,12 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # def upgrade_job() -> list[Callable]: return [make_set_default("requested_by", UID())] + @migrate(JobV4, Job) def downgrade_job() -> list[Callable]: return [drop("requested_by")] + @serializable() class JobInfo(SyftObject): __canonical_name__ = "JobInfo" diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index 48431952d90..a08beac55f0 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,15 +1,17 @@ # stdlib -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from typing import ClassVar -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default - # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext -from ...types.syft_object import SYFT_OBJECT_VERSION_3, SYFT_OBJECT_VERSION_4 +from ...types.syft_migration import migrate +from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID @@ -28,6 +30,7 @@ class SyftLogV3(SyncableSyftObject): stdout: str = "" stderr: str = "" + @serializable() class SyftLog(SyncableSyftObject): __canonical_name__ = "SyftLog" @@ -59,12 +62,13 @@ def get_sync_dependencies( ) -> list[UID]: # type: ignore return [self.job_id] + @migrate(SyftLogV3, SyftLog) def upgrade_syftlog() -> list[Callable]: # TODO: FIX return [make_set_default("job_id", UID())] + @migrate(SyftLog, SyftLogV3) def downgrade_syftlog() -> list[Callable]: return [drop("job_id")] - diff --git a/packages/syft/src/syft/service/migration/__init__.py b/packages/syft/src/syft/service/migration/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py new file mode 100644 index 00000000000..c6708ebc850 --- /dev/null +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -0,0 +1,359 @@ +# stdlib + +# stdlib + +# stdlib + +# third party +from result import Err +from result import Ok +from result import Result + +# relative +from ...serde.serializable import serializable +from ...store.document_store import DocumentStore +from ...types.syft_object import SyftObject +from ..action.action_object import Action +from ..action.action_object import ActionObject +from ..context import AuthedServiceContext +from ..response import SyftError +from ..response import SyftSuccess +from ..service import AbstractService +from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL +from .object_migration_state import SyftMigrationStateStash +from .object_migration_state import SyftObjectMigrationState + + +@serializable() +class MigrationService(AbstractService): + store: DocumentStore + stash: SyftMigrationStateStash + + def __init__(self, store: DocumentStore) -> None: + self.store = store + self.stash = SyftMigrationStateStash(store=store) + + @service_method(path="migration", name="get_version") + def get_version( + self, context: AuthedServiceContext, canonical_name: str + ) -> int | SyftError: + """Search for the metadata for an object.""" + + result = self.stash.get_by_name( + canonical_name=canonical_name, credentials=context.credentials + ) + + if result.is_err(): + return SyftError(message=f"{result.err()}") + + migration_state = result.ok() + + if migration_state is None: + return SyftError( + message=f"No migration state exists for canonical name: {canonical_name}" + ) + + return migration_state.current_version + + @service_method(path="migration", name="get_state") + def get_state( + self, context: AuthedServiceContext, canonical_name: str + ) -> bool | SyftError: + result = self.stash.get_by_name( + canonical_name=canonical_name, credentials=context.credentials + ) + + if result.is_err(): + return SyftError(message=f"{result.err()}") + + return result.ok() + + @service_method(path="migration", name="register_migration_state") + def register_migration_state( + self, + context: AuthedServiceContext, + current_version: int, + canonical_name: str, + ) -> SyftObjectMigrationState | SyftError: + obj = SyftObjectMigrationState( + current_version=current_version, canonical_name=canonical_name + ) + result = self.stash.set(migration_state=obj, credentials=context.credentials) + + if result.is_err(): + return SyftError(message=f"{result.err()}") + + return result.ok() + + def _find_klasses_pending_for_migration( + self, context: AuthedServiceContext, object_types: list[type[SyftObject]] + ) -> list[SyftObject]: + klasses_to_be_migrated = [] + + for object_type in object_types: + canonical_name = object_type.__canonical_name__ + object_version = object_type.__version__ + + migration_state = self.get_state(context, canonical_name) + if isinstance(migration_state, SyftError): + raise Exception( + f"Failed to get migration state for {canonical_name}. Error: {migration_state}" + ) + if ( + migration_state is not None + and migration_state.current_version != migration_state.latest_version + ): + klasses_to_be_migrated.append(object_type) + else: + self.register_migration_state( + context, + current_version=object_version, + canonical_name=canonical_name, + ) + + return klasses_to_be_migrated + + @service_method( + path="migration.get_migration_objects", + name="get_migration_objects", + roles=ADMIN_ROLE_LEVEL, + ) + def get_migration_objects( + self, + context: AuthedServiceContext, + document_store_object_types: list[type[SyftObject]] | None = None, + ) -> dict | SyftError: + res = self._get_migration_objects(context, document_store_object_types) + if res.is_err(): + return SyftError(message=res.value) + else: + return res.ok() + + def _get_migration_objects( + self, + context: AuthedServiceContext, + document_store_object_types: list[type[SyftObject]] | None = None, + ) -> Result[dict, str]: + if document_store_object_types is None: + document_store_object_types = [ + partition.settings.object_type + for partition in self.store.partitions.values() + ] + + klasses_to_migrate = self._find_klasses_pending_for_migration( + context=context, object_types=document_store_object_types + ) + + if klasses_to_migrate: + print( + f"Classes in Document Store that need migration: {klasses_to_migrate}" + ) + + result = {} + + for klass in klasses_to_migrate: + canonical_name = klass.__canonical_name__ + object_partition = self.store.partitions.get(canonical_name) + if object_partition is None: + continue + objects_result = object_partition.all( + context.credentials, has_permission=True + ) + if objects_result.is_err(): + return objects_result + objects = objects_result.ok() + result[klass] = objects + return Ok(result) + + @service_method( + path="migration.update_migrated_objects", + name="update_migrated_objects", + roles=ADMIN_ROLE_LEVEL, + ) + def update_migrated_objects( + self, context: AuthedServiceContext, migrated_objects: list[SyftObject] + ) -> SyftSuccess | SyftError: + res = self._update_migrated_objects(context, migrated_objects) + if res.is_err(): + return SyftError(message=res.value) + else: + return SyftSuccess(message=res.ok()) + + def _update_migrated_objects( + self, context: AuthedServiceContext, migrated_objects: list[SyftObject] + ) -> Result[str, str]: + for migrated_object in migrated_objects: + klass = type(migrated_object) + canonical_name = klass.__canonical_name__ + object_partition = self.store.partitions.get(canonical_name) + qk = object_partition.settings.store_key.with_obj(migrated_object.id) + result = object_partition._update( + context.credentials, + qk=qk, + obj=migrated_object, + has_permission=True, + overwrite=True, + allow_missing_keys=True, + ) + + if result.is_err(): + return result.err() + return Ok(value="success") + + @service_method( + path="migration.migrate_data", + name="migrate_data", + roles=ADMIN_ROLE_LEVEL, + ) + def migrate_data( + self, + context: AuthedServiceContext, + document_store_object_types: list[type[SyftObject]] | None = None, + ) -> SyftSuccess | SyftError: + # Track all object type that need migration for document store + + # get all objects, keyed by type (because we might want to have different rules for different types) + # Q: will this be tricky with the protocol???? + # A: For now we will assume that the client will have the same version + + # Then, locally we write stuff that says + # for klass, objects in migration_dict.items(): + # for object in objects: + # if isinstance(object, X): + # do something custom + # else: + # migrated_value = object.migrate_to(klass.__version__, context) + # + # migrated_values = [SyftObject] + # client.migration.write_migrated_values(migrated_values) + + migration_objects_result = self._get_migration_objects( + context, document_store_object_types + ) + if migration_objects_result.is_err(): + return migration_objects_result + migration_objects = migration_objects_result.ok() + + migrated_objects = [] + + for klass, objects in migration_objects.items(): + canonical_name = klass.__canonical_name__ + # Migrate data for objects in document store + print(f"Migrating data for: {canonical_name} table.") + for object in objects: + try: + migrated_value = object.migrate_to(klass.__version__, context) + migrated_objects.append(migrated_value) + except Exception: + # stdlib + import traceback + + print(traceback.format_exc()) + return Err( + f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" + ) + + objects_update_update_result = self._update_migrated_objects( + context, migrated_objects + ) + if objects_update_update_result.is_err(): + return SyftError(message=objects_update_update_result.value) + + # now action objects + migration_actionobjects_result: dict[type[SyftObject], list[SyftObject]] = ( + self._get_migration_actionobjects(context) + ) + + if migration_actionobjects_result.is_err(): + return migration_actionobjects_result + migration_actionobjects = migration_actionobjects_result.ok() + + migrated_actionobjects = [] + for klass, action_objects in migration_actionobjects.items(): + # these are Actions, ActionObjects, and possibly others + for object in action_objects: + try: + migrated_actionobject = object.migrate_to( + klass.__version__, context + ) + migrated_actionobjects.append(migrated_actionobject) + except Exception: + # stdlib + import traceback + + print(traceback.format_exc()) + return Err( + f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" + ) + + actionobjects_update_update_result = self._update_migrated_actionobjects( + context, migrated_actionobjects + ) + if actionobjects_update_update_result.is_err(): + return SyftError(message=actionobjects_update_update_result.err()) + + return SyftSuccess(message="Data upgraded to the latest version") + + @service_method( + path="migration.get_migration_actionobjects", + name="get_migration_actionobjects", + roles=ADMIN_ROLE_LEVEL, + ) + def get_migration_actionobjects(self, context: AuthedServiceContext): + res = self._get_migration_actionobjects(context) + if res.is_ok(): + return res.ok() + else: + return SyftError(message=res.value) + + def _get_migration_actionobjects( + self, context: AuthedServiceContext + ) -> Result[dict[type[SyftObject], SyftObject], str]: + # Track all object types from action store + action_object_types = [Action, ActionObject] + action_object_types.extend(ActionObject.__subclasses__()) + + action_object_pending_migration = self._find_klasses_pending_for_migration( + context=context, object_types=action_object_types + ) + result_dict = {x: [] for x in action_object_pending_migration} + action_store = context.node.action_store + action_store_objects_result = action_store._all( + context.credentials, has_permission=True + ) + if action_store_objects_result.is_err(): + return action_store_objects_result + action_store_objects = action_store_objects_result.ok() + + for obj in action_store_objects: + if type(obj) in result_dict: + result_dict[type(obj)].append(obj) + return Ok(result_dict) + + @service_method( + path="migration.update_migrated_actionobjects", + name="update_migrated_actionobjects", + roles=ADMIN_ROLE_LEVEL, + ) + def update_migrated_actionobjects( + self, context: AuthedServiceContext, objects: list[SyftObject] + ) -> SyftSuccess | SyftError: + res = self._update_migrated_actionobjects(context, objects) + if res.is_ok(): + return SyftSuccess(message="succesfully migrated actionobjects") + else: + return SyftError(message=res.value) + + def _update_migrated_actionobjects( + self, context: AuthedServiceContext, objects: list[SyftObject] + ) -> Result[str, str]: + # Track all object types from action store + action_store = context.node.action_store + for obj in objects: + res = action_store.set( + uid=obj.id, credentials=context.credentials, syft_object=obj + ) + if res.is_err(): + return res + return Ok("success") diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py new file mode 100644 index 00000000000..f5b3a043ea1 --- /dev/null +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -0,0 +1,82 @@ +# stdlib + +# third party +from result import Result + +# relative +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionKey +from ...store.document_store import PartitionSettings +from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SyftMigrationRegistry +from ...types.syft_object import SyftObject +from ..action.action_permissions import ActionObjectPermission + + +@serializable() +class SyftObjectMigrationState(SyftObject): + __canonical_name__ = "SyftObjectMigrationState" + __version__ = SYFT_OBJECT_VERSION_2 + + __attr_unique__ = ["canonical_name"] + + canonical_name: str + current_version: int + + @property + def latest_version(self) -> int | None: + available_versions = SyftMigrationRegistry.get_versions( + canonical_name=self.canonical_name, + ) + if not available_versions: + return None + + return sorted(available_versions, reverse=True)[0] + + @property + def supported_versions(self) -> list: + return SyftMigrationRegistry.get_versions(self.canonical_name) + + +KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) + + +@serializable() +class SyftMigrationStateStash(BaseStash): + object_type = SyftObjectMigrationState + settings: PartitionSettings = PartitionSettings( + name=SyftObjectMigrationState.__canonical_name__, + object_type=SyftObjectMigrationState, + ) + + def __init__(self, store: DocumentStore) -> None: + super().__init__(store=store) + + def set( + self, + credentials: SyftVerifyKey, + migration_state: SyftObjectMigrationState, + add_permissions: list[ActionObjectPermission] | None = None, + add_storage_permission: bool = True, + ignore_duplicates: bool = False, + ) -> Result[SyftObjectMigrationState, str]: + res = self.check_type(migration_state, self.object_type) + # we dont use and_then logic here as it is hard because of the order of the arguments + if res.is_err(): + return res + return super().set( + credentials=credentials, + obj=res.ok(), + add_permissions=add_permissions, + add_storage_permission=add_storage_permission, + ignore_duplicates=ignore_duplicates, + ) + + def get_by_name( + self, canonical_name: str, credentials: SyftVerifyKey + ) -> Result[SyftObjectMigrationState, str]: + qks = KlassNamePartitionKey.with_obj(canonical_name) + return self.query_one(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 3913c533514..b203a681880 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -2,12 +2,14 @@ # stdlib +# stdlib +import traceback + # third party from pydantic import EmailStr from result import Err from result import Ok from result import Result -import traceback # relative from ...abstract_node import AbstractNode @@ -277,7 +279,7 @@ def init_notifier( notifier_stash.set(node.signing_key.verify_key, notifier) return Ok("Notifier initialized successfully") - except Exception as e: + except Exception: raise Exception(f"Error initializing notifier. \n {traceback.format_exc()}") # This is not a public API. diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index c552a753852..d6cb8a5234b 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -1,13 +1,12 @@ # stdlib -from typing import Callable, ClassVar +from collections.abc import Callable +from typing import ClassVar # third party from pydantic import model_validator from result import Err from result import Ok from result import Result -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default # relative from ...client.api import APIRegistry @@ -20,8 +19,12 @@ from ...store.document_store import QueryKeys from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1, SYFT_OBJECT_VERSION_2 +from ...types.syft_migration import migrate +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_object import ActionObject @@ -275,11 +278,11 @@ def get_by_output_policy_id( ) - @migrate(ExecutionOutputV1, ExecutionOutput) def upgrade_execution_output() -> list[Callable]: return [make_set_default("job_id", None)] + @migrate(ExecutionOutput, ExecutionOutputV1) def downgrade_execution_output() -> list[Callable]: return [drop("job_id")] diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index bab9e2251e0..7c921f26eb6 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -26,8 +26,9 @@ from ..service.response import SyftSuccess from ..types.syft_object import SyftObject from ..types.uid import UID -from .document_store import BaseStash, PartitionKeys +from .document_store import BaseStash from .document_store import PartitionKey +from .document_store import PartitionKeys from .document_store import QueryKey from .document_store import QueryKeys from .document_store import StorePartition @@ -482,10 +483,10 @@ def _update( return Err(f"Failed to update obj {obj}, you have no permission") except Exception as e: - import ipdb - - ipdb.set_trace() + # third party + # stdlib import traceback + print(traceback.format_exc()) return Err(f"Failed to update obj {obj} with error: {e}") @@ -658,9 +659,9 @@ def _set_data_and_keys( ck_col[pk_value] = store_query_key.value self.unique_keys[pk_key] = ck_col - self.unique_keys[store_query_key.key][ + self.unique_keys[store_query_key.key][store_query_key.value] = ( store_query_key.value - ] = store_query_key.value + ) sqks = searchable_query_keys.all for qk in sqks: @@ -691,6 +692,7 @@ def _migrate_data( try: migrated_value = value.migrate_to(to_klass.__version__, context) except Exception: + # stdlib import traceback print(traceback.format_exc()) @@ -704,7 +706,7 @@ def _migrate_data( obj=migrated_value, has_permission=has_permission, overwrite=True, - allow_missing_keys=True + allow_missing_keys=True, ) if result.is_err(): diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 922ae642da9..c100d36d7fa 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -50,7 +50,6 @@ if TYPE_CHECKING: # relative from ..service.sync.diff_state import AttrDiff - from .syft_object_registry import SyftObjectRegistry IntStr = int | str AbstractSetIntStr = Set[IntStr] @@ -295,9 +294,10 @@ def get_migration_for_version( ] -class RegisteredSyftObject(): +class RegisteredSyftObject: def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) + # relative from .syft_object_registry import SyftObjectRegistry as reg if hasattr(reg, "__canonical_name__") and hasattr(reg, "__version__"): @@ -325,6 +325,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: # only if the cls has not been registered do we want to register it reg.__object_version_registry__[mapping_string] = reg + class SyftObject(SyftBaseObject, RegisteredSyftObject, SyftMigrationRegistry): __canonical_name__ = "SyftObject" __version__ = SYFT_OBJECT_VERSION_2 @@ -477,7 +478,9 @@ def __getitem__(self, key: str | int) -> Any: return self.__dict__.__getitem__(key) # type: ignore def _upgrade_version(self, latest: bool = True) -> "SyftObject": + # relative from .syft_object_registry import SyftObjectRegistry + constructor = SyftObjectRegistry.versioned_class( name=self.__canonical_name__, version=self.__version__ + 1 ) @@ -492,7 +495,9 @@ def _upgrade_version(self, latest: bool = True) -> "SyftObject": # transform from one supported type to another def to(self, projection: type, context: Context | None = None) -> Any: + # relative from .syft_object_registry import SyftObjectRegistry + # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) @@ -718,7 +723,9 @@ def short_uid(uid: UID | None) -> str | None: class StorableObjectType: def to(self, projection: type, context: Context | None = None) -> Any: # 🟡 TODO 19: Could we do an mro style inheritence conversion? Risky? + # relative from .syft_object_registry import SyftObjectRegistry + transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 0b89bf89267..5df03a0c588 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -1,20 +1,18 @@ # stdlib from collections.abc import Callable -from typing import Any from typing import TYPE_CHECKING -from syft.util.util import get_fully_qualified_name +# relative +from ..util.util import get_fully_qualified_name -SYFT_086_PROTOCOL_VERSION = '4' +SYFT_086_PROTOCOL_VERSION = "4" # third party # relative if TYPE_CHECKING: # relative - from syft.types.syft_object import SyftObject - - + from .syft_object import SyftObject class SyftObjectRegistry: @@ -22,7 +20,7 @@ class SyftObjectRegistry: str, type["SyftObject"] | type["SyftObjectRegistry"] ] = {} __object_transform_registry__: dict[str, Callable] = {} - __object_serialization_registry__: dict[tuple[str, str]: tuple] = {} + __object_serialization_registry__: dict[tuple[str, str] : tuple] = {} @classmethod def get_canonical_name(cls, obj, is_type: bool): @@ -39,13 +37,17 @@ def get_canonical_name(cls, obj, is_type: bool): @classmethod def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tuple: - from syft.serde.recursive import TYPE_BANK + # relative + from ..serde.recursive import TYPE_BANK + if canonical_name != "" and canonical_name is not None: return cls.__object_serialization_registry__[canonical_name, version] else: # this is for backward compatibility with 0.8.6 try: - from syft.protocol.data_protocol import get_data_protocol + # relative + from ..protocol.data_protocol import get_data_protocol + serde_props = TYPE_BANK[fqn] klass = serde_props[7] is_syftobject = hasattr(klass, "__canonical_name__") @@ -53,40 +55,52 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tu canonical_name = klass.__canonical_name__ dp = get_data_protocol() try: - version_mutations = dp.protocol_history[SYFT_086_PROTOCOL_VERSION]["object_versions"][canonical_name] + version_mutations = dp.protocol_history[ + SYFT_086_PROTOCOL_VERSION + ]["object_versions"][canonical_name] except Exception: print(f"could not find {canonical_name} in protocol history") raise - version_086 = max([int(k) for k, v in version_mutations.items() if v["action"] == "add"]) + version_086 = max( + [ + int(k) + for k, v in version_mutations.items() + if v["action"] == "add" + ] + ) try: - res = cls.__object_serialization_registry__[canonical_name, version_086] + res = cls.__object_serialization_registry__[ + canonical_name, version_086 + ] except Exception: - print(f"could not find {canonical_name} {version_086} in ObjectRegistry") + print( + f"could not find {canonical_name} {version_086} in ObjectRegistry" + ) raise return res else: # TODO, add refactoring for non syftobject versions canonical_name = fqn.split(".")[-1] version = 1 - return cls.__object_serialization_registry__[canonical_name, version] + return cls.__object_serialization_registry__[ + canonical_name, version + ] except Exception as e: print(e) - import ipdb - ipdb.set_trace() - + # third party @classmethod - def has_serde_class(cls, fqn:str, canonical_name: str, version: str) -> tuple: - from syft.serde.recursive import TYPE_BANK + def has_serde_class(cls, fqn: str, canonical_name: str, version: str) -> tuple: + # relative + from ..serde.recursive import TYPE_BANK + if canonical_name != "" and canonical_name is not None: return (canonical_name, version) in cls.__object_serialization_registry__ else: # this is for backward compatibility with 0.8.6 return fqn in TYPE_BANK - - @classmethod def add_transform( cls, @@ -103,7 +117,10 @@ def add_transform( def get_transform( cls, type_from: type["SyftObject"], type_to: type["SyftObject"] ) -> Callable: - from .syft_object import SyftBaseObject, SyftObject + # relative + from .syft_object import SyftBaseObject + from .syft_object import SyftObject + for type_from_mro in type_from.mro(): if issubclass(type_from_mro, SyftObject): klass_from = type_from_mro.__canonical_name__ @@ -123,7 +140,9 @@ def get_transform( f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" ) if mapping_string in SyftObjectRegistry.__object_transform_registry__: - return SyftObjectRegistry.__object_transform_registry__[mapping_string] + return SyftObjectRegistry.__object_transform_registry__[ + mapping_string + ] raise Exception( f"No mapping found for: {type_from} to {type_to} in" f"the registry: {SyftObjectRegistry.__object_transform_registry__.keys()}" @@ -133,7 +152,9 @@ def get_transform( def versioned_class( cls, name: str, version: int ) -> type["SyftObject"] | type["SyftObjectRegistry"] | None: + # relative from .syft_object_registry import SyftObjectRegistry + mapping_string = f"{name}_{version}" if mapping_string not in SyftObjectRegistry.__object_version_registry__: return None From 1c668b5ca1ee03dfeb62c762122a949aa8a7aac7 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 13:55:10 +0200 Subject: [PATCH 075/309] - --- packages/syft/src/syft/service/action/action_graph.py | 3 ++- packages/syft/tests/syft/stores/store_mocks_test.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_graph.py b/packages/syft/src/syft/service/action/action_graph.py index 3a928da9f0c..80c269f94a4 100644 --- a/packages/syft/src/syft/service/action/action_graph.py +++ b/packages/syft/src/syft/service/action/action_graph.py @@ -31,7 +31,7 @@ from ...store.locks import SyftLock from ...store.locks import ThreadingLockingConfig from ...types.datetime import DateTime -from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.uid import UID @@ -356,6 +356,7 @@ class ActionGraphStore: @serializable() class InMemoryActionGraphStore(ActionGraphStore): __canonical_name__ = "InMemoryActionGraphStore" + __version__ = SYFT_OBJECT_VERSION_1 def __init__(self, store_config: StoreConfig, reset: bool = False): self.store_config: StoreConfig = store_config diff --git a/packages/syft/tests/syft/stores/store_mocks_test.py b/packages/syft/tests/syft/stores/store_mocks_test.py index 39aa2700829..14226a80873 100644 --- a/packages/syft/tests/syft/stores/store_mocks_test.py +++ b/packages/syft/tests/syft/stores/store_mocks_test.py @@ -47,11 +47,13 @@ def __getitem__(self, key: Any) -> Any: @serializable() class MockObjectType(SyftObject): __canonical_name__ = "mock_type" + __version__ = 1 @serializable() class MockStore(DocumentStore): __canonical_name__ = "MockStore" + __version__ = 1 pass @@ -64,6 +66,7 @@ class MockSyftObject(SyftObject): @serializable() class MockStoreConfig(StoreConfig): __canonical_name__ = "MockStoreConfig" + __version__ = 1 store_type: type[DocumentStore] = MockStore db_name: str = "testing" backing_store: type[KeyValueBackingStore] = MockKeyValueBackingStore From a88d5966e2c19265e825ef4455d8e6c97035aeea Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 17:50:13 +0200 Subject: [PATCH 076/309] fix canonical name for non syft objects --- packages/syft/src/syft/serde/recursive.py | 4 ++-- .../syft/src/syft/types/syft_object_registry.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index e7e7e6bd4fe..1833b1216c6 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -165,7 +165,7 @@ def recursive_serde_register( version = cls.__version__ else: # TODO: refactor - canonical_name = fqn.split(".")[-1] + canonical_name = fqn version = 1 SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] = serde_attributes @@ -229,7 +229,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild msg = recursive_scheme.new_message() # todo: rewrite and make sure every object has a canonical name and version - canonical_name = SyftObjectRegistry.get_canonical_name(self, is_type=is_type) + canonical_name = SyftObjectRegistry.get_canonical_name(self) if is_type: version = 1 else: diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 5df03a0c588..fed0c46c117 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -23,17 +23,18 @@ class SyftObjectRegistry: __object_serialization_registry__: dict[tuple[str, str] : tuple] = {} @classmethod - def get_canonical_name(cls, obj, is_type: bool): - if is_type: - # TODO: this is different for builtin types, make more generic - return "ModelMetaclass" + def get_canonical_name(cls, obj): + # if is_type: + # # TODO: this is different for builtin types, make more generic + # return "ModelMetaclass" + is_type = isinstance(obj, type) res = getattr(obj, "__canonical_name__", None) - if res is not None: + if res is not None and not is_type: return res else: fqn = get_fully_qualified_name(obj) - return fqn.split(".")[-1] + return fqn @classmethod def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tuple: @@ -81,7 +82,7 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tu return res else: # TODO, add refactoring for non syftobject versions - canonical_name = fqn.split(".")[-1] + canonical_name = fqn version = 1 return cls.__object_serialization_registry__[ canonical_name, version From 7f9e7aaa81b18549750819ad4021cfc2fc5b4b08 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 18:04:40 +0200 Subject: [PATCH 077/309] fix linter --- packages/syft/src/syft/serde/recursive.py | 23 ++++++++----------- .../syft/src/syft/service/job/job_stash.py | 3 +-- .../src/syft/types/syft_object_registry.py | 12 ++++++---- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 1833b1216c6..bc0f4cc1e66 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -13,10 +13,9 @@ # syft absolute import syft as sy -from syft.types.syft_object_registry import SyftObjectRegistry # relative -from ..util.util import get_fully_qualified_name +from ..types.syft_object_registry import SyftObjectRegistry from ..util.util import index_syft_by_module_name from .capnp import get_capnp_schema from .util import compatible_with_large_file_writes_capnp @@ -158,7 +157,6 @@ def recursive_serde_register( version, ) - TYPE_BANK[fqn] = serde_attributes if hasattr(cls, "__canonical_name__"): canonical_name = cls.__canonical_name__ @@ -168,8 +166,9 @@ def recursive_serde_register( canonical_name = fqn version = 1 - SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] = serde_attributes - + SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] = ( + serde_attributes + ) if isinstance(alias_fqn, tuple): for alias in alias_fqn: @@ -234,15 +233,13 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild version = 1 else: version = getattr(self, "__version__", 1) - + if not SyftObjectRegistry.has_serde_class("", canonical_name, version): # third party - import ipdb - ipdb.set_trace() raise Exception(f"{canonical_name} version {version} not in SyftObjectRegistry") msg.canonicalName = canonical_name - msg.version=version + msg.version = version ( nonrecursive, @@ -315,11 +312,11 @@ def rs_bytes2object(blob: bytes) -> Any: return rs_proto2object(msg) -def map_fqns_for_backward_compatibility(fqn): +def map_fqns_for_backward_compatibility(fqn: str) -> str: """for backwards compatibility with 0.8.6. Sometimes classes where moved to another file. Which is exactly why we are implementing it differently""" mapping = { - 'syft.service.dataset.dataset.MarkdownDescription': "syft.util.misc_objs.MarkdownDescription" + "syft.service.dataset.dataset.MarkdownDescription": "syft.util.misc_objs.MarkdownDescription" } if fqn in mapping: return mapping[fqn] @@ -327,7 +324,6 @@ def map_fqns_for_backward_compatibility(fqn): return fqn - def rs_proto2object(proto: _DynamicStructBuilder) -> Any: # relative from .deserialize import _deserialize @@ -360,8 +356,7 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: fqn = getattr(proto, "fullyQualifiedName", "") fqn = map_fqns_for_backward_compatibility(fqn) if not SyftObjectRegistry.has_serde_class(fqn, canonical_name, version): - import ipdb - ipdb.set_trace() + # third party raise Exception(f"{canonical_name} version {version} not in SyftObjectRegistry") # TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 12e4bf89067..ed0146f608f 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -30,11 +30,10 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime +from ...types.datetime import format_timedelta from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_4 -from ...types.syft_object import SYFT_OBJECT_VERSION_5 -from ...types.datetime import format_timedelta from ...types.syft_object import SYFT_OBJECT_VERSION_6 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index fed0c46c117..964bf8ca4bf 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from typing import Any from typing import TYPE_CHECKING # relative @@ -20,10 +21,10 @@ class SyftObjectRegistry: str, type["SyftObject"] | type["SyftObjectRegistry"] ] = {} __object_transform_registry__: dict[str, Callable] = {} - __object_serialization_registry__: dict[tuple[str, str] : tuple] = {} + __object_serialization_registry__: dict[tuple[str, int], tuple] = {} @classmethod - def get_canonical_name(cls, obj): + def get_canonical_name(cls, obj: Any) -> str: # if is_type: # # TODO: this is different for builtin types, make more generic # return "ModelMetaclass" @@ -37,7 +38,7 @@ def get_canonical_name(cls, obj): return fqn @classmethod - def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tuple: + def get_serde_properties(cls, fqn: str, canonical_name: str, version: int) -> tuple: # relative from ..serde.recursive import TYPE_BANK @@ -89,10 +90,13 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: str) -> tu ] except Exception as e: print(e) + raise # third party @classmethod - def has_serde_class(cls, fqn: str, canonical_name: str, version: str) -> tuple: + def has_serde_class( + cls, fqn: str, canonical_name: str | None, version: int + ) -> bool: # relative from ..serde.recursive import TYPE_BANK From 835fd381e2d3ec4f147b18d5f91cba76e5ed93e6 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 18:35:30 +0200 Subject: [PATCH 078/309] add some lost changes --- packages/syft/src/syft/node/node.py | 136 ++++++++++-------- .../src/syft/service/action/action_store.py | 10 ++ 2 files changed, 85 insertions(+), 61 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index edb418f90c1..c1360db00dc 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -66,6 +66,7 @@ from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadataV3 +from ..service.migration.migration_service import MigrationService from ..service.network.network_service import NetworkService from ..service.network.utils import PeerHealthCheckTask from ..service.notification.notification_service import NotificationService @@ -406,13 +407,16 @@ def __init__( self.post_init() + if migrate: + self.find_and_migrate_data() + else: + self.find_and_migrate_data([NodeSettings]) + self.create_initial_settings(admin_email=root_email) self.init_blob_storage(config=blob_storage_config) # Migrate data before any operation on db - if migrate: - self.find_and_migrate_data() # first migrate, for backwards compatibility self.init_queue_manager(queue_config=self.queue_config) @@ -712,69 +716,79 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated - def find_and_migrate_data(self) -> None: - # Track all object type that need migration for document store + def find_and_migrate_data( + self, document_store_object_types: list[type[SyftObject]] | None = None + ) -> None: context = AuthedServiceContext( node=self, credentials=self.verify_key, role=ServiceRole.ADMIN, ) - document_store_object_types = [ - partition.settings.object_type - for partition in self.document_store.partitions.values() - ] - - object_pending_migration = self._find_klasses_pending_for_migration( - object_types=document_store_object_types - ) - - if object_pending_migration: - print( - "Object in Document Store that needs migration: ", - object_pending_migration, - ) - - # Migrate data for objects in document store - for object_type in object_pending_migration: - canonical_name = object_type.__canonical_name__ - object_partition = self.document_store.partitions.get(canonical_name) - if object_partition is None: - continue - - print(f"Migrating data for: {canonical_name} table.") - migration_status = object_partition.migrate_data( - to_klass=object_type, context=context - ) - if migration_status.is_err(): - raise Exception( - f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" - ) - - # Track all object types from action store - action_object_types = [Action, ActionObject] - action_object_types.extend(ActionObject.__subclasses__()) - action_object_pending_migration = self._find_klasses_pending_for_migration( - action_object_types - ) - - if action_object_pending_migration: - print( - "Object in Action Store that needs migration: ", - action_object_pending_migration, - ) - - # Migrate data for objects in action store - for object_type in action_object_pending_migration: - canonical_name = object_type.__canonical_name__ - - migration_status = self.action_store.migrate_data( - to_klass=object_type, credentials=self.verify_key - ) - if migration_status.is_err(): - raise Exception( - f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" - ) - print("Data Migrated to latest version !!!") + migration_service = self.get_service("migrationservice") + return migration_service.migrate_data(context, document_store_object_types) + + # # Track all object type that need migration for document store + # context = AuthedServiceContext( + # node=self, + # credentials=self.verify_key, + # role=ServiceRole.ADMIN, + # ) + # document_store_object_types = [ + # partition.settings.object_type + # for partition in self.document_store.partitions.values() + # ] + + # object_pending_migration = self._find_klasses_pending_for_migration( + # object_types=document_store_object_types + # ) + + # if object_pending_migration: + # print( + # "Object in Document Store that needs migration: ", + # object_pending_migration, + # ) + + # # Migrate data for objects in document store + # for object_type in object_pending_migration: + # canonical_name = object_type.__canonical_name__ + # object_partition = self.document_store.partitions.get(canonical_name) + # if object_partition is None: + # continue + + # print(f"Migrating data for: {canonical_name} table.") + # migration_status = object_partition.migrate_data( + # to_klass=object_type, context=context + # ) + # if migration_status.is_err(): + # raise Exception( + # f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" + # ) + + # # Track all object types from action store + # action_object_types = [Action, ActionObject] + # action_object_types.extend(ActionObject.__subclasses__()) + # action_object_pending_migration = self._find_klasses_pending_for_migration( + # action_object_types + # ) + + # if action_object_pending_migration: + # print( + # "Object in Action Store that needs migration: ", + # action_object_pending_migration, + # ) + + # # Migrate data for objects in action store + # for object_type in action_object_pending_migration: + # canonical_name = object_type.__canonical_name__ + + # migration_status = self.action_store.migrate_data( + # to_klass=object_type, credentials=self.verify_key + # ) + # if migration_status.is_err(): + # raise Exception( + # f"Failed to migrate data for {canonical_name}. Error: {migration_status.err()}" + # ) + # print("Data Migrated to latest version !!!") @property def guest_client(self) -> SyftClient: @@ -926,7 +940,7 @@ def _construct_services(self) -> None: {"svc": CodeHistoryService}, {"svc": MetadataService}, {"svc": BlobStorageService}, - {"svc": MigrateStateService}, + {"svc": MigrationService}, {"svc": SyftWorkerImageService}, {"svc": SyftWorkerPoolService}, {"svc": SyftImageRegistryService}, diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index fae5921828e..fc5ae3c1958 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -301,6 +301,16 @@ def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: return Ok(self.storage_permissions[uid]) return Err(f"No storage permissions found for uid: {uid}") + def _all( + self, + credentials: SyftVerifyKey, + has_permission: bool | None = False, + ) -> Result[list[SyftObject], str]: + # this checks permissions + res = [self.get(uid, credentials, has_permission) for uid in self.data.keys()] + result = [x.ok() for x in res if x.is_ok()] + return Ok(result) + def migrate_data( self, to_klass: SyftObject, credentials: SyftVerifyKey ) -> Result[bool, str]: From 077015ee38662eddcc44365424163766a4cd6369 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 30 May 2024 18:48:50 +0200 Subject: [PATCH 079/309] notebooks --- .../0-prepare-migration-data.ipynb | 243 ++++++++++++++++++ .../1a-connect-and-migrate.ipynb | 88 +++++++ .../1b-connect-and-migrate-via-api.ipynb | 243 ++++++++++++++++++ .../2-post-migration-tests.ipynb | 193 ++++++++++++++ 4 files changed, 767 insertions(+) create mode 100644 notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb create mode 100644 notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb create mode 100644 notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb create mode 100644 notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb new file mode 100644 index 00000000000..9c24c78e53f --- /dev/null +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "import numpy as np\n", + "\n", + "# syft absolute\n", + "import syft as sy" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "# Verify Version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "pip_info = !pip index versions syft\n", + "latest_deployed_version = pip_info[-1].split(\"LATEST: \")[-1].strip()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# this notebook should only be used to run the latest deployed version of syft\n", + "# the notebooks after this (1a/1b and 2), will test migrating from that latest version\n", + "assert latest_deployed_version == sy.__version__" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "# Launch Node" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " reset=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "client.register(\n", + " email=\"ds@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "# Prepare some data to be migrated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "client_ds = node.login(email=\"ds@openmined.org\", password=\"pw\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = sy.Dataset(\n", + " name=\"my-dataset\",\n", + " description=\"abc\",\n", + " asset_list=[\n", + " sy.Asset(\n", + " name=\"numpy-data\",\n", + " mock=np.array([10, 11, 12, 13, 14]),\n", + " data=np.array([15, 16, 17, 18, 19]),\n", + " mock_is_real=True,\n", + " )\n", + " ],\n", + ")\n", + "\n", + "client.upload_dataset(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "data_low = client_ds.datasets[0].assets[0]\n", + "\n", + "\n", + "@sy.syft_function_single_use(data=data_low)\n", + "def compute_mean(domain, data) -> float:\n", + " # launch another job\n", + " print(\"Computing mean...\")\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "req = client_ds.code.request_code_execution(compute_mean)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "client.requests[0].approve()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "job = client_ds.code.compute_mean(data=data_low, blocking=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "res = job.wait()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "res.get()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "# todo: add more data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:syft086] *", + "language": "python", + "name": "conda-env-syft086-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb b/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb new file mode 100644 index 00000000000..aca3e671f5c --- /dev/null +++ b/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb @@ -0,0 +1,88 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "import numpy as np\n", + "\n", + "# syft absolute\n", + "import syft as sy\n", + "from syft.service.job.job_stash import Job" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "# Serialization tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "assert sy.serialize(np.array([1, 2, 3])).canonicalName == \"numpy.ndarray\"\n", + "assert sy.serialize(bool).canonicalName == \"builtins.type\"\n", + "assert (\n", + " sy.serialize(Job).canonicalName\n", + " == \"pydantic._internal._model_construction.ModelMetaclass\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb new file mode 100644 index 00000000000..50f34b603f0 --- /dev/null +++ b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "from syft.service.log.log import SyftLogV3\n", + "from syft.types.syft_object import Context\n", + "from syft.types.syft_object import SyftObject" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "TODOS\n", + "- [x] action objects\n", + "- [x] maybe an example of how to migrate one object type in a custom way\n", + "- [x] check SyftObjectRegistry and compare with current implementation\n", + "- [x] run unit tests\n", + "- [ ] finalize notebooks for testing, run in CI\n", + "- [ ] other tasks defined in tickets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "# Client side migrations" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## document store objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "migration_dict = client.services.migration.get_migration_objects()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "def custom_migration_function(context, obj: SyftObject, klass) -> SyftObject:\n", + " # Here, we are just doing the same, but this is where you would write your custom logic\n", + " return obj.migrate_to(klass.__version__, context)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# this wont work in the cases where the context is actually used,\n", + "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", + "\n", + "context = Context()\n", + "migrated_objects = []\n", + "for klass, objects in migration_dict.items():\n", + " for obj in objects:\n", + " if isinstance(obj, SyftLogV3):\n", + " migrated_obj = custom_migration_function(context, obj, klass)\n", + " else:\n", + " migrated_obj = obj.migrate_to(klass.__version__, context)\n", + " migrated_objects.append(migrated_obj)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "migrated_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "res = client.services.migration.update_migrated_objects(migrated_objects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(res, sy.SyftSuccess)" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## Actions and ActionObjects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "migration_action_dict = client.services.migration.get_migration_actionobjects()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# this wont work in the cases where the context is actually used, but since this you would need custom logic here anyway\n", + "# it doesnt matter\n", + "context = Context()\n", + "migrated_actionobjects = []\n", + "for klass, objects in migration_action_dict.items():\n", + " for obj in objects:\n", + " # custom migration logic here\n", + " migrated_actionobject = obj.migrate_to(klass.__version__, context)\n", + " migrated_actionobjects.append(migrated_actionobject)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "migrated_actionobjects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "res = client.services.migration.update_migrated_objects(migrated_actionobjects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(res, sy.SyftSuccess)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb new file mode 100644 index 00000000000..ef89ccbeb16 --- /dev/null +++ b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "# Post migration tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")\n", + "client_ds = node.login(email=\"ds@openmined.org\", password=\"pw\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "- [x] log in\n", + "- [x] get request / datasets / \n", + "- [x] check request is approved\n", + "- [x] run function\n", + "- [ ] run new function\n", + "- [x] repr (request, code, job)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "req1 = client.requests[0]\n", + "req2 = client_ds.requests[0]\n", + "assert req1.status.name == \"APPROVED\" and req2.status.name == \"APPROVED\"\n", + "assert isinstance(req1._repr_html_(), str)\n", + "assert isinstance(req2._repr_html_(), str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = client_ds.jobs\n", + "assert isinstance(jobs[0]._repr_html_(), str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "ds = client_ds.datasets\n", + "asset = ds[0].assets[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "res = client_ds.code.compute_mean(data=asset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "import numpy as np\n", + "\n", + "assert all(res == np.array([15, 16, 17, 18, 19]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = client_ds.jobs.get_all()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "job = jobs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "job.logs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "logs = job.logs(_print=False)\n", + "assert isinstance(logs, str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 6ac32533e23400c722830df08e8c603c85a09d66 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 31 May 2024 00:24:41 +0530 Subject: [PATCH 080/309] update network service to add override rathole config is already exists --- .../service/network/association_request.py | 35 +++- .../syft/service/network/network_service.py | 170 ++++++++---------- .../src/syft/service/network/node_peer.py | 10 ++ .../syft/service/network/rathole_service.py | 61 ++----- 4 files changed, 131 insertions(+), 145 deletions(-) diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 0d1363b54e6..043bdbc101e 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -33,6 +33,16 @@ class AssociationRequestChange(Change): def _run( self, context: ChangeContext, apply: bool ) -> Result[tuple[bytes, NodePeer], SyftError]: + """ + Executes the association request. + + Args: + context (ChangeContext): The change context. + apply (bool): A flag indicating whether to apply the association request. + + Returns: + Result[tuple[bytes, NodePeer], SyftError]: The result of the association request. + """ # relative from .network_service import NetworkService @@ -42,11 +52,25 @@ def _run( SyftError(message="Undo not supported for AssociationRequestChange") ) + # Get the network service service_ctx = context.to_service_ctx() + network_service = cast( + NetworkService, service_ctx.node.get_service(NetworkService) + ) + network_stash = network_service.stash + # Check if remote peer to be added is via rathole rathole_route = self.remote_peer.get_rathole_route() + add_rathole_route = ( + rathole_route is not None + and self.remote_peer.latest_added_route == rathole_route + ) - if rathole_route and rathole_route.rathole_token is None: + # If the remote peer is added via rathole, we don't need to ping the peer + if add_rathole_route: + network_service.rathole_service.add_host_to_server(self.remote_peer) + else: + # Pinging the remote peer to verify the connection try: remote_client: SyftClient = self.remote_peer.client_with_context( context=service_ctx @@ -77,14 +101,7 @@ def _run( except Exception as e: return Err(SyftError(message=str(e))) - network_service = cast( - NetworkService, service_ctx.node.get_service(NetworkService) - ) - - network_stash = network_service.stash - - network_service.rathole_service.add_host_to_server(self.remote_peer) - + # Adding the remote peer to the network stash result = network_stash.create_or_update_peer( service_ctx.node.verify_key, self.remote_peer ) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4532201aaae..a1607694ef4 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -1,10 +1,12 @@ # stdlib from collections.abc import Callable from enum import Enum +import logging import secrets from typing import Any # third party +from loguru import logger from result import Err from result import Result @@ -130,10 +132,9 @@ def create_or_update_peer( existing = existing.ok() existing.update_routes(peer.node_routes) result = self.update(credentials, existing) - return result else: result = self.set(credentials, peer) - return result + return result def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey @@ -202,115 +203,53 @@ def exchange_credentials_with( ) remote_node_peer = NodePeer.from_client(remote_client) - # ask the remote client to add this node (represented by `self_node_peer`) as a peer - # check locally if the remote node already exists as a peer - existing_peer_result = self.stash.get_by_uid( - context.node.verify_key, remote_node_peer.id + # Step 3: Check remotely if the self node already exists as a peer + # Update the peer if it exists, otherwise add it + remote_self_node_peer = remote_client.api.services.network.get_peer_by_name( + name=self_node_peer.name ) - if ( - existing_peer_result.is_ok() - and (existing_peer := existing_peer_result.ok()) is not None - ): - msg = [ - ( - f"{existing_peer.node_type} peer '{existing_peer.name}' already exist for " - f"{self_node_peer.node_type} '{self_node_peer.name}'." - ) - ] - if existing_peer != remote_node_peer: - result = self.stash.create_or_update_peer( - context.node.verify_key, - remote_node_peer, - ) - msg.append( - f"{existing_peer.node_type} peer '{existing_peer.name}' information change detected." - ) - if result.is_err(): - msg.append( - f"Attempt to update peer '{existing_peer.name}' information failed." - ) - return SyftError(message="\n".join(msg)) - msg.append( - f"{existing_peer.node_type} peer '{existing_peer.name}' information successfully updated." - ) - # Also check remotely if the self node already exists as a peer - remote_self_node_peer = remote_client.api.services.network.get_peer_by_name( - name=self_node_peer.name - ) - if isinstance(remote_self_node_peer, NodePeer): - msg.append( - f"{self_node_peer.node_type} '{self_node_peer.name}' already exist " - f"as a peer for {remote_node_peer.node_type} '{remote_node_peer.name}'." - ) - if remote_self_node_peer != self_node_peer: - result = remote_client.api.services.network.update_peer( - peer=self_node_peer, - ) - msg.append( - f"{self_node_peer.node_type} peer '{self_node_peer.name}' information change detected." - ) - if isinstance(result, SyftError): - msg.apnpend( - f"Attempt to remotely update {self_node_peer.node_type} peer " - f"'{self_node_peer.name}' information remotely failed." - ) - return SyftError(message="\n".join(msg)) - msg.append( - f"{self_node_peer.node_type} peer '{self_node_peer.name}' " - f"information successfully updated." - ) - msg.append( - f"Routes between {remote_node_peer.node_type} '{remote_node_peer.name}' and " - f"{self_node_peer.node_type} '{self_node_peer.name}' already exchanged." + association_request_approved = True + if isinstance(remote_self_node_peer, NodePeer): + result = remote_client.api.services.network.update_peer(peer=self_node_peer) + if isinstance(result, SyftError): + return SyftError( + message=f"Failed to add peer information on remote client : {remote_client.id}" ) - return SyftSuccess(message="\n".join(msg)) # If peer does not exist, ask the remote client to add this node # (represented by `self_node_peer`) as a peer - random_challenge = secrets.token_bytes(16) - remote_res = remote_client.api.services.network.add_peer( - peer=self_node_peer, - challenge=random_challenge, - self_node_route=remote_node_route, - verify_key=remote_node_verify_key, - ) - - if isinstance(remote_res, SyftError): - return SyftError( - message=f"returned error from add peer: {remote_res.message}" + if remote_self_node_peer is None: + random_challenge = secrets.token_bytes(16) + remote_res = remote_client.api.services.network.add_peer( + peer=self_node_peer, + challenge=random_challenge, + self_node_route=remote_node_route, + verify_key=remote_node_verify_key, ) - association_request_approved = not isinstance(remote_res, Request) + if isinstance(remote_res, SyftError): + return SyftError( + message=f"Failed to add peer to remote client: {remote_client.id}. Error: {remote_res.message}" + ) + + association_request_approved = not isinstance(remote_res, Request) - # save the remote peer for later + # Step 4: Save the remote peer for later result = self.stash.create_or_update_peer( context.node.verify_key, remote_node_peer, ) if result.is_err(): + logging.error( + f"Failed to save peer: {remote_node_peer}. Error: {result.err()}" + ) return SyftError(message="Failed to update route information.") + # Step 5: Save rathole config to enable reverse tunneling if reverse_tunnel and get_rathole_enabled(): - rathole_route = self_node_peer.get_rathole_route() - if not rathole_route: - raise Exception( - "Failed to exchange credentials. " - + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" - ) - - remote_url = GridURL( - host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port - ) - rathole_remote_addr = remote_url.as_container_host() - - remote_addr = rathole_remote_addr.url_no_protocol - - self.rathole_service.add_host_to_client( - peer_name=self_node_peer.name, - peer_id=str(self_node_peer.id), - rathole_token=rathole_route.rathole_token, - remote_addr=remote_addr, + self._add_reverse_tunneling_config_for_peer( + self_node_peer=self_node_peer, remote_node_route=remote_node_route ) return ( @@ -319,6 +258,33 @@ def exchange_credentials_with( else remote_res ) + def _add_reverse_tunneling_config_for_peer( + self, + self_node_peer: NodePeer, + remote_node_route: NodeRoute, + ) -> None: + + rathole_route = self_node_peer.get_rathole_route() + if not rathole_route: + return SyftError( + "Failed to exchange routes via . " + + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" + ) + + remote_url = GridURL( + host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port + ) + rathole_remote_addr = remote_url.as_container_host() + + remote_addr = rathole_remote_addr.url_no_protocol + + self.rathole_service.add_host_to_client( + peer_name=self_node_peer.name, + peer_id=str(self_node_peer.id), + rathole_token=rathole_route.rathole_token, + remote_addr=remote_addr, + ) + @service_method(path="network.add_peer", name="add_peer", roles=GUEST_ROLE_LEVEL) def add_peer( self, @@ -524,6 +490,22 @@ def update_peer( return SyftError( message=f"Failed to update peer '{peer.name}'. Error: {result.err()}" ) + + if context.node.node_side_type == NodeType.GATEWAY: + rathole_route = peer.get_rathole_route() + self.rathole_service.add_host_to_server(peer) if rathole_route else None + else: + self_node_peer: NodePeer = context.node.settings.to(NodePeer) + rathole_route = self_node_peer.get_rathole_route() + ( + self._add_reverse_tunneling_config_for_peer( + self_node_peer=self_node_peer, + remote_node_route=peer.pick_highest_priority_route(), + ) + if rathole_route + else None + ) + return SyftSuccess( message=f"Peer '{result.ok().name}' information successfully updated." ) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 4fe2b5f38b2..59015e21dfb 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -215,6 +215,16 @@ def from_client(client: SyftClient) -> "NodePeer": peer.node_routes.append(route) return peer + @property + def latest_added_route(self) -> NodeRoute | None: + """ + Returns the latest added route from the list of node routes. + + Returns: + NodeRoute | None: The latest added route, or None if there are no routes. + """ + return self.node_routes[-1] if self.node_routes else None + def client_with_context( self, context: NodeServiceContext ) -> Result[type[SyftClient], str]: diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 2bcde4fd2f4..e2d729de069 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -120,40 +120,6 @@ def add_host_to_client( # Update the rathole config map KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) - def forward_port_to_proxy( - self, config: RatholeConfig, entrypoint: str = "web" - ) -> None: - """Add a port to the rathole proxy config map.""" - - rathole_proxy_config_map = KubeUtils.get_configmap( - self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP - ) - - if rathole_proxy_config_map is None: - raise Exception("Rathole proxy config map not found.") - - rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] - - if not rathole_proxy: - rathole_proxy = {"http": {"routers": {}, "services": {}}} - else: - rathole_proxy = yaml.safe_load(rathole_proxy) - - rathole_proxy["http"]["services"][config.server_name] = { - "loadBalancer": {"servers": [{"url": "http://proxy:8001"}]} - } - - rathole_proxy["http"]["routers"][config.server_name] = { - "rule": "PathPrefix(`/`)", - "service": config.server_name, - "entryPoints": [entrypoint], - } - - KubeUtils.update_configmap( - config_map=rathole_proxy_config_map, - patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, - ) - def add_dynamic_addr_to_rathole( self, config: RatholeConfig, entrypoint: str = "web" ) -> None: @@ -206,14 +172,25 @@ def expose_port_on_rathole_service(self, port_name: str, port: int) -> None: config = rathole_service.raw - config["spec"]["ports"].append( - { - "name": port_name, - "port": port, - "targetPort": port, - "protocol": "TCP", - } - ) + existing_port_idx = None + for idx, port in enumerate(config["spec"]["ports"]): + if port["name"] == port_name: + print("Port already exists.", existing_port_idx, port_name) + existing_port_idx = idx + break + + if existing_port_idx is not None: + config["spec"]["ports"][existing_port_idx]["port"] = port + config["spec"]["ports"][existing_port_idx]["targetPort"] = port + else: + config["spec"]["ports"].append( + { + "name": port_name, + "port": port, + "targetPort": port, + "protocol": "TCP", + } + ) rathole_service.patch(config) From b5c97e9a29053470b6e14edc6338cc000320c613 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 31 May 2024 14:36:50 +0200 Subject: [PATCH 081/309] clean up services --- packages/syft/src/syft/node/node.py | 5 +- packages/syft/src/syft/serde/recursive.py | 3 +- .../object_search/migration_state_service.py | 73 ----------------- .../object_search/object_migration_state.py | 82 ------------------- 4 files changed, 3 insertions(+), 160 deletions(-) delete mode 100644 packages/syft/src/syft/service/object_search/migration_state_service.py delete mode 100644 packages/syft/src/syft/service/object_search/object_migration_state.py diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index c1360db00dc..eb7b6be0bbb 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -71,7 +71,6 @@ from ..service.network.utils import PeerHealthCheckTask from ..service.notification.notification_service import NotificationService from ..service.notifier.notifier_service import NotifierService -from ..service.object_search.migration_state_service import MigrateStateService from ..service.output.output_service import OutputService from ..service.policy.policy_service import PolicyService from ..service.project.project_service import ProjectService @@ -689,7 +688,7 @@ def _find_klasses_pending_for_migration( credentials=self.verify_key, role=ServiceRole.ADMIN, ) - migration_state_service = self.get_service(MigrateStateService) + migration_state_service = self.get_service(MigrationService) klasses_to_be_migrated = [] @@ -1662,8 +1661,6 @@ def create_admin_new( else: raise Exception(f"Could not create user: {result}") except Exception: - # import ipdb - # ipdb.set_trace() # stdlib import traceback diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index bc0f4cc1e66..d9d241ac051 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -316,7 +316,8 @@ def map_fqns_for_backward_compatibility(fqn: str) -> str: """for backwards compatibility with 0.8.6. Sometimes classes where moved to another file. Which is exactly why we are implementing it differently""" mapping = { - "syft.service.dataset.dataset.MarkdownDescription": "syft.util.misc_objs.MarkdownDescription" + "syft.service.dataset.dataset.MarkdownDescription": "syft.util.misc_objs.MarkdownDescription", + "syft.service.object_search.object_migration_state.SyftObjectMigrationState": "syft.service.migration.object_migration_state.SyftObjectMigrationState", # noqa: E501 } if fqn in mapping: return mapping[fqn] diff --git a/packages/syft/src/syft/service/object_search/migration_state_service.py b/packages/syft/src/syft/service/object_search/migration_state_service.py deleted file mode 100644 index ae415584d3c..00000000000 --- a/packages/syft/src/syft/service/object_search/migration_state_service.py +++ /dev/null @@ -1,73 +0,0 @@ -# stdlib - -# relative -from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ..context import AuthedServiceContext -from ..response import SyftError -from ..service import AbstractService -from ..service import service_method -from .object_migration_state import SyftMigrationStateStash -from .object_migration_state import SyftObjectMigrationState - - -@serializable() -class MigrateStateService(AbstractService): - store: DocumentStore - stash: SyftMigrationStateStash - - def __init__(self, store: DocumentStore) -> None: - self.store = store - self.stash = SyftMigrationStateStash(store=store) - - @service_method(path="migration", name="get_version") - def get_version( - self, context: AuthedServiceContext, canonical_name: str - ) -> int | SyftError: - """Search for the metadata for an object.""" - - result = self.stash.get_by_name( - canonical_name=canonical_name, credentials=context.credentials - ) - - if result.is_err(): - return SyftError(message=f"{result.err()}") - - migration_state = result.ok() - - if migration_state is None: - return SyftError( - message=f"No migration state exists for canonical name: {canonical_name}" - ) - - return migration_state.current_version - - @service_method(path="migration", name="get_state") - def get_state( - self, context: AuthedServiceContext, canonical_name: str - ) -> bool | SyftError: - result = self.stash.get_by_name( - canonical_name=canonical_name, credentials=context.credentials - ) - - if result.is_err(): - return SyftError(message=f"{result.err()}") - - return result.ok() - - @service_method(path="migration", name="register_migration_state") - def register_migration_state( - self, - context: AuthedServiceContext, - current_version: int, - canonical_name: str, - ) -> SyftObjectMigrationState | SyftError: - obj = SyftObjectMigrationState( - current_version=current_version, canonical_name=canonical_name - ) - result = self.stash.set(migration_state=obj, credentials=context.credentials) - - if result.is_err(): - return SyftError(message=f"{result.err()}") - - return result.ok() diff --git a/packages/syft/src/syft/service/object_search/object_migration_state.py b/packages/syft/src/syft/service/object_search/object_migration_state.py deleted file mode 100644 index f5b3a043ea1..00000000000 --- a/packages/syft/src/syft/service/object_search/object_migration_state.py +++ /dev/null @@ -1,82 +0,0 @@ -# stdlib - -# third party -from result import Result - -# relative -from ...node.credentials import SyftVerifyKey -from ...serde.serializable import serializable -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SyftMigrationRegistry -from ...types.syft_object import SyftObject -from ..action.action_permissions import ActionObjectPermission - - -@serializable() -class SyftObjectMigrationState(SyftObject): - __canonical_name__ = "SyftObjectMigrationState" - __version__ = SYFT_OBJECT_VERSION_2 - - __attr_unique__ = ["canonical_name"] - - canonical_name: str - current_version: int - - @property - def latest_version(self) -> int | None: - available_versions = SyftMigrationRegistry.get_versions( - canonical_name=self.canonical_name, - ) - if not available_versions: - return None - - return sorted(available_versions, reverse=True)[0] - - @property - def supported_versions(self) -> list: - return SyftMigrationRegistry.get_versions(self.canonical_name) - - -KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) - - -@serializable() -class SyftMigrationStateStash(BaseStash): - object_type = SyftObjectMigrationState - settings: PartitionSettings = PartitionSettings( - name=SyftObjectMigrationState.__canonical_name__, - object_type=SyftObjectMigrationState, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - def set( - self, - credentials: SyftVerifyKey, - migration_state: SyftObjectMigrationState, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> Result[SyftObjectMigrationState, str]: - res = self.check_type(migration_state, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().set( - credentials=credentials, - obj=res.ok(), - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ignore_duplicates=ignore_duplicates, - ) - - def get_by_name( - self, canonical_name: str, credentials: SyftVerifyKey - ) -> Result[SyftObjectMigrationState, str]: - qks = KlassNamePartitionKey.with_obj(canonical_name) - return self.query_one(credentials=credentials, qks=qks) From 4a29384d54daa5c982413f61f0ae9c9eafe0a4db Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 31 May 2024 15:19:22 +0200 Subject: [PATCH 082/309] clean up services --- .../syft/src/syft/protocol/data_protocol.py | 15 ++-- packages/syft/src/syft/serde/recursive.py | 10 +-- .../migration/object_migration_state.py | 6 +- .../syft/src/syft/types/syft_migration.py | 2 +- packages/syft/src/syft/types/syft_object.py | 86 ++++--------------- .../src/syft/types/syft_object_registry.py | 59 +++++++------ 6 files changed, 70 insertions(+), 108 deletions(-) diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 912537e1266..eb789e12fb6 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -209,7 +209,7 @@ def build_state(self, stop_key: str | None = None) -> dict: return state_dict @staticmethod - def obj_json(version, _hash, action="add"): + def obj_json(version: str | int, _hash: str, action: str = "add") -> dict: return { "version": int(version), "hash": _hash, @@ -219,7 +219,12 @@ def obj_json(version, _hash, action="add"): def diff_state(self, state: dict) -> tuple[dict, dict]: compare_dict: dict = defaultdict(dict) # what versions are in the latest code object_diff: dict = defaultdict(dict) # diff in latest code with saved json - for serde_properties in SyftObjectRegistry.__object_serialization_registry__.values(): + all_serde_propeties = [ + serde_properties + for version_dict in SyftObjectRegistry.__object_serialization_registry__.values() + for serde_properties in version_dict.values() + ] + for serde_properties in all_serde_propeties: cls, version = serde_properties[7], serde_properties[9] if issubclass(cls, SyftBaseObject): canonical_name = cls.__canonical_name__ @@ -417,9 +422,9 @@ def validate_release(self) -> None: # Update older file path to newer file path latest_protocol_fp.rename(new_protocol_file_path) - protocol_history[latest_protocol][ - "release_name" - ] = f"{current_syft_version}.json" + protocol_history[latest_protocol]["release_name"] = ( + f"{current_syft_version}.json" + ) # Save history self.file_path.write_text(json.dumps(protocol_history, indent=2) + "\n") diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index d9d241ac051..0afb967ba3e 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -166,9 +166,7 @@ def recursive_serde_register( canonical_name = fqn version = 1 - SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] = ( - serde_attributes - ) + SyftObjectRegistry.register_cls(canonical_name, version, serde_attributes) if isinstance(alias_fqn, tuple): for alias in alias_fqn: @@ -252,7 +250,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild _, _, _, - ) = SyftObjectRegistry.__object_serialization_registry__[canonical_name, version] + ) = SyftObjectRegistry.get_serde_properties(canonical_name, version) if nonrecursive or is_type: if serialize is None: @@ -376,7 +374,9 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: cls, _, version, - ) = SyftObjectRegistry.get_serde_properties(fqn, canonical_name, version) + ) = SyftObjectRegistry.get_serde_properties_bw_compatible( + fqn, canonical_name, version + ) if class_type == type(None) or fqn != "": # yes this looks stupid but it works and the opposite breaks diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index f5b3a043ea1..77ea6f60f51 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -11,8 +11,8 @@ from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SyftMigrationRegistry from ...types.syft_object import SyftObject +from ...types.syft_object_registry import SyftObjectRegistry from ..action.action_permissions import ActionObjectPermission @@ -28,7 +28,7 @@ class SyftObjectMigrationState(SyftObject): @property def latest_version(self) -> int | None: - available_versions = SyftMigrationRegistry.get_versions( + available_versions = SyftObjectRegistry.get_versions( canonical_name=self.canonical_name, ) if not available_versions: @@ -38,7 +38,7 @@ def latest_version(self) -> int | None: @property def supported_versions(self) -> list: - return SyftMigrationRegistry.get_versions(self.canonical_name) + return SyftObjectRegistry.get_versions(self.canonical_name) KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) diff --git a/packages/syft/src/syft/types/syft_migration.py b/packages/syft/src/syft/types/syft_migration.py index f3205282194..bc9d8f5c4fb 100644 --- a/packages/syft/src/syft/types/syft_migration.py +++ b/packages/syft/src/syft/types/syft_migration.py @@ -44,7 +44,7 @@ def decorator(function: Callable) -> Callable: klass_from=klass_from, klass_to=klass_to, transforms=transforms ) - SyftMigrationRegistry.register_transform( + SyftMigrationRegistry.register_migration_function( klass_type_str=klass_from_str, version_from=version_from, version_to=version_to, diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 948afa9dd56..438b789dba6 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -142,7 +142,7 @@ class Context(SyftBaseObject): class SyftMigrationRegistry: __migration_version_registry__: dict[str, dict[int, str]] = {} - __migration_transform_registry__: dict[str, dict[str, Callable]] = {} + __migration_function_registry__: dict[str, dict[str, Callable]] = {} def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -179,16 +179,16 @@ def register_version(cls, klass: type) -> None: klass_version: fqn } - @classmethod - def get_versions(cls, canonical_name: str) -> list[int]: - available_versions: dict = cls.__migration_version_registry__.get( - canonical_name, - {}, - ) - return list(available_versions.keys()) + # @classmethod + # def get_versions(cls, canonical_name: str) -> list[int]: + # available_versions: dict = cls.__migration_version_registry__.get( + # canonical_name, + # {}, + # ) + # return list(available_versions.keys()) @classmethod - def register_transform( + def register_migration_function( cls, klass_type_str: str, version_from: int, version_to: int, method: Callable ) -> None: """ @@ -211,11 +211,9 @@ def register_transform( if versions_exists: mapping_string = f"{version_from}x{version_to}" - if klass_type_str not in cls.__migration_transform_registry__: - cls.__migration_transform_registry__[klass_type_str] = {} - cls.__migration_transform_registry__[klass_type_str][mapping_string] = ( - method - ) + if klass_type_str not in cls.__migration_function_registry__: + cls.__migration_function_registry__[klass_type_str] = {} + cls.__migration_function_registry__[klass_type_str][mapping_string] = method else: raise Exception( f"Available versions for {klass_type_str} are: {available_versions}." @@ -246,9 +244,9 @@ def get_migration( mapping_string = f"{version_from}x{version_to}" if ( mapping_string - in cls.__migration_transform_registry__[klass_from] + in cls.__migration_function_registry__[klass_from] ): - return cls.__migration_transform_registry__[klass_from][ + return cls.__migration_function_registry__[klass_from][ mapping_string ] raise ValueError( @@ -273,11 +271,9 @@ def get_migration_for_version( mapping_string = f"{version_from}x{version_to}" if ( mapping_string - in cls.__migration_transform_registry__[ - type_from.__canonical_name__ - ] + in cls.__migration_function_registry__[type_from.__canonical_name__] ): - return cls.__migration_transform_registry__[klass_from][ + return cls.__migration_function_registry__[klass_from][ mapping_string ] @@ -296,39 +292,7 @@ def get_migration_for_version( ] -class RegisteredSyftObject: - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - # relative - from .syft_object_registry import SyftObjectRegistry as reg - - if hasattr(reg, "__canonical_name__") and hasattr(reg, "__version__"): - mapping_string = f"{reg.__canonical_name__}_{reg.__version__}" - - if ( - mapping_string in reg.__object_version_registry__ - and not autoreload_enabled() - ): - current_cls = reg.__object_version_registry__[mapping_string] - if reg == current_cls: - # same class so noop - return None - - # user code is reinitialized which means it might have a new address - # in memory so for that we can just skip - if "syft.user" in reg.__module__: - # this happens every time we reload the user code - return None - else: - # this shouldn't happen and is usually a mistake of reusing the - # same __canonical_name__ and __version__ in two classes - raise Exception(f"Duplicate mapping for {mapping_string} and {reg}") - else: - # only if the cls has not been registered do we want to register it - reg.__object_version_registry__[mapping_string] = reg - - -class SyftObject(SyftBaseObject, RegisteredSyftObject, SyftMigrationRegistry): +class SyftObject(SyftBaseObject, SyftMigrationRegistry): __canonical_name__ = "SyftObject" __version__ = SYFT_OBJECT_VERSION_2 @@ -479,22 +443,6 @@ def keys(self) -> KeysView[str]: def __getitem__(self, key: str | int) -> Any: return self.__dict__.__getitem__(key) # type: ignore - def _upgrade_version(self, latest: bool = True) -> "SyftObject": - # relative - from .syft_object_registry import SyftObjectRegistry - - constructor = SyftObjectRegistry.versioned_class( - name=self.__canonical_name__, version=self.__version__ + 1 - ) - if not constructor: - return self - else: - # should we do some kind of recursive upgrades? - upgraded = constructor._from_previous_version(self) - if latest: - upgraded = upgraded._upgrade_version(latest=latest) - return upgraded - # transform from one supported type to another def to(self, projection: type, context: Context | None = None) -> Any: # relative diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 964bf8ca4bf..01a9ca4348f 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -17,11 +17,26 @@ class SyftObjectRegistry: - __object_version_registry__: dict[ - str, type["SyftObject"] | type["SyftObjectRegistry"] - ] = {} __object_transform_registry__: dict[str, Callable] = {} - __object_serialization_registry__: dict[tuple[str, int], tuple] = {} + __object_serialization_registry__: dict[str, dict[int, tuple]] = {} + + @classmethod + def register_cls( + cls, canonical_name: str, version: int, serde_attributes: tuple + ) -> None: + if canonical_name not in cls.__object_serialization_registry__: + cls.__object_serialization_registry__[canonical_name] = {} + cls.__object_serialization_registry__[canonical_name][version] = ( + serde_attributes + ) + + @classmethod + def get_versions(cls, canonical_name: str) -> list[int]: + available_versions: dict = cls.__object_serialization_registry__.get( + canonical_name, + {}, + ) + return list(available_versions.keys()) @classmethod def get_canonical_name(cls, obj: Any) -> str: @@ -38,12 +53,18 @@ def get_canonical_name(cls, obj: Any) -> str: return fqn @classmethod - def get_serde_properties(cls, fqn: str, canonical_name: str, version: int) -> tuple: + def get_serde_properties(cls, canonical_name: str, version: int) -> tuple: + return cls.__object_serialization_registry__[canonical_name][version] + + @classmethod + def get_serde_properties_bw_compatible( + cls, fqn: str, canonical_name: str, version: int + ) -> tuple: # relative from ..serde.recursive import TYPE_BANK if canonical_name != "" and canonical_name is not None: - return cls.__object_serialization_registry__[canonical_name, version] + return cls.get_serde_properties(canonical_name, version) else: # this is for backward compatibility with 0.8.6 try: @@ -72,9 +93,8 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: int) -> tu ] ) try: - res = cls.__object_serialization_registry__[ - canonical_name, version_086 - ] + res = cls.get_serde_properties(canonical_name, version_086) + except Exception: print( f"could not find {canonical_name} {version_086} in ObjectRegistry" @@ -85,9 +105,7 @@ def get_serde_properties(cls, fqn: str, canonical_name: str, version: int) -> tu # TODO, add refactoring for non syftobject versions canonical_name = fqn version = 1 - return cls.__object_serialization_registry__[ - canonical_name, version - ] + return cls.get_serde_properties(canonical_name, version) except Exception as e: print(e) raise @@ -101,7 +119,10 @@ def has_serde_class( from ..serde.recursive import TYPE_BANK if canonical_name != "" and canonical_name is not None: - return (canonical_name, version) in cls.__object_serialization_registry__ + return ( + canonical_name in cls.__object_serialization_registry__ + and version in cls.__object_serialization_registry__[canonical_name] + ) else: # this is for backward compatibility with 0.8.6 return fqn in TYPE_BANK @@ -152,15 +173,3 @@ def get_transform( f"No mapping found for: {type_from} to {type_to} in" f"the registry: {SyftObjectRegistry.__object_transform_registry__.keys()}" ) - - @classmethod - def versioned_class( - cls, name: str, version: int - ) -> type["SyftObject"] | type["SyftObjectRegistry"] | None: - # relative - from .syft_object_registry import SyftObjectRegistry - - mapping_string = f"{name}_{version}" - if mapping_string not in SyftObjectRegistry.__object_version_registry__: - return None - return SyftObjectRegistry.__object_version_registry__[mapping_string] From e27c5cd1a05d33aff6d20319da5b5bafb0f4fdd6 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sat, 1 Jun 2024 14:44:32 +0530 Subject: [PATCH 083/309] make id optional in NodeConnection and its subclasses - remove using route id to check for existing routes --- packages/syft/src/syft/client/client.py | 3 +-- packages/syft/src/syft/client/connection.py | 7 +++-- .../src/syft/protocol/protocol_version.json | 16 ++++++++++-- .../syft/service/network/network_service.py | 10 +++---- .../src/syft/service/network/node_peer.py | 26 +++++-------------- .../syft/service/network/rathole_service.py | 4 +-- .../syft/src/syft/service/network/routes.py | 5 ++-- 7 files changed, 36 insertions(+), 35 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 28e975efa03..33a509dfc18 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -48,7 +48,6 @@ from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..types.grid_url import GridURL -from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.uid import UID from ..util.logger import debug @@ -391,7 +390,7 @@ def get_client_type(self) -> type[SyftClient] | SyftError: @serializable() class PythonConnection(NodeConnection): __canonical_name__ = "PythonConnection" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 node: AbstractNode proxy_target_uid: UID | None = None diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py index e82db863e8a..0899532818a 100644 --- a/packages/syft/src/syft/client/connection.py +++ b/packages/syft/src/syft/client/connection.py @@ -2,13 +2,16 @@ from typing import Any # relative -from ..types.syft_object import SYFT_OBJECT_VERSION_2 +from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.syft_object import SyftObject +from ..types.uid import UID class NodeConnection(SyftObject): __canonical_name__ = "NodeConnection" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 + + id: UID | None = None # type: ignore def get_cache_key(self) -> str: raise NotImplementedError diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 395f1263ef1..a32451b058b 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -217,7 +217,7 @@ }, "3": { "version": 3, - "hash": "5e363abe2875beec89a3f4f4f5c53e15f9893fb98e5da71e2fa6c0f619883b1f", + "hash": "9162f038f0f6401c4cf4d1b517c40805d1f291bd69a6e76f3c1ee9e5095de2e5", "action": "add" } }, @@ -229,7 +229,7 @@ }, "3": { "version": 3, - "hash": "89ace8067c392b802fe23a99446a8ae464a9dad0b49d8b2c3871b631451acec4", + "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", "action": "add" } }, @@ -239,6 +239,18 @@ "hash": "010d9aaca95f3fdfc8d1f97d01c1bd66483da774a59275b310c08d6912f7f863", "action": "add" } + }, + "PythonConnection": { + "2": { + "version": 2, + "hash": "eb479c671fc112b2acbedb88bc5624dfdc9592856c04c22c66410f6c863e1708", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index a1607694ef4..fdfe6c184ec 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -4,9 +4,9 @@ import logging import secrets from typing import Any +from typing import cast # third party -from loguru import logger from result import Err from result import Result @@ -263,10 +263,9 @@ def _add_reverse_tunneling_config_for_peer( self_node_peer: NodePeer, remote_node_route: NodeRoute, ) -> None: - rathole_route = self_node_peer.get_rathole_route() if not rathole_route: - return SyftError( + raise Exception( "Failed to exchange routes via . " + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" ) @@ -491,7 +490,8 @@ def update_peer( message=f"Failed to update peer '{peer.name}'. Error: {result.err()}" ) - if context.node.node_side_type == NodeType.GATEWAY: + node_side_type = cast(NodeType, context.node.node_side_type) + if node_side_type == NodeType.GATEWAY: rathole_route = peer.get_rathole_route() self.rathole_service.add_host_to_server(peer) if rathole_route else None else: @@ -899,7 +899,7 @@ def _get_remote_node_peer_by_verify_key( remote_node_peer = remote_node_peer.ok() if remote_node_peer is None: return SyftError( - message=f"Can't retrive {remote_node_peer.name} from the store of peers (None)." + message=f"Can't retrieve {remote_node_peer.name} from the store of peers (None)." ) return remote_node_peer diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 59015e21dfb..b08a793e96f 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -86,22 +86,17 @@ class NodePeer(SyftObject): ping_status_message: str | None = None pinged_timestamp: DateTime | None = None - def existed_route( - self, route: NodeRouteType | None = None, route_id: UID | None = None - ) -> tuple[bool, int | None]: + def existed_route(self, route: NodeRouteType) -> tuple[bool, int | None]: """Check if a route exists in self.node_routes Args: route: the route to be checked. For now it can be either - HTTPNodeRoute or PythonNodeRoute or VeilidNodeRoute - route_id: the id of the route to be checked + HTTPNodeRoute or PythonNodeRoute Returns: if the route exists, returns (True, index of the existed route in self.node_routes) if the route does not exist returns (False, None) """ - if route_id is None and route is None: - raise ValueError("Either route or route_id should be provided in args") if route: if not isinstance(route, HTTPNodeRoute | PythonNodeRoute | VeilidNodeRoute): @@ -110,11 +105,6 @@ def existed_route( if route == r: return (True, i) - elif route_id: - for i, r in enumerate(self.node_routes): - if r.id == route_id: - return (True, i) - return (False, None) def assign_highest_priority(self, route: NodeRoute) -> NodeRoute: @@ -131,7 +121,7 @@ def assign_highest_priority(self, route: NodeRoute) -> NodeRoute: route.priority = current_max_priority + 1 return route - def update_route(self, route: NodeRoute) -> NodeRoute | None: + def update_route(self, route: NodeRoute) -> None: """ Update the route for the node. If the route already exists, return it. @@ -140,17 +130,13 @@ def update_route(self, route: NodeRoute) -> NodeRoute | None: Args: route (NodeRoute): The new route to be added to the peer. - - Returns: - NodeRoute | None: if the route already exists, return it, else returns None """ - existed, _ = self.existed_route(route) + existed, idx = self.existed_route(route) if existed: - return route + self.node_routes[idx] = route # type: ignore else: new_route = self.assign_highest_priority(route) self.node_routes.append(new_route) - return None def update_routes(self, new_routes: list[NodeRoute]) -> None: """ @@ -191,7 +177,7 @@ def update_existed_route_priority( message="Priority must be greater than 0. Now it is {priority}." ) - existed, index = self.existed_route(route_id=route.id) + existed, index = self.existed_route(route=route) if not existed or index is None: return SyftError(message=f"Route with id {route.id} does not exist.") diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index e2d729de069..2b879f2df27 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -173,8 +173,8 @@ def expose_port_on_rathole_service(self, port_name: str, port: int) -> None: config = rathole_service.raw existing_port_idx = None - for idx, port in enumerate(config["spec"]["ports"]): - if port["name"] == port_name: + for idx, existing_port in enumerate(config["spec"]["ports"]): + if existing_port["name"] == port_name: print("Port already exists.", existing_port_idx, port_name) existing_port_idx = idx break diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 1d9ec116467..b1ff68f6b72 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -18,7 +18,6 @@ from ...node.worker_settings import WorkerSettings from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext @@ -90,6 +89,7 @@ class HTTPNodeRoute(SyftObject, NodeRoute): __canonical_name__ = "HTTPNodeRoute" __version__ = SYFT_OBJECT_VERSION_3 + id: UID | None = None # type: ignore host_or_ip: str private: bool = False protocol: str = "http" @@ -119,8 +119,9 @@ def __str__(self) -> str: @serializable() class PythonNodeRoute(SyftObject, NodeRoute): __canonical_name__ = "PythonNodeRoute" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 + id: UID | None = None # type: ignore worker_settings: WorkerSettings proxy_target_uid: UID | None = None priority: int = 1 From b9a0e6c72b529a9316167b91bec7c8bc0026dab5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 3 Jun 2024 11:17:08 +0530 Subject: [PATCH 084/309] add /rathole prefix to dynamic proxy config --- .../syft/src/syft/protocol/protocol_version.json | 12 ++++++++++++ .../syft/src/syft/service/network/rathole_service.py | 9 +++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index a32451b058b..9aa5efd25fb 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -251,6 +251,18 @@ "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", "action": "add" } + }, + "PythonNodeRoute": { + "2": { + "version": 2, + "hash": "3eca5767ae4a8fbe67744509e58c6d9fb78f38fa0a0f7fcf5960ab4250acc1f0", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index 2b879f2df27..abc3bc878a5 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -135,7 +135,7 @@ def add_dynamic_addr_to_rathole( rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] if not rathole_proxy: - rathole_proxy = {"http": {"routers": {}, "services": {}}} + rathole_proxy = {"http": {"routers": {}, "services": {}, "middlewares": {}}} else: rathole_proxy = yaml.safe_load(rathole_proxy) @@ -145,15 +145,20 @@ def add_dynamic_addr_to_rathole( } } + rathole_proxy["http"]["middlewares"]["strip-rathole-prefix"] = { + "replacePathRegex:": {"regex": "^/rathole/(.*)", "replacement": "/$1"} + } + proxy_rule = ( f"Host(`{config.server_name}.syft.local`) || " - f"HostHeader(`{config.server_name}.syft.local`) && PathPrefix(`/`)" + f"HostHeader(`{config.server_name}.syft.local`) && PathPrefix(`/rathole`)" ) rathole_proxy["http"]["routers"][config.server_name] = { "rule": proxy_rule, "service": config.server_name, "entryPoints": [entrypoint], + "middlewares": ["strip-rathole-prefix"], } KubeUtils.update_configmap( From a7de40706d813ddd41a0cc50e9d491e8bc70c1cc Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 3 Jun 2024 11:22:29 +0530 Subject: [PATCH 085/309] add rathole prefix in http connection when rathole token present --- packages/syft/src/syft/client/client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 33a509dfc18..7e74a19261d 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -124,6 +124,7 @@ class Routes(Enum): ROUTE_API_CALL = f"{API_PATH}/api_call" ROUTE_BLOB_STORE = "/blob" STREAM = f"{API_PATH}/stream" + RATHOLE = "/rathole" @serializable(attrs=["proxy_target_uid", "url", "rathole_token"]) @@ -191,7 +192,7 @@ def _make_get( url = self.url if self.rathole_token: - url = GridURL.from_url(INTERNAL_PROXY_URL) + url = GridURL.from_url(INTERNAL_PROXY_URL).with_path(Routes.RATHOLE.value) headers = {"Host": self.url.host_or_ip} url = url.with_path(path) @@ -226,7 +227,7 @@ def _make_post( url = self.url if self.rathole_token: - url = GridURL.from_url(INTERNAL_PROXY_URL) + url = GridURL.from_url(INTERNAL_PROXY_URL).with_path(Routes.RATHOLE.value) headers = {"Host": self.url.host_or_ip} url = url.with_path(path) @@ -338,7 +339,9 @@ def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError: headers = {} if self.rathole_token: - api_url = GridURL.from_url(INTERNAL_PROXY_URL) + api_url = GridURL.from_url(INTERNAL_PROXY_URL).with_path( + Routes.RATHOLE.value + ) api_url = api_url.with_path(self.routes.ROUTE_API_CALL.value) headers = {"Host": self.url.host_or_ip} else: From 006f4d47e2954311a62a3ea7bd367a2ba3fc5a9e Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Wed, 5 Jun 2024 11:56:53 +0200 Subject: [PATCH 086/309] - --- packages/syft/src/syft/service/migration/migration_service.py | 3 ++- packages/syft/src/syft/service/response.py | 2 +- packages/syft/src/syft/store/mongo_document_store.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index c6708ebc850..80a64ec8762 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -188,6 +188,7 @@ def _update_migrated_objects( canonical_name = klass.__canonical_name__ object_partition = self.store.partitions.get(canonical_name) qk = object_partition.settings.store_key.with_obj(migrated_object.id) + result = object_partition._update( context.credentials, qk=qk, @@ -198,7 +199,7 @@ def _update_migrated_objects( ) if result.is_err(): - return result.err() + return result return Ok(value="success") @service_method( diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 37227046c5c..0d5c04c64ad 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -42,7 +42,7 @@ def _repr_html_class_(self) -> str: def _repr_html_(self) -> str: return ( f'
' - + f"{type(self).__name__}: {self.message}

" + + f"{type(self).__name__}: {self.message.replace("\n", "
")}
" ) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index fa38d6c1ba8..65640052133 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -336,6 +336,7 @@ def _update( obj: SyftObject, has_permission: bool = False, overwrite: bool = False, + allow_missing_keys=False, ) -> Result[SyftObject, str]: collection_status = self.collection if collection_status.is_err(): From d52643bbb3c3e25ff499b56be248300e72ef4dbd Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sun, 9 Jun 2024 23:56:40 +0530 Subject: [PATCH 087/309] update internal proxy url to include /rathole - fix syntax in middleware to strip rathole prefix - fix proxy config map to ignore if PathPrefix doesn't match --- .../helm/syft/templates/proxy/proxy-configmap.yaml | 6 +++--- packages/syft/src/syft/client/client.py | 11 ++++------- .../syft/src/syft/service/network/network_service.py | 5 ++--- .../syft/src/syft/service/network/rathole_service.py | 2 +- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index e5055f03f95..748dfee78c4 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -28,17 +28,17 @@ data: - url: "http://rathole:2333" routers: rathole: - rule: "PathPrefix(`/`) && Headers(`Upgrade`, `websocket`)" + rule: "PathPrefix(`/`) && Headers(`Upgrade`, `websocket`) && !PathPrefix(`/rathole`)" entryPoints: - "web" service: "rathole" frontend: - rule: "PathPrefix(`/`)" + rule: "PathPrefix(`/`) && !PathPrefix(`/rathole`)" entryPoints: - "web" service: "frontend" backend: - rule: "PathPrefix(`/api`) || PathPrefix(`/docs`) || PathPrefix(`/redoc`)" + rule: "(PathPrefix(`/api`) || PathPrefix(`/docs`) || PathPrefix(`/redoc`)) && !PathPrefix(`/rathole`)" entryPoints: - "web" service: "backend" diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 7e74a19261d..3a1c047ac1f 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -113,7 +113,7 @@ def forward_message_to_proxy( API_PATH = "/api/v2" DEFAULT_PYGRID_PORT = 80 DEFAULT_PYGRID_ADDRESS = f"http://localhost:{DEFAULT_PYGRID_PORT}" -INTERNAL_PROXY_URL = "http://proxy:80" +INTERNAL_PROXY_TO_RATHOLE = "http://proxy:80/rathole/" class Routes(Enum): @@ -124,7 +124,6 @@ class Routes(Enum): ROUTE_API_CALL = f"{API_PATH}/api_call" ROUTE_BLOB_STORE = "/blob" STREAM = f"{API_PATH}/stream" - RATHOLE = "/rathole" @serializable(attrs=["proxy_target_uid", "url", "rathole_token"]) @@ -192,7 +191,7 @@ def _make_get( url = self.url if self.rathole_token: - url = GridURL.from_url(INTERNAL_PROXY_URL).with_path(Routes.RATHOLE.value) + url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) headers = {"Host": self.url.host_or_ip} url = url.with_path(path) @@ -227,7 +226,7 @@ def _make_post( url = self.url if self.rathole_token: - url = GridURL.from_url(INTERNAL_PROXY_URL).with_path(Routes.RATHOLE.value) + url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) headers = {"Host": self.url.host_or_ip} url = url.with_path(path) @@ -339,9 +338,7 @@ def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError: headers = {} if self.rathole_token: - api_url = GridURL.from_url(INTERNAL_PROXY_URL).with_path( - Routes.RATHOLE.value - ) + api_url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) api_url = api_url.with_path(self.routes.ROUTE_API_CALL.value) headers = {"Host": self.url.host_or_ip} else: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index fdfe6c184ec..adb117be61a 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -608,11 +608,10 @@ def add_route( if isinstance(remote_node_peer, SyftError): return remote_node_peer # add and update the priority for the peer - existed_route: NodeRoute | None = remote_node_peer.update_route(route) - if existed_route: + if route in remote_node_peer.node_routes: return SyftSuccess( message=f"The route already exists between '{context.node.name}' and " - f"peer '{remote_node_peer.name}' with id '{existed_route.id}'." + f"peer '{remote_node_peer.name}'." ) # update the peer in the store with the updated routes result = self.stash.update( diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_service.py index abc3bc878a5..afdf48503d7 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_service.py @@ -146,7 +146,7 @@ def add_dynamic_addr_to_rathole( } rathole_proxy["http"]["middlewares"]["strip-rathole-prefix"] = { - "replacePathRegex:": {"regex": "^/rathole/(.*)", "replacement": "/$1"} + "replacePathRegex": {"regex": "^/rathole/(.*)", "replacement": "/$1"} } proxy_rule = ( From ed6ce110c19cb62e419dfa25d2f46cac5b6f4906 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 00:43:22 +0530 Subject: [PATCH 088/309] update protocol version --- .../src/syft/protocol/protocol_version.json | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 900bf516616..98c6b4576ba 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -284,6 +284,30 @@ "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", "action": "add" } + }, + "HTTPConnection": { + "2": { + "version": 2, + "hash": "68409295f8916ceb22a8cf4abf89f5e4bcff0d75dc37e16ede37250ada28df59", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "9162f038f0f6401c4cf4d1b517c40805d1f291bd69a6e76f3c1ee9e5095de2e5", + "action": "add" + } + }, + "HTTPNodeRoute": { + "2": { + "version": 2, + "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", + "action": "add" + } } } } From 6ea2cff3e59bc7a0b6ec97615041529eb8d2aff8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 15:53:56 +0530 Subject: [PATCH 089/309] fix parameter name reference in update peer api - fix reference of node side instead of node side type from context - return the updated object in mongo --- .../syft/src/syft/service/network/network_service.py | 10 +++++++--- packages/syft/src/syft/store/mongo_document_store.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 86b744cc36b..eba9a7a6b3c 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -7,6 +7,7 @@ from typing import cast # third party +from loguru import logger from result import Result # relative @@ -216,9 +217,12 @@ def exchange_credentials_with( id=self_node_peer.id, node_routes=self_node_peer.node_routes ) result = remote_client.api.services.network.update_peer( - update_peer=updated_peer + peer_update=updated_peer ) if isinstance(result, SyftError): + logger.error( + f"Failed to update peer information on remote client. {result.message}" + ) return SyftError( message=f"Failed to add peer information on remote client : {remote_client.id}" ) @@ -502,8 +506,8 @@ def update_peer( peer = result.ok() - node_side_type = cast(NodeType, context.node.node_side_type) - if node_side_type == NodeType.GATEWAY: + node_side_type = cast(NodeType, context.node.node_type) + if node_side_type.value == NodeType.GATEWAY.value: rathole_route = peer.get_rathole_route() self.rathole_service.add_host_to_server(peer) if rathole_route else None else: diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index 59d6799c2bb..60040ce4a9d 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -376,7 +376,7 @@ def _update( except Exception as e: return Err(f"Failed to update obj: {obj} with qk: {qk}. Error: {e}") - return Ok(obj) + return Ok(prev_obj) else: return Err(f"Failed to update obj {obj}, you have no permission") From 4dd0b7fb878dd20d859d98219f513cbc86dfe996 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 17:16:36 +0530 Subject: [PATCH 090/309] added notebook --- notebooks/Experimental/Network.ipynb | 8036 ++++++++++++++++++++++++++ 1 file changed, 8036 insertions(+) create mode 100644 notebooks/Experimental/Network.ipynb diff --git a/notebooks/Experimental/Network.ipynb b/notebooks/Experimental/Network.ipynb new file mode 100644 index 00000000000..88240421fb3 --- /dev/null +++ b/notebooks/Experimental/Network.ipynb @@ -0,0 +1,8036 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bd9a2226-3e53-4f27-9213-75a8c3ff9176", + "metadata": {}, + "outputs": [], + "source": [ + "import syft as sy" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fddf8d07-d154-4284-a27b-d74e35d3f851", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged into as \n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gateway_client = sy.login(url=\"http://localhost\", port=9081, email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8f7b106d-b784-45d8-b54d-4ce2de2da453", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged into as \n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "domain_client = sy.login(url=\"http://localhost\", port=9082, email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ff504949-620d-4e26-beee-0d39e0e502eb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: Connected domain 'syft-dev-node' to gateway 'syft-dev-node'. Routes Exchanged

" + ], + "text/plain": [ + "SyftSuccess: Connected domain 'syft-dev-node' to gateway 'syft-dev-node'. Routes Exchanged" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain_client.connect_to_gateway(gateway_client, reverse_tunnel=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ba7bc71a-4e6a-4429-9588-7b3d0ed19e27", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "

Request List

\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

Total: 0

\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gateway_client.api.services.request" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5b4984e1-331e-4fd8-b012-768fc613f48a", + "metadata": {}, + "outputs": [], + "source": [ + "# gateway_client.api.services.request[0].approve()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "90dc44bd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "

NodePeer List

\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

Total: 0

\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "[syft.service.network.node_peer.NodePeer]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node_peers = gateway_client.api.network.get_all_peers()\n", + "node_peers" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8c06aaa6-4157-42d1-959f-9d47722a3420", + "metadata": {}, + "outputs": [], + "source": [ + "node_peer = gateway_client.api.network.get_all_peers()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cb63a77b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[syft.service.network.routes.HTTPNodeRoute]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node_peer.node_routes" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "61882e86", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'syft_node_location': ,\n", + " 'syft_client_verify_key': cc82ec7abd5d516e6972e787ddaafe8b04e223228436b09c026be97b80ad6246,\n", + " 'id': None,\n", + " 'host_or_ip': 'syft-dev-node.syft.local',\n", + " 'private': False,\n", + " 'protocol': 'http',\n", + " 'port': 9082,\n", + " 'proxy_target_uid': None,\n", + " 'priority': 1,\n", + " 'rathole_token': 'b95e8d239d563e6fcc3a4f44a5292177e608a7b0b1194e6106adc1998a1b68a1'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node_peer.node_routes[0].__dict__" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "fb19dbc6-869b-46dc-92e3-5e75ee6d0b06", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'syft_node_location': ,\n", + " 'syft_client_verify_key': 5ed40db70e275e30e9001cda808922c4542c84f7472a106c00158795b9388b0a,\n", + " 'id': None,\n", + " 'host_or_ip': 'host.k3d.internal',\n", + " 'private': False,\n", + " 'protocol': 'http',\n", + " 'port': 9081,\n", + " 'proxy_target_uid': None,\n", + " 'priority': 1,\n", + " 'rathole_token': None}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain_client.api.network.get_all_peers()[0].node_routes[0].__dict__" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "32d09a51", + "metadata": {}, + "outputs": [], + "source": [ + "# node_peer.client_with_key(sy.SyftSigningKey.generate())" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b7d9e41d", + "metadata": {}, + "outputs": [], + "source": [ + "# gateway_client.api.network.delete_route(node_peer.verify_key, node_peer.node_routes[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8fa24ec7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + "
\n", + " \"Logo\"\n",\n", + "

Welcome to syft-dev-node

\n", + "
\n", + " URL: http://localhost:9081
Node Type: Gateway
Node Side Type: High Side
Syft Version: 0.8.7-beta.10
\n", + "
\n", + "
\n", + " ⓘ \n", + " This node is run by the library PySyft to learn more about how it works visit\n", + " github.com/OpenMined/PySyft.\n", + "
\n", + "

Commands to Get Started

\n", + " \n", + "
    \n", + " \n", + "
  • <your_client>\n", + " .domains - list domains connected to this gateway
  • \n", + "
  • <your_client>\n", + " .proxy_client_for - get a connection to a listed domain
  • \n", + "
  • <your_client>\n", + " .login - log into the gateway
  • \n", + " \n", + "
\n", + " \n", + "

\n", + " " + ], + "text/plain": [ + ": HTTPConnection: http://localhost:9081>" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gateway_client" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3a081250-abc3-43a3-9e06-ff0c3a362ebf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "

NodePeer List

\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

Total: 0

\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/markdown": [ + "```python\n", + "class ProxyClient:\n", + " id: str = 37073e9151ce4fa9b665501ec03924c8\n", + "\n", + "```" + ], + "text/plain": [ + "syft.client.gateway_client.ProxyClient" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gateway_client.peers" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b6fedfe4-9362-47c9-9342-5cf6eacde8ab", + "metadata": {}, + "outputs": [], + "source": [ + "domain_client_proxy = gateway_client.peers[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f1940e00-0337-4b56-88c2-d70f397a7016", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "```python\n", + "class HTTPConnection:\n", + " id: str = None\n", + "\n", + "```" + ], + "text/plain": [ + "HTTPConnection: http://localhost:9081" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain_client_proxy.connection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "613125c5-6321-4238-852c-ff0cfcd9526a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.1.-1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 51b72095100444a405c6f8ae443c732417980c07 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 18:30:02 +0530 Subject: [PATCH 091/309] remove hagrid and docker compose files --- packages/grid/docker-compose.build.yml | 29 - packages/grid/docker-compose.dev.yml | 63 - packages/grid/docker-compose.pull.yml | 20 - packages/grid/docker-compose.yml | 249 -- packages/hagrid/hagrid/cli.py | 4425 ------------------------ 5 files changed, 4786 deletions(-) delete mode 100644 packages/grid/docker-compose.build.yml delete mode 100644 packages/grid/docker-compose.dev.yml delete mode 100644 packages/grid/docker-compose.pull.yml delete mode 100644 packages/grid/docker-compose.yml delete mode 100644 packages/hagrid/hagrid/cli.py diff --git a/packages/grid/docker-compose.build.yml b/packages/grid/docker-compose.build.yml deleted file mode 100644 index a0175bc762a..00000000000 --- a/packages/grid/docker-compose.build.yml +++ /dev/null @@ -1,29 +0,0 @@ -version: "3.8" -services: - frontend: - build: - context: ${RELATIVE_PATH}./frontend - dockerfile: frontend.dockerfile - target: "${FRONTEND_TARGET:-grid-ui-development}" - - backend: - build: - context: ${RELATIVE_PATH}../ - dockerfile: ./grid/backend/backend.dockerfile - target: "backend" - - seaweedfs: - build: - context: ${RELATIVE_PATH}./seaweedfs - dockerfile: seaweedfs.dockerfile - - worker: - build: - context: ${RELATIVE_PATH}../ - dockerfile: ./grid/backend/backend.dockerfile - target: "backend" - - rathole: - build: - context: ${RELATIVE_PATH}./rathole - dockerfile: rathole.dockerfile diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml deleted file mode 100644 index bcde98488e7..00000000000 --- a/packages/grid/docker-compose.dev.yml +++ /dev/null @@ -1,63 +0,0 @@ -version: "3.8" -services: - proxy: - ports: - - "8080" - command: - - "--api" # admin panel - - "--api.insecure=true" # admin panel no password - - frontend: - volumes: - - ${RELATIVE_PATH}./frontend/src:/app/src - - ${RELATIVE_PATH}./frontend/static:/app/static - - ${RELATIVE_PATH}./frontend/svelte.config.js:/app/svelte.config.js - - ${RELATIVE_PATH}./frontend/tsconfig.json:/app/tsconfig.json - - ${RELATIVE_PATH}./frontend/vite.config.ts:/app/vite.config.ts - environment: - - FRONTEND_TARGET=grid-ui-development - - mongo: - ports: - - "27017" - - backend: - volumes: - - ${RELATIVE_PATH}./backend/grid:/root/app/grid - - ${RELATIVE_PATH}../syft:/root/app/syft - - ${RELATIVE_PATH}./data/package-cache:/root/.cache - environment: - - DEV_MODE=True - stdin_open: true - tty: true - - worker: - volumes: - - ${RELATIVE_PATH}./backend/grid:/root/app/grid - - ${RELATIVE_PATH}../syft:/root/app/syft - - ${RELATIVE_PATH}./data/package-cache:/root/.cache - environment: - - DEV_MODE=True - - WATCHFILES_FORCE_POLLING=true - stdin_open: true - tty: true - - rathole: - volumes: - - ${RELATIVE_PATH}./rathole/:/root/app/ - environment: - - DEV_MODE=True - - APP_PORT=5555 - - APP_LOG_LEVEL=debug - stdin_open: true - tty: true - ports: - - 2333:2333 - - seaweedfs: - volumes: - - ./data/seaweedfs:/data - ports: - - "9333" # admin web port - - "8888" # filer web port - - "8333" # S3 API port diff --git a/packages/grid/docker-compose.pull.yml b/packages/grid/docker-compose.pull.yml deleted file mode 100644 index e68ed03d968..00000000000 --- a/packages/grid/docker-compose.pull.yml +++ /dev/null @@ -1,20 +0,0 @@ -version: "3.8" -services: - seaweedfs: - image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}:${VERSION-latest}" - - proxy: - image: ${DOCKER_IMAGE_TRAEFIK?Variable not set}:${TRAEFIK_VERSION?Variable not set} - - mongo: - image: "${MONGO_IMAGE}:${MONGO_VERSION}" - - jaeger: - image: jaegertracing/all-in-one:1.37 - - # Temporary fix until we refactor pull, build, launch UI step during hagrid launch - worker: - image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" - - rathole: - image: "${DOCKER_IMAGE_RATHOLE?Variable not set}:${VERSION-latest}" diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml deleted file mode 100644 index 153dfb718ba..00000000000 --- a/packages/grid/docker-compose.yml +++ /dev/null @@ -1,249 +0,0 @@ -version: "3.8" -services: - # docker-host: - # image: qoomon/docker-host - # cap_add: - # - net_admin - # - net_raw - - proxy: - restart: always - hostname: ${NODE_NAME?Variable not set} - image: ${DOCKER_IMAGE_TRAEFIK?Variable not set}:${TRAEFIK_VERSION?Variable not set} - profiles: - - proxy - networks: - - "${TRAEFIK_PUBLIC_NETWORK?Variable not set}" - - default - volumes: - - "./traefik/docker/traefik.yml:/etc/traefik/traefik.yml" - - "./traefik/docker/dynamic.yml:/etc/traefik/conf/dynamic.yml" - environment: - - SERVICE_NAME=proxy - - RELEASE=${RELEASE:-production} - - HOSTNAME=${NODE_NAME?Variable not set} - - HTTP_PORT=${HTTP_PORT} - - HTTPS_PORT=${HTTPS_PORT} - ports: - - "${HTTP_PORT}:81" - extra_hosts: - - "host.docker.internal:host-gateway" - labels: - - "orgs.openmined.syft=this is a syft proxy container" - - # depends_on: - # - "docker-host" - - frontend: - restart: always - image: "${DOCKER_IMAGE_FRONTEND?Variable not set}:${VERSION-latest}" - profiles: - - frontend - depends_on: - - proxy - environment: - - SERVICE_NAME=frontend - - RELEASE=${RELEASE:-production} - - NODE_TYPE=${NODE_TYPE?Variable not set} - - FRONTEND_TARGET=${FRONTEND_TARGET} - - VERSION=${VERSION} - - VERSION_HASH=${VERSION_HASH} - - PORT=80 - - HTTP_PORT=${HTTP_PORT} - - HTTPS_PORT=${HTTPS_PORT}RELOAD - - BACKEND_API_BASE_URL=${BACKEND_API_BASE_URL} - extra_hosts: - - "host.docker.internal:host-gateway" - labels: - - "orgs.openmined.syft=this is a syft frontend container" - - rathole: - restart: always - image: "${DOCKER_IMAGE_RATHOLE?Variable not set}:${VERSION-latest}" - profiles: - - rathole - depends_on: - - proxy - environment: - - SERVICE_NAME=rathole - - APP_LOG_LEVEL=${APP_LOG_LEVEL:-info} - - MODE=${MODE} - - DEV_MODE=${DEV_MODE} - - APP_PORT=${APP_PORT:-5555} - - RATHOLE_PORT=${RATHOLE_PORT:-2333} - extra_hosts: - - "host.docker.internal:host-gateway" - labels: - - "orgs.openmined.syft=this is a syft rathole container" - - worker: - restart: always - image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" - hostname: ${NODE_NAME?Variable not set} - profiles: - - worker - env_file: - - .env - environment: - - SERVICE_NAME=worker - - RELEASE=${RELEASE:-production} - - VERSION=${VERSION} - - VERSION_HASH=${VERSION_HASH} - - NODE_TYPE=${NODE_TYPE?Variable not set} - - NODE_NAME=${NODE_NAME?Variable not set} - - STACK_API_KEY=${STACK_API_KEY} - - PORT=${HTTP_PORT} - - IGNORE_TLS_ERRORS=${IGNORE_TLS_ERRORS?False} - - HTTP_PORT=${HTTP_PORT} - - HTTPS_PORT=${HTTPS_PORT} - - USE_BLOB_STORAGE=${USE_BLOB_STORAGE} - - CONTAINER_HOST=${CONTAINER_HOST} - - TRACE=False # TODO: Trace Mode is set to False, until jaegar is integrated - - JAEGER_HOST=${JAEGER_HOST} - - JAEGER_PORT=${JAEGER_PORT} - - ASSOCIATION_TIMEOUT=${ASSOCIATION_TIMEOUT} - - DEV_MODE=${DEV_MODE} - - QUEUE_PORT=${QUEUE_PORT} - - CREATE_PRODUCER=true - - NODE_SIDE_TYPE=${NODE_SIDE_TYPE} - - ENABLE_WARNINGS=${ENABLE_WARNINGS} - - INMEMORY_WORKERS=True # hardcoding is intentional, since single_container don't share databases - ports: - - "${HTTP_PORT}:${HTTP_PORT}" - volumes: - - credentials-data:/root/data/creds/ - - /var/run/docker.sock:/var/run/docker.sock - extra_hosts: - - "host.docker.internal:host-gateway" - labels: - - "orgs.openmined.syft=this is a syft worker container" - - backend: - restart: always - image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${VERSION-latest}" - profiles: - - backend - depends_on: - - proxy - - mongo - env_file: - - .env - environment: - - SERVICE_NAME=backend - - RELEASE=${RELEASE:-production} - - VERSION=${VERSION} - - VERSION_HASH=${VERSION_HASH} - - NODE_TYPE=${NODE_TYPE?Variable not set} - - NODE_NAME=${NODE_NAME?Variable not set} - - STACK_API_KEY=${STACK_API_KEY} - - PORT=8001 - - IGNORE_TLS_ERRORS=${IGNORE_TLS_ERRORS?False} - - HTTP_PORT=${HTTP_PORT} - - HTTPS_PORT=${HTTPS_PORT} - - USE_BLOB_STORAGE=${USE_BLOB_STORAGE} - - CONTAINER_HOST=${CONTAINER_HOST} - - TRACE=${TRACE} - - JAEGER_HOST=${JAEGER_HOST} - - JAEGER_PORT=${JAEGER_PORT} - - ASSOCIATION_TIMEOUT=${ASSOCIATION_TIMEOUT} - - DEV_MODE=${DEV_MODE} - - DEFAULT_ROOT_EMAIL=${DEFAULT_ROOT_EMAIL} - - DEFAULT_ROOT_PASSWORD=${DEFAULT_ROOT_PASSWORD} - - QUEUE_PORT=${QUEUE_PORT} - - CREATE_PRODUCER=true - - N_CONSUMERS=1 - - INMEMORY_WORKERS=${INMEMORY_WORKERS} - - HOST_GRID_PATH=${PWD} - command: "./grid/start.sh" - network_mode: service:proxy - volumes: - - ${CREDENTIALS_VOLUME}:/root/data/creds/ - - /var/run/docker.sock:/var/run/docker.sock - stdin_open: true - tty: true - labels: - - "orgs.openmined.syft=this is a syft backend container" - - seaweedfs: - profiles: - - blob-storage - depends_on: - - proxy - env_file: - - .env - image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}:${VERSION-latest}" - environment: - - SWFS_VOLUME_SIZE_LIMIT_MB=${SWFS_VOLUME_SIZE_LIMIT_MB:-1000} - - S3_ROOT_USER=${S3_ROOT_USER:-admin} - - S3_ROOT_PWD=${S3_ROOT_PWD:-admin} - - MOUNT_API_PORT=${MOUNT_API_PORT:-4001} - volumes: - - seaweedfs-data:/data - labels: - - "orgs.openmined.syft=this is a syft seaweedfs container" - - mongo: - image: "${MONGO_IMAGE}:${MONGO_VERSION}" - profiles: - - mongo - restart: always - environment: - - MONGO_INITDB_ROOT_USERNAME=${MONGO_USERNAME} - - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASSWORD} - volumes: - - mongo-data:/data/db - - mongo-config-data:/data/configdb - labels: - - "orgs.openmined.syft=this is a syft mongo container" - - jaeger: - profiles: - - telemetry - image: jaegertracing/all-in-one:1.37 - environment: - - COLLECTOR_ZIPKIN_HOST_PORT=9411 - - COLLECTOR_OTLP_ENABLED=true - extra_hosts: - - "host.docker.internal:host-gateway" - ports: - - "${JAEGER_PORT}:14268" # http collector - - "16686" # ui - # - "6831:6831/udp" - # - "6832:6832/udp" - # - "5778:5778" - # - "4317:4317" - # - "4318:4318" - # - "14250:14250" - # - "14269:14269" - # - "9411:9411" - volumes: - - jaeger-data:/tmp - labels: - - "orgs.openmined.syft=this is a syft jaeger container" - -volumes: - credentials-data: - labels: - orgs.openmined.syft: "this is a syft credentials volume" - seaweedfs-data: - labels: - orgs.openmined.syft: "this is a syft seaweedfs volume" - mongo-data: - labels: - orgs.openmined.syft: "this is a syft mongo volume" - mongo-config-data: - labels: - orgs.openmined.syft: "this is a syft mongo volume" - jaeger-data: - labels: - orgs.openmined.syft: "this is a syft jaeger volume" - -networks: - traefik-public: - # Allow setting it to false for testing - external: ${TRAEFIK_PUBLIC_NETWORK_IS_EXTERNAL-true} - labels: - orgs.openmined.syft: "this is a syft traefik public network" - default: - labels: - orgs.openmined.syft: "this is a syft default network" diff --git a/packages/hagrid/hagrid/cli.py b/packages/hagrid/hagrid/cli.py deleted file mode 100644 index 9407e49add8..00000000000 --- a/packages/hagrid/hagrid/cli.py +++ /dev/null @@ -1,4425 +0,0 @@ -# stdlib -from collections import namedtuple -from collections.abc import Callable -from enum import Enum -import json -import os -from pathlib import Path -import platform -from queue import Queue -import re -import shutil -import socket -import stat -import subprocess # nosec -import sys -import tempfile -from threading import Event -from threading import Thread -import time -from typing import Any -from typing import cast -from urllib.parse import urlparse -import webbrowser - -# third party -import click -import requests -import rich -from rich.console import Console -from rich.live import Live -from rich.progress import BarColumn -from rich.progress import Progress -from rich.progress import SpinnerColumn -from rich.progress import TextColumn -from virtualenvapi.manage import VirtualEnvironment - -# relative -from .art import RichEmoji -from .art import hagrid -from .art import quickstart_art -from .auth import AuthCredentials -from .cache import DEFAULT_BRANCH -from .cache import DEFAULT_REPO -from .cache import arg_cache -from .deps import DEPENDENCIES -from .deps import LATEST_BETA_SYFT -from .deps import allowed_hosts -from .deps import check_docker_service_status -from .deps import check_docker_version -from .deps import check_grid_docker -from .deps import gather_debug -from .deps import get_version_string -from .deps import is_windows -from .exceptions import MissingDependency -from .grammar import BadGrammar -from .grammar import GrammarVerb -from .grammar import parse_grammar -from .land import get_land_verb -from .launch import get_launch_verb -from .lib import GIT_REPO -from .lib import GRID_SRC_PATH -from .lib import GRID_SRC_VERSION -from .lib import check_api_metadata -from .lib import check_host -from .lib import check_jupyter_server -from .lib import check_login_page -from .lib import commit_hash -from .lib import docker_desktop_memory -from .lib import find_available_port -from .lib import generate_process_status_table -from .lib import generate_user_table -from .lib import gitpod_url -from .lib import hagrid_root -from .lib import is_gitpod -from .lib import name_tag -from .lib import save_vm_details_as_json -from .lib import update_repo -from .lib import use_branch -from .mode import EDITABLE_MODE -from .parse_template import deployment_dir -from .parse_template import get_template_yml -from .parse_template import manifest_cache_path -from .parse_template import render_templates -from .parse_template import setup_from_manifest_template -from .quickstart_ui import fetch_notebooks_for_url -from .quickstart_ui import fetch_notebooks_from_zipfile -from .quickstart_ui import quickstart_download_notebook -from .rand_sec import generate_sec_random_password -from .stable_version import LATEST_STABLE_SYFT -from .style import RichGroup -from .util import fix_windows_virtualenv_api -from .util import from_url -from .util import shell - -# fix VirtualEnvironment bug in windows -fix_windows_virtualenv_api(VirtualEnvironment) - - -class NodeSideType(Enum): - LOW_SIDE = "low" - HIGH_SIDE = "high" - - -def get_azure_image(short_name: str) -> str: - prebuild_070 = ( - "madhavajay1632269232059:openmined_mj_grid_domain_ubuntu_1:domain_070:latest" - ) - fresh_ubuntu = "Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest" - if short_name == "default": - return fresh_ubuntu - elif short_name == "domain_0.7.0": - return prebuild_070 - raise Exception(f"Image name doesn't exist: {short_name}. Try: default or 0.7.0") - - -@click.group(cls=RichGroup) -def cli() -> None: - pass - - -def get_compose_src_path( - node_name: str, - template_location: str | None = None, - **kwargs: Any, -) -> str: - grid_path = GRID_SRC_PATH() - tag = kwargs["tag"] - # Use local compose files if in editable mode and - # template_location is None and (kwargs["dev"] is True or tag is local) - if ( - EDITABLE_MODE - and template_location is None - and (kwargs["dev"] is True or tag == "local") - ): - path = grid_path - else: - path = deployment_dir(node_name) - - os.makedirs(path, exist_ok=True) - return path - - -@click.command( - help="Restore some part of the hagrid installation or deployment to its initial/starting state.", - context_settings={"show_default": True}, -) -@click.argument("location", type=str, nargs=1) -def clean(location: str) -> None: - if location == "library" or location == "volumes": - print("Deleting all Docker volumes in 2 secs (Ctrl-C to stop)") - time.sleep(2) - subprocess.call("docker volume rm $(docker volume ls -q)", shell=True) # nosec - - if location == "containers" or location == "pantry": - print("Deleting all Docker containers in 2 secs (Ctrl-C to stop)") - time.sleep(2) - subprocess.call("docker rm -f $(docker ps -a -q)", shell=True) # nosec - - if location == "images": - print("Deleting all Docker images in 2 secs (Ctrl-C to stop)") - time.sleep(2) - subprocess.call("docker rmi $(docker images -q)", shell=True) # nosec - - -@click.command( - help="Start a new PyGrid domain/network node!", - context_settings={"show_default": True}, -) -@click.argument("args", type=str, nargs=-1) -@click.option( - "--username", - default=None, - required=False, - type=str, - help="Username for provisioning the remote host", -) -@click.option( - "--key-path", - default=None, - required=False, - type=str, - help="Path to the key file for provisioning the remote host", -) -@click.option( - "--password", - default=None, - required=False, - type=str, - help="Password for provisioning the remote host", -) -@click.option( - "--repo", - default=None, - required=False, - type=str, - help="Repo to fetch source from", -) -@click.option( - "--branch", - default=None, - required=False, - type=str, - help="Branch to monitor for updates", -) -@click.option( - "--tail", - is_flag=True, - help="Tail logs on launch", -) -@click.option( - "--headless", - is_flag=True, - help="Start the frontend container", -) -@click.option( - "--cmd", - is_flag=True, - help="Print the cmd without running it", -) -@click.option( - "--jupyter", - is_flag=True, - help="Enable Jupyter Notebooks", -) -@click.option( - "--in-mem-workers", - is_flag=True, - help="Enable InMemory Workers", -) -@click.option( - "--enable-signup", - is_flag=True, - help="Enable Signup for Node", -) -@click.option( - "--build", - is_flag=True, - help="Disable forcing re-build", -) -@click.option( - "--no-provision", - is_flag=True, - help="Disable provisioning VMs", -) -@click.option( - "--node-count", - default=1, - required=False, - type=click.IntRange(1, 250), - help="Number of independent nodes/VMs to launch", -) -@click.option( - "--auth-type", - default=None, - type=click.Choice(["key", "password"], case_sensitive=False), -) -@click.option( - "--ansible-extras", - default="", - type=str, -) -@click.option("--tls", is_flag=True, help="Launch with TLS configuration") -@click.option("--test", is_flag=True, help="Launch with test configuration") -@click.option("--dev", is_flag=True, help="Shortcut for development mode") -@click.option( - "--release", - default="production", - required=False, - type=click.Choice(["production", "staging", "development"], case_sensitive=False), - help="Choose between production and development release", -) -@click.option( - "--deployment-type", - default="container_stack", - required=False, - type=click.Choice(["container_stack", "single_container"], case_sensitive=False), - help="Choose between container_stack and single_container deployment", -) -@click.option( - "--cert-store-path", - default="/home/om/certs", - required=False, - type=str, - help="Remote path to store and load TLS cert and key", -) -@click.option( - "--upload-tls-cert", - default="", - required=False, - type=str, - help="Local path to TLS cert to upload and store at --cert-store-path", -) -@click.option( - "--upload-tls-key", - default="", - required=False, - type=str, - help="Local path to TLS private key to upload and store at --cert-store-path", -) -@click.option( - "--no-blob-storage", - is_flag=True, - help="Disable blob storage", -) -@click.option( - "--image-name", - default=None, - required=False, - type=str, - help="Image to use for the VM", -) -@click.option( - "--tag", - default=None, - required=False, - type=str, - help="Container image tag to use", -) -@click.option( - "--smtp-username", - default=None, - required=False, - type=str, - help="Username used to auth in email server and enable notification via emails", -) -@click.option( - "--smtp-password", - default=None, - required=False, - type=str, - help="Password used to auth in email server and enable notification via emails", -) -@click.option( - "--smtp-port", - default=None, - required=False, - type=str, - help="Port used by email server to send notification via emails", -) -@click.option( - "--smtp-host", - default=None, - required=False, - type=str, - help="Address used by email server to send notification via emails", -) -@click.option( - "--smtp-sender", - default=None, - required=False, - type=str, - help="Sender email used to deliver PyGrid email notifications.", -) -@click.option( - "--build-src", - default=DEFAULT_BRANCH, - required=False, - type=str, - help="Git branch to use for launch / build operations", -) -@click.option( - "--platform", - default=None, - required=False, - type=str, - help="Run docker with a different platform like linux/arm64", -) -@click.option( - "--verbose", - is_flag=True, - help="Show verbose output", -) -@click.option( - "--trace", - required=False, - type=str, - help="Optional: allow trace to be turned on or off", -) -@click.option( - "--template", - required=False, - default=None, - help="Path or URL to manifest template", -) -@click.option( - "--template-overwrite", - is_flag=True, - help="Force re-downloading of template manifest", -) -@click.option( - "--no-health-checks", - is_flag=True, - help="Turn off auto health checks post node launch", -) -@click.option( - "--set-root-email", - default=None, - required=False, - type=str, - help="Set root email of node", -) -@click.option( - "--set-root-password", - default=None, - required=False, - type=str, - help="Set root password of node", -) -@click.option( - "--azure-resource-group", - default=None, - required=False, - type=str, - help="Azure Resource Group", -) -@click.option( - "--azure-location", - default=None, - required=False, - type=str, - help="Azure Resource Group Location", -) -@click.option( - "--azure-size", - default=None, - required=False, - type=str, - help="Azure VM Size", -) -@click.option( - "--azure-username", - default=None, - required=False, - type=str, - help="Azure VM Username", -) -@click.option( - "--azure-key-path", - default=None, - required=False, - type=str, - help="Azure Key Path", -) -@click.option( - "--azure-repo", - default=None, - required=False, - type=str, - help="Azure Source Repo", -) -@click.option( - "--azure-branch", - default=None, - required=False, - type=str, - help="Azure Source Branch", -) -@click.option( - "--render", - is_flag=True, - help="Render Docker Files", -) -@click.option( - "--no-warnings", - is_flag=True, - help="Enable API warnings on the node.", -) -@click.option( - "--low-side", - is_flag=True, - help="Launch a low side node type else a high side node type", -) -@click.option( - "--set-s3-username", - default=None, - required=False, - type=str, - help="Set root username for s3 blob storage", -) -@click.option( - "--set-s3-password", - default=None, - required=False, - type=str, - help="Set root password for s3 blob storage", -) -@click.option( - "--set-volume-size-limit-mb", - default=1024, - required=False, - type=click.IntRange(1024, 50000), - help="Set the volume size limit (in MBs)", -) -@click.option( - "--association-request-auto-approval", - is_flag=True, - help="Enable auto approval of association requests", -) -@click.option( - "--rathole", - is_flag=True, - help="Enable rathole service", -) -def launch(args: tuple[str], **kwargs: Any) -> None: - verb = get_launch_verb() - try: - grammar = parse_grammar(args=args, verb=verb) - verb.load_grammar(grammar=grammar) - except BadGrammar as e: - print(e) - return - - node_name = verb.get_named_term_type(name="node_name") - snake_name = str(node_name.snake_input) - node_type = verb.get_named_term_type(name="node_type") - - # For enclave currently it is only a single container deployment - # This would change when we have side car containers to enclave - if node_type.input == "enclave": - kwargs["deployment_type"] = "single_container" - - compose_src_path = get_compose_src_path( - node_type=node_type, - node_name=snake_name, - template_location=kwargs["template"], - **kwargs, - ) - kwargs["compose_src_path"] = compose_src_path - - try: - update_repo(repo=GIT_REPO(), branch=str(kwargs["build_src"])) - except Exception as e: - print(f"Failed to update repo. {e}") - try: - cmds = create_launch_cmd(verb=verb, kwargs=kwargs) - cmds = [cmds] if isinstance(cmds, str) else cmds - except Exception as e: - print(f"Error: {e}\n\n") - return - - dry_run = bool(kwargs["cmd"]) - - health_checks = not bool(kwargs["no_health_checks"]) - render_only = bool(kwargs["render"]) - - try: - tail = bool(kwargs["tail"]) - verbose = bool(kwargs["verbose"]) - silent = not verbose - if tail: - silent = False - - if render_only: - print( - "Docker Compose Files Rendered: {}".format(kwargs["compose_src_path"]) - ) - return - - execute_commands( - cmds, - dry_run=dry_run, - silent=silent, - compose_src_path=kwargs["compose_src_path"], - node_type=node_type.input, - ) - - host_term = verb.get_named_term_hostgrammar(name="host") - run_health_checks = ( - health_checks and not dry_run and host_term.host == "docker" and silent - ) - - if run_health_checks: - docker_cmds = cast(dict[str, list[str]], cmds) - - # get the first command (cmd1) from docker_cmds which is of the form - # {"": [cmd1, cmd2], "": [cmd3, cmd4]} - (command, *_), *_ = docker_cmds.values() - - match_port = re.search("HTTP_PORT=[0-9]{1,5}", command) - if match_port: - rich.get_console().print( - "\n[bold green]⠋[bold blue] Checking node API [/bold blue]\t" - ) - port = match_port.group().replace("HTTP_PORT=", "") - - check_status("localhost" + ":" + port, node_name=node_name.snake_input) - - rich.get_console().print( - rich.panel.Panel.fit( - f"✨ To view container logs run [bold green]hagrid logs {node_name.snake_input}[/bold green]\t" - ) - ) - - except Exception as e: - print(f"Error: {e}\n\n") - return - - -def check_errors( - line: str, process: subprocess.Popen, cmd_name: str, progress_bar: Progress -) -> None: - task = progress_bar.tasks[0] - if "Error response from daemon: " in line: - if progress_bar: - progress_bar.update( - 0, - description=f"❌ [bold red]{cmd_name}[/bold red] [{task.completed} / {task.total}]", - refresh=True, - ) - progress_bar.update(0, visible=False) - progress_bar.console.clear_live() - progress_bar.console.quiet = True - progress_bar.stop() - console = rich.get_console() - progress_bar.console.quiet = False - console.print(f"\n\n [red] ERROR [/red]: [bold]{line}[/bold]\n") - process.terminate() - raise Exception - - -def check_pulling(line: str, cmd_name: str, progress_bar: Progress) -> None: - task = progress_bar.tasks[0] - if "Pulling" in line and "fs layer" not in line: - progress_bar.update( - 0, - description=f"[bold]{cmd_name} [{task.completed} / {task.total+1}]", - total=task.total + 1, - refresh=True, - ) - if "Pulled" in line: - progress_bar.update( - 0, - description=f"[bold]{cmd_name} [{task.completed + 1} / {task.total}]", - completed=task.completed + 1, - refresh=True, - ) - if progress_bar.finished: - progress_bar.update( - 0, - description=f"✅ [bold green]{cmd_name} [{task.completed} / {task.total}]", - refresh=True, - ) - - -def check_building(line: str, cmd_name: str, progress_bar: Progress) -> None: - load_pattern = re.compile( - r"^#.* load build definition from [A-Za-z0-9]+\.dockerfile$", re.IGNORECASE - ) - build_pattern = re.compile( - r"^#.* naming to docker\.io/openmined/.* done$", re.IGNORECASE - ) - task = progress_bar.tasks[0] - - if load_pattern.match(line): - progress_bar.update( - 0, - description=f"[bold]{cmd_name} [{task.completed} / {task.total +1}]", - total=task.total + 1, - refresh=True, - ) - if build_pattern.match(line): - progress_bar.update( - 0, - description=f"[bold]{cmd_name} [{task.completed+1} / {task.total}]", - completed=task.completed + 1, - refresh=True, - ) - - if progress_bar.finished: - progress_bar.update( - 0, - description=f"✅ [bold green]{cmd_name} [{task.completed} / {task.total}]", - refresh=True, - ) - - -def check_launching(line: str, cmd_name: str, progress_bar: Progress) -> None: - task = progress_bar.tasks[0] - if "Starting" in line: - progress_bar.update( - 0, - description=f" [bold]{cmd_name} [{task.completed} / {task.total+1}]", - total=task.total + 1, - refresh=True, - ) - if "Started" in line: - progress_bar.update( - 0, - description=f" [bold]{cmd_name} [{task.completed + 1} / {task.total}]", - completed=task.completed + 1, - refresh=True, - ) - if progress_bar.finished: - progress_bar.update( - 0, - description=f"✅ [bold green]{cmd_name} [{task.completed} / {task.total}]", - refresh=True, - ) - - -DOCKER_FUNC_MAP = { - "Pulling": check_pulling, - "Building": check_building, - "Launching": check_launching, -} - - -def read_thread_logs( - progress_bar: Progress, process: subprocess.Popen, queue: Queue, cmd_name: str -) -> None: - line = queue.get() - line = str(line, encoding="utf-8").strip() - - if progress_bar: - check_errors(line, process, cmd_name, progress_bar=progress_bar) - DOCKER_FUNC_MAP[cmd_name](line, cmd_name, progress_bar=progress_bar) - - -def create_thread_logs(process: subprocess.Popen) -> Queue: - def enqueue_output(out: Any, queue: Queue) -> None: - for line in iter(out.readline, b""): - queue.put(line) - out.close() - - queue: Queue = Queue() - thread_1 = Thread(target=enqueue_output, args=(process.stdout, queue)) - thread_2 = Thread(target=enqueue_output, args=(process.stderr, queue)) - - thread_1.daemon = True # thread dies with the program - thread_1.start() - thread_2.daemon = True # thread dies with the program - thread_2.start() - return queue - - -def process_cmd( - cmds: list[str], - node_type: str, - dry_run: bool, - silent: bool, - compose_src_path: str, - progress_bar: Progress | None = None, - cmd_name: str = "", -) -> None: - process_list: list = [] - cwd = compose_src_path - - username, password = ( - extract_username_and_pass(cmds[0]) if len(cmds) > 0 else ("-", "-") - ) - # display VM credentials - console = rich.get_console() - credentials = generate_user_table(username=username, password=password) - if credentials: - console.print(credentials) - - for cmd in cmds: - if dry_run: - print(f"\nRunning:\ncd {cwd}\n", hide_password(cmd=cmd)) - continue - - # use powershell if environment is Windows - cmd_to_exec = ["powershell.exe", "-Command", cmd] if is_windows() else cmd - - try: - if len(cmds) > 1: - process = subprocess.Popen( # nosec - cmd_to_exec, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=cwd, - shell=True, - ) - ip_address = extract_host_ip_from_cmd(cmd) - jupyter_token = extract_jupyter_token(cmd) - process_list.append((ip_address, process, jupyter_token)) - else: - display_jupyter_token(cmd) - if silent: - ON_POSIX = "posix" in sys.builtin_module_names - - process = subprocess.Popen( # nosec - cmd_to_exec, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=cwd, - close_fds=ON_POSIX, - shell=True, - ) - - # Creates two threads to get docker stdout and sterr - logs_queue = create_thread_logs(process=process) - - read_thread_logs(progress_bar, process, logs_queue, cmd_name) - while process.poll() != 0: - while not logs_queue.empty(): - # Read stdout and sterr to check errors or update progress bar. - read_thread_logs( - progress_bar, process, logs_queue, cmd_name - ) - else: - if progress_bar: - progress_bar.stop() - - subprocess.run( # nosec - cmd_to_exec, - shell=True, - cwd=cwd, - ) - except Exception as e: - print(f"Failed to run cmd: {cmd}. {e}") - - if dry_run is False and len(process_list) > 0: - # display VM launch status - display_vm_status(process_list) - - # save vm details as json - save_vm_details_as_json(username, password, process_list) - - -def execute_commands( - cmds: list[str] | dict[str, list[str]], - node_type: str, - compose_src_path: str, - dry_run: bool = False, - silent: bool = False, -) -> None: - """Execute the launch commands and display their status in realtime. - - Args: - cmds (list): list of commands to be executed - dry_run (bool, optional): If `True` only displays cmds to be executed. Defaults to False. - """ - console = rich.get_console() - if isinstance(cmds, dict): - console.print("[bold green]⠋[bold blue] Launching Containers [/bold blue]\t") - for cmd_name, cmd in cmds.items(): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:.2f}% "), - console=console, - auto_refresh=True, - ) as progress: - if silent: - progress.add_task( - f"[bold green]{cmd_name} Images", - total=0, - ) - process_cmd( - cmds=cmd, - node_type=node_type, - dry_run=dry_run, - silent=silent, - compose_src_path=compose_src_path, - progress_bar=progress, - cmd_name=cmd_name, - ) - else: - process_cmd( - cmds=cmds, - node_type=node_type, - dry_run=dry_run, - silent=silent, - compose_src_path=compose_src_path, - ) - - -def display_vm_status(process_list: list) -> None: - """Display the status of the processes being executed on the VM. - - Args: - process_list (list): list of processes executed. - """ - - # Generate the table showing the status of each process being executed - status_table, process_completed = generate_process_status_table(process_list) - - # Render the live table - with Live(status_table, refresh_per_second=1) as live: - # Loop till all processes have not completed executing - while not process_completed: - status_table, process_completed = generate_process_status_table( - process_list - ) - live.update(status_table) # Update the process status table - - -def display_jupyter_token(cmd: str) -> None: - token = extract_jupyter_token(cmd=cmd) - if token is not None: - print(f"Jupyter Token: {token}") - - -def extract_username_and_pass(cmd: str) -> tuple: - # Extract username - matcher = r"--user (.+?) " - username = re.findall(matcher, cmd) - username = username[0] if len(username) > 0 else None - - # Extract password - matcher = r"ansible_ssh_pass='(.+?)'" - password = re.findall(matcher, cmd) - password = password[0] if len(password) > 0 else None - - return username, password - - -def extract_jupyter_token(cmd: str) -> str | None: - matcher = r"jupyter_token='(.+?)'" - token = re.findall(matcher, cmd) - if len(token) == 1: - return token[0] - return None - - -def hide_password(cmd: str) -> str: - try: - matcher = r"ansible_ssh_pass='(.+?)'" - passwords = re.findall(matcher, cmd) - if len(passwords) > 0: - password = passwords[0] - stars = "*" * 4 - cmd = cmd.replace( - f"ansible_ssh_pass='{password}'", f"ansible_ssh_pass='{stars}'" - ) - return cmd - except Exception as e: - print("Failed to hide password.") - raise e - - -def hide_azure_vm_password(azure_cmd: str) -> str: - try: - matcher = r"admin-password '(.+?)'" - passwords = re.findall(matcher, azure_cmd) - if len(passwords) > 0: - password = passwords[0] - stars = "*" * 4 - azure_cmd = azure_cmd.replace( - f"admin-password '{password}'", f"admin-password '{stars}'" - ) - return azure_cmd - except Exception as e: - print("Failed to hide password.") - raise e - - -class QuestionInputError(Exception): - pass - - -class QuestionInputPathError(Exception): - pass - - -class Question: - def __init__( - self, - var_name: str, - question: str, - kind: str, - default: str | None = None, - cache: bool = False, - options: list[str] | None = None, - ) -> None: - self.var_name = var_name - self.question = question - self.default = default - self.kind = kind - self.cache = cache - self.options = options if options is not None else [] - - def validate(self, value: str) -> str: - value = value.strip() - if self.default is not None and value == "": - return self.default - - if self.kind == "path": - value = os.path.expanduser(value) - if not os.path.exists(value): - error = f"{value} is not a valid path." - if self.default is not None: - error += f" Try {self.default}" - raise QuestionInputPathError(f"{error}") - - if self.kind == "yesno": - if value.lower().startswith("y"): - return "y" - elif value.lower().startswith("n"): - return "n" - else: - raise QuestionInputError(f"{value} is not an yes or no answer") - - if self.kind == "options": - if value in self.options: - return value - first_letter = value.lower()[0] - for option in self.options: - if option.startswith(first_letter): - return option - - raise QuestionInputError( - f"{value} is not one of the options: {self.options}" - ) - - if self.kind == "password": - try: - return validate_password(password=value) - except Exception as e: - raise QuestionInputError(f"Invalid password. {e}") - return value - - -def ask(question: Question, kwargs: dict[str, str]) -> str: - if question.var_name in kwargs and kwargs[question.var_name] is not None: - value = kwargs[question.var_name] - else: - if question.default is not None: - value = click.prompt(question.question, type=str, default=question.default) - elif question.var_name == "password": - value = click.prompt( - question.question, type=str, hide_input=True, confirmation_prompt=True - ) - else: - value = click.prompt(question.question, type=str) - - try: - value = question.validate(value=value) - except QuestionInputError as e: - print(e) - return ask(question=question, kwargs=kwargs) - if question.cache: - arg_cache[question.var_name] = value - - return value - - -def fix_key_permission(private_key_path: str) -> None: - key_permission = oct(stat.S_IMODE(os.stat(private_key_path).st_mode)) - chmod_permission = "400" - octal_permission = f"0o{chmod_permission}" - if key_permission != octal_permission: - print( - f"Fixing key permission: {private_key_path}, setting to {chmod_permission}" - ) - try: - os.chmod(private_key_path, int(octal_permission, 8)) - except Exception as e: - print("Failed to fix key permission", e) - raise e - - -def private_to_public_key(private_key_path: str, temp_path: str, username: str) -> str: - # check key permission - fix_key_permission(private_key_path=private_key_path) - output_path = f"{temp_path}/hagrid_{username}_key.pub" - cmd = f"ssh-keygen -f {private_key_path} -y > {output_path}" - try: - subprocess.check_call(cmd, shell=True) # nosec - except Exception as e: - print("failed to make ssh key", e) - raise e - return output_path - - -def check_azure_authed() -> bool: - cmd = "az account show" - try: - subprocess.check_call(cmd, shell=True, stdout=subprocess.DEVNULL) # nosec - return True - except Exception: # nosec - pass - return False - - -def login_azure() -> bool: - cmd = "az login" - try: - subprocess.check_call(cmd, shell=True, stdout=subprocess.DEVNULL) # nosec - return True - except Exception: # nosec - pass - return False - - -def check_azure_cli_installed() -> bool: - try: - result = subprocess.run( # nosec - ["az", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT - ) - if result.returncode != 0: - raise FileNotFoundError("az not installed") - except Exception: # nosec - msg = "\nYou don't appear to have the Azure CLI installed!!! \n\n\ -Please install it and then retry your command.\ -\n\nInstallation Instructions: https://docs.microsoft.com/en-us/cli/azure/install-azure-cli\n" - raise FileNotFoundError(msg) - - return True - - -def check_gcloud_cli_installed() -> bool: - try: - subprocess.call(["gcloud", "version"]) # nosec - print("Gcloud cli installed!") - except FileNotFoundError: - msg = "\nYou don't appear to have the gcloud CLI tool installed! \n\n\ -Please install it and then retry again.\ -\n\nInstallation Instructions: https://cloud.google.com/sdk/docs/install-sdk \n" - raise FileNotFoundError(msg) - - return True - - -def check_aws_cli_installed() -> bool: - try: - result = subprocess.run( # nosec - ["aws", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT - ) - if result.returncode != 0: - raise FileNotFoundError("AWS CLI not installed") - except Exception: # nosec - msg = "\nYou don't appear to have the AWS CLI installed! \n\n\ -Please install it and then retry your command.\ -\n\nInstallation Instructions: https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html\n" - raise FileNotFoundError(msg) - - return True - - -def check_gcloud_authed() -> bool: - try: - result = subprocess.run( # nosec - ["gcloud", "auth", "print-identity-token"], stdout=subprocess.PIPE - ) - if result.returncode == 0: - return True - except Exception: # nosec - pass - return False - - -def login_gcloud() -> bool: - cmd = "gcloud auth login" - try: - subprocess.check_call(cmd, shell=True, stdout=subprocess.DEVNULL) # nosec - return True - except Exception: # nosec - pass - return False - - -def str_to_bool(bool_str: str | None) -> bool: - result = False - bool_str = str(bool_str).lower() - if bool_str == "true" or bool_str == "1": - result = True - return result - - -ART = str_to_bool(os.environ.get("HAGRID_ART", "True")) - - -def generate_gcloud_key_at_path(key_path: str) -> str: - key_path = os.path.expanduser(key_path) - if os.path.exists(key_path): - raise Exception(f"Can't generate key since path already exists. {key_path}") - else: - # triggers a key check - cmd = "gcloud compute ssh '' --dry-run" - try: - subprocess.check_call(cmd, shell=True) # nosec - except Exception: # nosec - pass - if not os.path.exists(key_path): - raise Exception(f"gcloud failed to generate ssh-key at: {key_path}") - - return key_path - - -def generate_aws_key_at_path(key_path: str, key_name: str) -> str: - key_path = os.path.expanduser(key_path) - if os.path.exists(key_path): - raise Exception(f"Can't generate key since path already exists. {key_path}") - else: - # TODO we need to do differently for powershell. - # Ex: aws ec2 create-key-pair --key-name MyKeyPair --query 'KeyMaterial' - # --output text | out-file -encoding ascii -filepath MyKeyPair.pem - - print(f"Creating AWS key pair with name {key_name} at path {key_path}..") - cmd = f"aws ec2 create-key-pair --key-name {key_name} --query 'KeyMaterial' --output text > {key_path}" - try: - subprocess.check_call(cmd, shell=True) # nosec - subprocess.check_call(f"chmod 400 {key_path}", shell=True) # nosec - except Exception as e: # nosec - print(f"Failed to create key: {e}") - if not os.path.exists(key_path): - raise Exception(f"AWS failed to generate key pair at: {key_path}") - - return key_path - - -def generate_key_at_path(key_path: str) -> str: - key_path = os.path.expanduser(key_path) - if os.path.exists(key_path): - raise Exception(f"Can't generate key since path already exists. {key_path}") - else: - cmd = f"ssh-keygen -N '' -f {key_path}" - try: - subprocess.check_call(cmd, shell=True) # nosec - if not os.path.exists(key_path): - raise Exception(f"Failed to generate ssh-key at: {key_path}") - except Exception as e: - raise e - - return key_path - - -def validate_password(password: str) -> str: - """Validate if the password entered by the user is valid. - - Password length should be between 12 - 123 characters - Passwords must also meet 3 out of the following 4 complexity requirements: - - Have lower characters - - Have upper characters - - Have a digit - - Have a special character - - Args: - password (str): password for the vm - - Returns: - str: password if it is valid - """ - # Validate password length - if len(password) < 12 or len(password) > 123: - raise ValueError("Password length should be between 12 - 123 characters") - - # Valid character types - character_types = { - "upper_case": False, - "lower_case": False, - "digit": False, - "special": False, - } - - for ch in password: - if ch.islower(): - character_types["lower_case"] = True - elif ch.isupper(): - character_types["upper_case"] = True - elif ch.isdigit(): - character_types["digit"] = True - elif ch.isascii(): - character_types["special"] = True - else: - raise ValueError(f"{ch} is not a valid character for password") - - # Validate characters in the password - required_character_type_count = sum( - [int(value) for value in character_types.values()] - ) - - if required_character_type_count >= 3: - return password - - absent_character_types = ", ".join( - char_type for char_type, value in character_types.items() if value is False - ).strip(", ") - - raise ValueError( - f"At least one {absent_character_types} character types must be present" - ) - - -def create_launch_cmd( - verb: GrammarVerb, - kwargs: dict[str, Any], - ignore_docker_version_check: bool | None = False, -) -> str | list[str] | dict[str, list[str]]: - parsed_kwargs: dict[str, Any] = {} - host_term = verb.get_named_term_hostgrammar(name="host") - - host = host_term.host - auth: AuthCredentials | None = None - - tail = bool(kwargs["tail"]) - - parsed_kwargs = {} - - parsed_kwargs["build"] = bool(kwargs["build"]) - - parsed_kwargs["use_blob_storage"] = not bool(kwargs["no_blob_storage"]) - - parsed_kwargs["in_mem_workers"] = bool(kwargs["in_mem_workers"]) - - if parsed_kwargs["use_blob_storage"]: - parsed_kwargs["set_s3_username"] = kwargs["set_s3_username"] - parsed_kwargs["set_s3_password"] = kwargs["set_s3_password"] - parsed_kwargs["set_volume_size_limit_mb"] = kwargs["set_volume_size_limit_mb"] - - parsed_kwargs["association_request_auto_approval"] = str( - kwargs["association_request_auto_approval"] - ) - - parsed_kwargs["node_count"] = ( - int(kwargs["node_count"]) if "node_count" in kwargs else 1 - ) - - if parsed_kwargs["node_count"] > 1 and host not in ["azure"]: - print("\nArgument `node_count` is only supported with `azure`.\n") - else: - # Default to detached mode if running more than one nodes - tail = False if parsed_kwargs["node_count"] > 1 else tail - - headless = bool(kwargs["headless"]) - parsed_kwargs["headless"] = headless - - parsed_kwargs["tls"] = bool(kwargs["tls"]) - parsed_kwargs["enable_rathole"] = bool(kwargs["rathole"]) - parsed_kwargs["test"] = bool(kwargs["test"]) - parsed_kwargs["dev"] = bool(kwargs["dev"]) - - parsed_kwargs["silent"] = not bool(kwargs["verbose"]) - - parsed_kwargs["trace"] = False - if ("trace" not in kwargs or kwargs["trace"] is None) and parsed_kwargs["dev"]: - # default to trace on in dev mode - parsed_kwargs["trace"] = False - elif "trace" in kwargs: - parsed_kwargs["trace"] = str_to_bool(cast(str, kwargs["trace"])) - - parsed_kwargs["release"] = "production" - if "release" in kwargs and kwargs["release"] != "production": - parsed_kwargs["release"] = kwargs["release"] - - # if we use --dev override it - if parsed_kwargs["dev"] is True: - parsed_kwargs["release"] = "development" - - # derive node type - if kwargs["low_side"]: - parsed_kwargs["node_side_type"] = NodeSideType.LOW_SIDE.value - else: - parsed_kwargs["node_side_type"] = NodeSideType.HIGH_SIDE.value - - parsed_kwargs["smtp_username"] = kwargs["smtp_username"] - parsed_kwargs["smtp_password"] = kwargs["smtp_password"] - parsed_kwargs["smtp_port"] = kwargs["smtp_port"] - parsed_kwargs["smtp_host"] = kwargs["smtp_host"] - parsed_kwargs["smtp_sender"] = kwargs["smtp_sender"] - - parsed_kwargs["enable_warnings"] = not kwargs["no_warnings"] - - # choosing deployment type - parsed_kwargs["deployment_type"] = "container_stack" - if "deployment_type" in kwargs and kwargs["deployment_type"] is not None: - parsed_kwargs["deployment_type"] = kwargs["deployment_type"] - - if "cert_store_path" in kwargs: - parsed_kwargs["cert_store_path"] = kwargs["cert_store_path"] - if "upload_tls_cert" in kwargs: - parsed_kwargs["upload_tls_cert"] = kwargs["upload_tls_cert"] - if "upload_tls_key" in kwargs: - parsed_kwargs["upload_tls_key"] = kwargs["upload_tls_key"] - - parsed_kwargs["provision"] = not bool(kwargs["no_provision"]) - - if "image_name" in kwargs and kwargs["image_name"] is not None: - parsed_kwargs["image_name"] = kwargs["image_name"] - else: - parsed_kwargs["image_name"] = "default" - - if parsed_kwargs["dev"] is True: - parsed_kwargs["tag"] = "local" - else: - if "tag" in kwargs and kwargs["tag"] is not None and kwargs["tag"] != "": - parsed_kwargs["tag"] = kwargs["tag"] - else: - parsed_kwargs["tag"] = "latest" - - if "jupyter" in kwargs and kwargs["jupyter"] is not None: - parsed_kwargs["jupyter"] = str_to_bool(cast(str, kwargs["jupyter"])) - else: - parsed_kwargs["jupyter"] = False - - # allows changing docker platform to other cpu architectures like arm64 - parsed_kwargs["platform"] = kwargs["platform"] if "platform" in kwargs else None - - parsed_kwargs["tail"] = tail - - parsed_kwargs["set_root_password"] = ( - kwargs["set_root_password"] if "set_root_password" in kwargs else None - ) - - parsed_kwargs["set_root_email"] = ( - kwargs["set_root_email"] if "set_root_email" in kwargs else None - ) - - parsed_kwargs["template"] = kwargs["template"] if "template" in kwargs else None - parsed_kwargs["template_overwrite"] = bool(kwargs["template_overwrite"]) - - parsed_kwargs["compose_src_path"] = kwargs["compose_src_path"] - - parsed_kwargs["enable_signup"] = str_to_bool(cast(str, kwargs["enable_signup"])) - - # Override template tag with user input tag - if ( - parsed_kwargs["tag"] is not None - and parsed_kwargs["template"] is None - and parsed_kwargs["tag"] not in ["local"] - ): - # third party - from packaging import version - - pattern = r"[0-9].[0-9].[0-9]" - input_tag = parsed_kwargs["tag"] - if ( - not re.match(pattern, input_tag) - and input_tag != "latest" - and input_tag != "beta" - and "b" not in input_tag - ): - raise Exception( - f"Not a valid tag: {parsed_kwargs['tag']}" - + "\nValid tags: latest, beta, beta version(ex: 0.8.2b35),[0-9].[0-9].[0-9]" - ) - - # TODO: we need to redo this so that pypi and docker mappings are in a single - # file inside dev - if parsed_kwargs["tag"] == "latest": - parsed_kwargs["template"] = LATEST_STABLE_SYFT - parsed_kwargs["tag"] = LATEST_STABLE_SYFT - elif parsed_kwargs["tag"] == "beta" or "b" in parsed_kwargs["tag"]: - tag = ( - LATEST_BETA_SYFT - if parsed_kwargs["tag"] == "beta" - else parsed_kwargs["tag"] - ) - - # Currently, manifest_template.yml is only supported for beta versions >= 0.8.2b34 - beta_version = version.parse(tag) - MINIMUM_BETA_VERSION = "0.8.2b34" - if beta_version < version.parse(MINIMUM_BETA_VERSION): - raise Exception( - f"Minimum beta version tag supported is {MINIMUM_BETA_VERSION}" - ) - - # Check if the beta version is available - template_url = f"https://github.com/OpenMined/PySyft/releases/download/v{str(beta_version)}/manifest_template.yml" - response = requests.get(template_url) # nosec - if response.status_code != 200: - raise Exception( - f"Tag {parsed_kwargs['tag']} is not available" - + " \n for download. Please check the available tags at: " - + "\n https://github.com/OpenMined/PySyft/releases" - ) - - parsed_kwargs["template"] = template_url - parsed_kwargs["tag"] = tag - else: - MINIMUM_TAG_VERSION = version.parse("0.8.0") - tag = version.parse(parsed_kwargs["tag"]) - if tag < MINIMUM_TAG_VERSION: - raise Exception( - f"Minimum supported stable tag version is {MINIMUM_TAG_VERSION}" - ) - parsed_kwargs["template"] = parsed_kwargs["tag"] - - if host in ["docker"] and parsed_kwargs["template"] and host is not None: - # Setup the files from the manifest_template.yml - kwargs = setup_from_manifest_template( - host_type=host, - deployment_type=parsed_kwargs["deployment_type"], - template_location=parsed_kwargs["template"], - overwrite=parsed_kwargs["template_overwrite"], - verbose=kwargs["verbose"], - ) - - parsed_kwargs.update(kwargs) - - if host in ["docker"]: - # Check docker service status - if not ignore_docker_version_check: - check_docker_service_status() - - # Check grid docker versions - if not ignore_docker_version_check: - check_grid_docker(display=True, output_in_text=True) - - if not ignore_docker_version_check: - version = check_docker_version() - else: - version = "n/a" - - if version: - # If the user is using docker desktop (OSX/Windows), check to make sure there's enough RAM. - # If the user is using Linux this isn't an issue because Docker scales to the avaialble RAM, - # but on Docker Desktop it defaults to 2GB which isn't enough. - dd_memory = docker_desktop_memory() - if dd_memory < 8192 and dd_memory != -1: - raise Exception( - "You appear to be using Docker Desktop but don't have " - "enough memory allocated. It appears you've configured " - f"Memory:{dd_memory} MB when 8192MB (8GB) is required. " - f"Please open Docker Desktop Preferences panel and set Memory" - f" to 8GB or higher. \n\n" - f"\tOSX Help: https://docs.docker.com/desktop/mac/\n" - f"\tWindows Help: https://docs.docker.com/desktop/windows/\n\n" - f"Then re-run your hagrid command.\n\n" - f"If you see this warning on Linux then something isn't right. " - f"Please file a Github Issue on PySyft's Github.\n\n" - f"Alternatively in case no more memory could be allocated, " - f"you can run hagrid on the cloud with GitPod by visiting " - f"https://gitpod.io/#https://github.com/OpenMined/PySyft." - ) - - if is_windows() and not DEPENDENCIES["wsl"]: - raise Exception( - "You must install wsl2 for Windows to use HAGrid.\n" - "In PowerShell or Command Prompt type:\n> wsl --install\n\n" - "Read more here: https://docs.microsoft.com/en-us/windows/wsl/install" - ) - - return create_launch_docker_cmd( - verb=verb, - docker_version=version, - tail=tail, - kwargs=parsed_kwargs, - silent=parsed_kwargs["silent"], - ) - - elif host in ["azure"]: - check_azure_cli_installed() - - while not check_azure_authed(): - print("You need to log into Azure") - login_azure() - - if DEPENDENCIES["ansible-playbook"]: - resource_group = ask( - question=Question( - var_name="azure_resource_group", - question="What resource group name do you want to use (or create)?", - default=arg_cache["azure_resource_group"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - location = ask( - question=Question( - var_name="azure_location", - question="If this is a new resource group what location?", - default=arg_cache["azure_location"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - size = ask( - question=Question( - var_name="azure_size", - question="What size machine?", - default=arg_cache["azure_size"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - username = ask( - question=Question( - var_name="azure_username", - question="What do you want the username for the VM to be?", - default=arg_cache["azure_username"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - parsed_kwargs["auth_type"] = ask( - question=Question( - var_name="auth_type", - question="Do you want to login with a key or password", - default=arg_cache["auth_type"], - kind="option", - options=["key", "password"], - cache=True, - ), - kwargs=kwargs, - ) - - key_path = None - if parsed_kwargs["auth_type"] == "key": - key_path_question = Question( - var_name="azure_key_path", - question=f"Absolute path of the private key to access {username}@{host}?", - default=arg_cache["azure_key_path"], - kind="path", - cache=True, - ) - try: - key_path = ask( - key_path_question, - kwargs=kwargs, - ) - except QuestionInputPathError as e: - print(e) - key_path = str(e).split("is not a valid path")[0].strip() - - create_key_question = Question( - var_name="azure_key_path", - question=f"Key {key_path} does not exist. Do you want to create it? (y/n)", - default="y", - kind="yesno", - ) - create_key = ask( - create_key_question, - kwargs=kwargs, - ) - if create_key == "y": - key_path = generate_key_at_path(key_path=key_path) - else: - raise QuestionInputError( - "Unable to create VM without a private key" - ) - elif parsed_kwargs["auth_type"] == "password": - auto_generate_password = ask( - question=Question( - var_name="auto_generate_password", - question="Do you want to auto-generate the password? (y/n)", - kind="yesno", - ), - kwargs=kwargs, - ) - if auto_generate_password == "y": # nosec - parsed_kwargs["password"] = generate_sec_random_password(length=16) - elif auto_generate_password == "n": # nosec - parsed_kwargs["password"] = ask( - question=Question( - var_name="password", - question=f"Password for {username}@{host}?", - kind="password", - ), - kwargs=kwargs, - ) - - repo = ask( - Question( - var_name="azure_repo", - question="Repo to fetch source from?", - default=arg_cache["azure_repo"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - branch = ask( - Question( - var_name="azure_branch", - question="Branch to monitor for updates?", - default=arg_cache["azure_branch"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - use_branch(branch=branch) - - password = parsed_kwargs.get("password") - - auth = AuthCredentials( - username=username, key_path=key_path, password=password - ) - - if not auth.valid: - raise Exception(f"Login Credentials are not valid. {auth}") - - return create_launch_azure_cmd( - verb=verb, - resource_group=resource_group, - location=location, - size=size, - username=username, - password=password, - key_path=key_path, - repo=repo, - branch=branch, - auth=auth, - ansible_extras=kwargs["ansible_extras"], - kwargs=parsed_kwargs, - ) - else: - errors = [] - if not DEPENDENCIES["ansible-playbook"]: - errors.append("ansible-playbook") - msg = "\nERROR!!! MISSING DEPENDENCY!!!" - msg += f"\n\nLaunching a Cloud VM requires: {' '.join(errors)}" - msg += "\n\nPlease follow installation instructions: " - msg += "https://docs.ansible.com/ansible/latest/installation_guide/intro_installation.html#" - msg += "\n\nNote: we've found the 'conda' based installation instructions to work best" - msg += " (e.g. something lke 'conda install -c conda-forge ansible'). " - msg += "The pip based instructions seem to be a bit buggy if you're using a conda environment" - msg += "\n" - raise MissingDependency(msg) - - elif host in ["gcp"]: - check_gcloud_cli_installed() - - while not check_gcloud_authed(): - print("You need to log into Google Cloud") - login_gcloud() - - if DEPENDENCIES["ansible-playbook"]: - project_id = ask( - question=Question( - var_name="gcp_project_id", - question="What PROJECT ID do you want to use?", - default=arg_cache["gcp_project_id"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - zone = ask( - question=Question( - var_name="gcp_zone", - question="What zone do you want your VM in?", - default=arg_cache["gcp_zone"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - machine_type = ask( - question=Question( - var_name="gcp_machine_type", - question="What size machine?", - default=arg_cache["gcp_machine_type"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - username = ask( - question=Question( - var_name="gcp_username", - question="What is your shell username?", - default=arg_cache["gcp_username"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - key_path_question = Question( - var_name="gcp_key_path", - question=f"Private key to access user@{host}?", - default=arg_cache["gcp_key_path"], - kind="path", - cache=True, - ) - try: - key_path = ask( - key_path_question, - kwargs=kwargs, - ) - except QuestionInputPathError as e: - print(e) - key_path = str(e).split("is not a valid path")[0].strip() - - create_key_question = Question( - var_name="gcp_key_path", - question=f"Key {key_path} does not exist. Do you want gcloud to make it? (y/n)", - default="y", - kind="yesno", - ) - create_key = ask( - create_key_question, - kwargs=kwargs, - ) - if create_key == "y": - key_path = generate_gcloud_key_at_path(key_path=key_path) - else: - raise QuestionInputError( - "Unable to create VM without a private key" - ) - - repo = ask( - Question( - var_name="gcp_repo", - question="Repo to fetch source from?", - default=arg_cache["gcp_repo"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - branch = ask( - Question( - var_name="gcp_branch", - question="Branch to monitor for updates?", - default=arg_cache["gcp_branch"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - use_branch(branch=branch) - - auth = AuthCredentials(username=username, key_path=key_path) - - return create_launch_gcp_cmd( - verb=verb, - project_id=project_id, - zone=zone, - machine_type=machine_type, - repo=repo, - auth=auth, - branch=branch, - ansible_extras=kwargs["ansible_extras"], - kwargs=parsed_kwargs, - ) - else: - errors = [] - if not DEPENDENCIES["ansible-playbook"]: - errors.append("ansible-playbook") - msg = "\nERROR!!! MISSING DEPENDENCY!!!" - msg += f"\n\nLaunching a Cloud VM requires: {' '.join(errors)}" - msg += "\n\nPlease follow installation instructions: " - msg += "https://docs.ansible.com/ansible/latest/installation_guide/intro_installation.html#" - msg += "\n\nNote: we've found the 'conda' based installation instructions to work best" - msg += " (e.g. something lke 'conda install -c conda-forge ansible'). " - msg += "The pip based instructions seem to be a bit buggy if you're using a conda environment" - msg += "\n" - raise MissingDependency(msg) - - elif host in ["aws"]: - check_aws_cli_installed() - - if DEPENDENCIES["ansible-playbook"]: - aws_region = ask( - question=Question( - var_name="aws_region", - question="In what region do you want to deploy the EC2 instance?", - default=arg_cache["aws_region"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - aws_security_group_name = ask( - question=Question( - var_name="aws_security_group_name", - question="Name of the security group to be created?", - default=arg_cache["aws_security_group_name"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - aws_security_group_cidr = ask( - question=Question( - var_name="aws_security_group_cidr", - question="What IP addresses to allow for incoming network traffic? Please use CIDR notation", - default=arg_cache["aws_security_group_cidr"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - ec2_instance_type = ask( - question=Question( - var_name="aws_ec2_instance_type", - question="What EC2 instance type do you want to deploy?", - default=arg_cache["aws_ec2_instance_type"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - aws_key_name = ask( - question=Question( - var_name="aws_key_name", - question="Enter the name of the key pair to use to connect to the EC2 instance", - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - key_path_qn_str = ( - "Please provide the path of the private key to connect to the instance" - ) - key_path_qn_str += " (if it does not exist, this path corresponds to " - key_path_qn_str += "where you want to store the key upon creation)" - key_path_question = Question( - var_name="aws_key_path", - question=key_path_qn_str, - kind="path", - cache=True, - ) - try: - key_path = ask( - key_path_question, - kwargs=kwargs, - ) - except QuestionInputPathError as e: - print(e) - key_path = str(e).split("is not a valid path")[0].strip() - - create_key_question = Question( - var_name="aws_key_path", - question=f"Key {key_path} does not exist. Do you want AWS to make it? (y/n)", - default="y", - kind="yesno", - ) - create_key = ask( - create_key_question, - kwargs=kwargs, - ) - if create_key == "y": - key_path = generate_aws_key_at_path( - key_path=key_path, key_name=aws_key_name - ) - else: - raise QuestionInputError( - "Unable to create EC2 instance without key" - ) - - repo = ask( - Question( - var_name="aws_repo", - question="Repo to fetch source from?", - default=arg_cache["aws_repo"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - branch = ask( - Question( - var_name="aws_branch", - question="Branch to monitor for updates?", - default=arg_cache["aws_branch"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - use_branch(branch=branch) - - username = arg_cache["aws_ec2_instance_username"] - auth = AuthCredentials(username=username, key_path=key_path) - - return create_launch_aws_cmd( - verb=verb, - region=aws_region, - ec2_instance_type=ec2_instance_type, - security_group_name=aws_security_group_name, - aws_security_group_cidr=aws_security_group_cidr, - key_path=key_path, - key_name=aws_key_name, - repo=repo, - branch=branch, - ansible_extras=kwargs["ansible_extras"], - kwargs=parsed_kwargs, - ami_id=arg_cache["aws_image_id"], - username=username, - auth=auth, - ) - - else: - errors = [] - if not DEPENDENCIES["ansible-playbook"]: - errors.append("ansible-playbook") - msg = "\nERROR!!! MISSING DEPENDENCY!!!" - msg += f"\n\nLaunching a Cloud VM requires: {' '.join(errors)}" - msg += "\n\nPlease follow installation instructions: " - msg += "https://docs.ansible.com/ansible/latest/installation_guide/intro_installation.html#" - msg += "\n\nNote: we've found the 'conda' based installation instructions to work best" - msg += " (e.g. something lke 'conda install -c conda-forge ansible'). " - msg += "The pip based instructions seem to be a bit buggy if you're using a conda environment" - msg += "\n" - raise MissingDependency(msg) - else: - if DEPENDENCIES["ansible-playbook"]: - if host != "localhost": - parsed_kwargs["username"] = ask( - question=Question( - var_name="username", - question=f"Username for {host} with sudo privledges?", - default=arg_cache["username"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - parsed_kwargs["auth_type"] = ask( - question=Question( - var_name="auth_type", - question="Do you want to login with a key or password", - default=arg_cache["auth_type"], - kind="option", - options=["key", "password"], - cache=True, - ), - kwargs=kwargs, - ) - if parsed_kwargs["auth_type"] == "key": - parsed_kwargs["key_path"] = ask( - question=Question( - var_name="key_path", - question=f"Private key to access {parsed_kwargs['username']}@{host}?", - default=arg_cache["key_path"], - kind="path", - cache=True, - ), - kwargs=kwargs, - ) - elif parsed_kwargs["auth_type"] == "password": - parsed_kwargs["password"] = ask( - question=Question( - var_name="password", - question=f"Password for {parsed_kwargs['username']}@{host}?", - kind="password", - ), - kwargs=kwargs, - ) - - parsed_kwargs["repo"] = ask( - question=Question( - var_name="repo", - question="Repo to fetch source from?", - default=arg_cache["repo"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - parsed_kwargs["branch"] = ask( - Question( - var_name="branch", - question="Branch to monitor for updates?", - default=arg_cache["branch"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - auth = None - if host != "localhost": - if parsed_kwargs["auth_type"] == "key": - auth = AuthCredentials( - username=parsed_kwargs["username"], - key_path=parsed_kwargs["key_path"], - ) - else: - auth = AuthCredentials( - username=parsed_kwargs["username"], - key_path=parsed_kwargs["password"], - ) - if not auth.valid: - raise Exception(f"Login Credentials are not valid. {auth}") - parsed_kwargs["ansible_extras"] = kwargs["ansible_extras"] - return create_launch_custom_cmd(verb=verb, auth=auth, kwargs=parsed_kwargs) - else: - errors = [] - if not DEPENDENCIES["ansible-playbook"]: - errors.append("ansible-playbook") - raise MissingDependency( - f"Launching a Custom VM requires: {' '.join(errors)}" - ) - - host_options = ", ".join(allowed_hosts) - raise MissingDependency( - f"Launch requires a correct host option, try: {host_options}" - ) - - -def pull_command(cmd: str, kwargs: dict[str, Any]) -> list[str]: - pull_cmd = str(cmd) - if kwargs["release"] == "production": - pull_cmd += " --file docker-compose.yml" - else: - pull_cmd += " --file docker-compose.pull.yml" - pull_cmd += " pull --ignore-pull-failures" # ignore missing version from Dockerhub - return [pull_cmd] - - -def build_command(cmd: str) -> list[str]: - build_cmd = str(cmd) - build_cmd += " --file docker-compose.build.yml" - build_cmd += " build" - return [build_cmd] - - -def deploy_command(cmd: str, tail: bool, dev_mode: bool) -> list[str]: - up_cmd = str(cmd) - up_cmd += " --file docker-compose.dev.yml" if dev_mode else "" - up_cmd += " up" - if not tail: - up_cmd += " -d" - return [up_cmd] - - -def create_launch_docker_cmd( - verb: GrammarVerb, - docker_version: str, - kwargs: dict[str, Any], - tail: bool = True, - silent: bool = False, -) -> dict[str, list[str]]: - host_term = verb.get_named_term_hostgrammar(name="host") - node_name = verb.get_named_term_type(name="node_name") - node_type = verb.get_named_term_type(name="node_type") - - snake_name = str(node_name.snake_input) - tag = name_tag(name=str(node_name.input)) - - if ART and not silent: - hagrid() - - print( - "Launching a PyGrid " - + str(node_type.input).capitalize() - + " node on port " - + str(host_term.free_port) - + "!\n" - ) - - version_string = kwargs["tag"] - version_hash = "dockerhub" - build = kwargs["build"] - - # if in development mode, generate a version_string which is either - # the one you inputed concatenated with -dev or the contents of the VERSION file - version = GRID_SRC_VERSION() - if "release" in kwargs and kwargs["release"] == "development": - # force version to have -dev at the end in dev mode - # during development we can use the latest beta version - if version_string is None: - version_string = version[0] - version_string += "-dev" - version_hash = version[1] - build = True - else: - # whereas if in production mode and tag == "local" use the local VERSION file - # or if its not set somehow, which should never happen, use stable - # otherwise use the kwargs["tag"] from above - - # during production the default would be stable - if version_string == "local": - # this can be used in VMs in production to auto update from src - version_string = version[0] - version_hash = version[1] - build = True - elif version_string is None: - version_string = "latest" - - if platform.uname().machine.lower() in ["x86_64", "amd64"]: - docker_platform = "linux/amd64" - else: - docker_platform = "linux/arm64" - - if "platform" in kwargs and kwargs["platform"] is not None: - docker_platform = kwargs["platform"] - - if kwargs["template"]: - _, template_hash = get_template_yml(kwargs["template"]) - template_dir = manifest_cache_path(template_hash) - template_grid_dir = f"{template_dir}/packages/grid" - else: - template_grid_dir = GRID_SRC_PATH() - - compose_src_path = kwargs["compose_src_path"] - if not compose_src_path: - compose_src_path = get_compose_src_path( - node_type=node_type, - node_name=snake_name, - template_location=kwargs["template"], - **kwargs, - ) - - default_env = f"{template_grid_dir}/default.env" - if not os.path.exists(default_env): - # old path - default_env = f"{template_grid_dir}/.env" - default_envs = {} - with open(default_env) as f: - for line in f.readlines(): - if "=" in line: - parts = line.strip().split("=") - key = parts[0] - value = "" - if len(parts) > 1: - value = parts[1] - default_envs[key] = value - - single_container_mode = kwargs["deployment_type"] == "single_container" - in_mem_workers = kwargs.get("in_mem_workers") - smtp_username = kwargs.get("smtp_username") - smtp_sender = kwargs.get("smtp_sender") - smtp_password = kwargs.get("smtp_password") - smtp_port = kwargs.get("smtp_port") - if smtp_port is None or smtp_port == "": - smtp_port = int(default_envs["SMTP_PORT"]) - smtp_host = kwargs.get("smtp_host") - - print(" - NAME: " + str(snake_name)) - print(" - TEMPLATE DIR: " + template_grid_dir) - if compose_src_path: - print(" - COMPOSE SOURCE: " + compose_src_path) - print(" - RELEASE: " + f'{kwargs["node_side_type"]}-{kwargs["release"]}') - print(" - DEPLOYMENT:", kwargs["deployment_type"]) - print(" - ARCH: " + docker_platform) - print(" - TYPE: " + str(node_type.input)) - print(" - DOCKER_TAG: " + version_string) - if version_hash != "dockerhub": - print(" - GIT_HASH: " + version_hash) - print(" - HAGRID_VERSION: " + get_version_string()) - if EDITABLE_MODE: - print(" - HAGRID_REPO_SHA: " + commit_hash()) - print(" - PORT: " + str(host_term.free_port)) - print(" - DOCKER COMPOSE: " + docker_version) - print(" - IN-MEMORY WORKERS: " + str(in_mem_workers)) - print("\n") - - use_blob_storage = ( - False - if str(node_type.input) in ["network", "gateway"] - else bool(kwargs["use_blob_storage"]) - ) - - enable_rathole = bool(kwargs.get("enable_rathole")) or str(node_type.input) in [ - "network", - "gateway", - ] - - # use a docker volume - host_path = "credentials-data" - - # # in development use a folder mount - # if kwargs.get("release", "") == "development": - # RELATIVE_PATH = "" - # # if EDITABLE_MODE: - # # RELATIVE_PATH = "../" - # # we might need to change this for the hagrid template mode - # host_path = f"{RELATIVE_PATH}./data/storage/{snake_name}" - - rathole_mode = ( - "client" if enable_rathole and str(node_type.input) in ["domain"] else "server" - ) - - envs = { - "RELEASE": "production", - "COMPOSE_DOCKER_CLI_BUILD": 1, - "DOCKER_BUILDKIT": 1, - "HTTP_PORT": int(host_term.free_port), - "HTTPS_PORT": int(host_term.free_port_tls), - "TRAEFIK_TAG": str(tag), - "NODE_NAME": str(snake_name), - "NODE_TYPE": str(node_type.input), - "TRAEFIK_PUBLIC_NETWORK_IS_EXTERNAL": "False", - "VERSION": version_string, - "VERSION_HASH": version_hash, - "USE_BLOB_STORAGE": str(use_blob_storage), - "FRONTEND_TARGET": "grid-ui-production", - "STACK_API_KEY": str( - generate_sec_random_password(length=48, special_chars=False) - ), - "CREDENTIALS_VOLUME": host_path, - "NODE_SIDE_TYPE": kwargs["node_side_type"], - "SINGLE_CONTAINER_MODE": single_container_mode, - "INMEMORY_WORKERS": in_mem_workers, - } - - if smtp_host and smtp_port and smtp_username and smtp_password: - envs["SMTP_HOST"] = smtp_host - envs["SMTP_PORT"] = smtp_port - envs["SMTP_USERNAME"] = smtp_username - envs["SMTP_PASSWORD"] = smtp_password - envs["EMAIL_SENDER"] = smtp_sender - - if "trace" in kwargs and kwargs["trace"] is True: - envs["TRACE"] = "True" - envs["JAEGER_HOST"] = "host.docker.internal" - envs["JAEGER_PORT"] = int( - find_available_port(host="localhost", port=14268, search=True) - ) - - if "association_request_auto_approval" in kwargs: - envs["ASSOCIATION_REQUEST_AUTO_APPROVAL"] = kwargs[ - "association_request_auto_approval" - ] - - if "enable_warnings" in kwargs: - envs["ENABLE_WARNINGS"] = kwargs["enable_warnings"] - - if "platform" in kwargs and kwargs["platform"] is not None: - envs["DOCKER_DEFAULT_PLATFORM"] = docker_platform - - if "tls" in kwargs and kwargs["tls"] is True and len(kwargs["cert_store_path"]) > 0: - envs["TRAEFIK_TLS_CERTS"] = kwargs["cert_store_path"] - - if ( - "tls" in kwargs - and kwargs["tls"] is True - and "test" in kwargs - and kwargs["test"] is True - ): - envs["IGNORE_TLS_ERRORS"] = "True" - - if "test" in kwargs and kwargs["test"] is True: - envs["SWFS_VOLUME_SIZE_LIMIT_MB"] = "100" # GitHub CI is small - - if kwargs.get("release", "") == "development": - envs["RABBITMQ_MANAGEMENT"] = "-management" - - # currently we only have a domain frontend for dev mode - if kwargs.get("release", "") == "development" and ( - str(node_type.input) not in ["network", "gateway"] - ): - envs["FRONTEND_TARGET"] = "grid-ui-development" - - if "set_root_password" in kwargs and kwargs["set_root_password"] is not None: - envs["DEFAULT_ROOT_PASSWORD"] = kwargs["set_root_password"] - - if "set_root_email" in kwargs and kwargs["set_root_email"] is not None: - envs["DEFAULT_ROOT_EMAIL"] = kwargs["set_root_email"] - - if "set_s3_username" in kwargs and kwargs["set_s3_username"] is not None: - envs["S3_ROOT_USER"] = kwargs["set_s3_username"] - - if "set_s3_password" in kwargs and kwargs["set_s3_password"] is not None: - envs["S3_ROOT_PWD"] = kwargs["set_s3_password"] - - if ( - "set_volume_size_limit_mb" in kwargs - and kwargs["set_volume_size_limit_mb"] is not None - ): - envs["SWFS_VOLUME_SIZE_LIMIT_MB"] = kwargs["set_volume_size_limit_mb"] - - if "release" in kwargs: - envs["RELEASE"] = kwargs["release"] - - if "enable_signup" in kwargs: - envs["ENABLE_SIGNUP"] = kwargs["enable_signup"] - - if enable_rathole: - envs["MODE"] = rathole_mode - - cmd = "" - args = [] - for k, v in envs.items(): - if is_windows(): - # powershell envs - quoted = f"'{v}'" if not isinstance(v, int) else v - args.append(f"$env:{k}={quoted}") - else: - args.append(f"{k}={v}") - if is_windows(): - cmd += "; ".join(args) - cmd += "; " - else: - cmd += " ".join(args) - - cmd += " docker compose -p " + snake_name - - # new docker compose regression work around - # default_env = os.path.expanduser("~/.hagrid/app/.env") - - default_envs.update(envs) - - # env file path - env_file_path = compose_src_path + "/.env" - - # Render templates if creating stack from the manifest_template.yml - if kwargs["template"] and host_term.host is not None: - # If release is development, update relative path - # if EDITABLE_MODE: - # default_envs["RELATIVE_PATH"] = "../" - - render_templates( - node_name=snake_name, - deployment_type=kwargs["deployment_type"], - template_location=kwargs["template"], - env_vars=default_envs, - host_type=host_term.host, - ) - - try: - env_file = "" - for k, v in default_envs.items(): - env_file += f"{k}={v}\n" - - with open(env_file_path, "w") as f: - f.write(env_file) - - # cmd += f" --env-file {env_file_path}" - except Exception: # nosec - pass - - if single_container_mode: - cmd += " --profile worker" - else: - cmd += " --profile backend" - cmd += " --profile proxy" - cmd += " --profile mongo" - - if str(node_type.input) in ["network", "gateway"]: - cmd += " --profile network" - - if use_blob_storage: - cmd += " --profile blob-storage" - - if enable_rathole: - cmd += " --profile rathole" - - # no frontend container so expect bad gateway on the / route - if not bool(kwargs["headless"]): - cmd += " --profile frontend" - - if "trace" in kwargs and kwargs["trace"]: - cmd += " --profile telemetry" - - final_commands = {} - final_commands["Pulling"] = pull_command(cmd, kwargs) - - cmd += " --file docker-compose.yml" - if "tls" in kwargs and kwargs["tls"] is True: - cmd += " --file docker-compose.tls.yml" - if "test" in kwargs and kwargs["test"] is True: - cmd += " --file docker-compose.test.yml" - - if build: - my_build_command = build_command(cmd) - final_commands["Building"] = my_build_command - - dev_mode = kwargs.get("dev", False) - final_commands["Launching"] = deploy_command(cmd, tail, dev_mode) - return final_commands - - -def create_launch_vagrant_cmd(verb: GrammarVerb) -> str: - host_term = verb.get_named_term_hostgrammar(name="host") - node_name = verb.get_named_term_type(name="node_name") - node_type = verb.get_named_term_type(name="node_type") - - snake_name = str(node_name.snake_input) - - if ART: - hagrid() - - print( - "Launching a " - + str(node_type.input) - + " PyGrid node on port " - + str(host_term.port) - + "!\n" - ) - - print(" - TYPE: " + str(node_type.input)) - print(" - NAME: " + str(snake_name)) - print(" - PORT: " + str(host_term.port)) - # print(" - VAGRANT: " + "1") - # print(" - VIRTUALBOX: " + "1") - print("\n") - - cmd = "" - cmd += 'ANSIBLE_ARGS="' - cmd += f"-e 'node_name={snake_name}'" - cmd += f"-e 'node_type={node_type.input}'" - cmd += '" ' - cmd += "vagrant up --provision" - cmd = "cd " + GRID_SRC_PATH() + ";" + cmd - return cmd - - -def get_or_make_resource_group(resource_group: str, location: str = "westus") -> None: - cmd = f"az group show --resource-group {resource_group}" - exists = True - try: - subprocess.check_call(cmd, shell=True) # nosec - except Exception: # nosec - # group doesn't exist so lets create it - exists = False - - if not exists: - cmd = f"az group create -l {location} -n {resource_group}" - try: - print(f"Creating resource group.\nRunning: {cmd}") - subprocess.check_call(cmd, shell=True) # nosec - except Exception as e: - raise Exception( - f"Unable to create resource group {resource_group} @ {location}. {e}" - ) - - -def extract_host_ip(stdout: bytes) -> str | None: - output = stdout.decode("utf-8") - - try: - j = json.loads(output) - if "publicIpAddress" in j: - return str(j["publicIpAddress"]) - except Exception: # nosec - matcher = r'publicIpAddress":\s+"(.+)"' - ips = re.findall(matcher, output) - if len(ips) > 0: - return ips[0] - - return None - - -def get_vm_host_ips(node_name: str, resource_group: str) -> list | None: - cmd = f"az vm list-ip-addresses -g {resource_group} --query " - cmd += f""""[?starts_with(virtualMachine.name, '{node_name}')]""" - cmd += '''.virtualMachine.network.publicIpAddresses[0].ipAddress"''' - output = subprocess.check_output(cmd, shell=True) # nosec - try: - host_ips = json.loads(output) - return host_ips - except Exception as e: - print(f"Failed to extract ips: {e}") - - return None - - -def is_valid_ip(host_or_ip: str) -> bool: - matcher = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" - ips = re.findall(matcher, host_or_ip.strip()) - if len(ips) == 1: - return True - return False - - -def extract_host_ip_gcp(stdout: bytes) -> str | None: - output = stdout.decode("utf-8") - - try: - matcher = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" - ips = re.findall(matcher, output) - if len(ips) == 2: - return ips[1] - except Exception: # nosec - pass - - return None - - -def extract_host_ip_from_cmd(cmd: str) -> str | None: - try: - matcher = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" - ips = re.findall(matcher, cmd) - if ips: - return ips[0] - except Exception: # nosec - pass - - return None - - -def check_ip_for_ssh( - host_ip: str, timeout: int = 600, wait_time: int = 5, silent: bool = False -) -> bool: - if not silent: - print(f"Checking VM at {host_ip} is up") - checks = int(timeout / wait_time) # 10 minutes in 5 second chunks - first_run = True - while checks > 0: - checks -= 1 - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(wait_time) - result = sock.connect_ex((host_ip, 22)) - sock.close() - if result == 0: - if not silent: - print(f"VM at {host_ip} is up!") - return True - else: - if first_run: - if not silent: - print("Waiting for VM to start", end="", flush=True) - first_run = False - else: - if not silent: - print(".", end="", flush=True) - except Exception: # nosec - pass - return False - - -def create_aws_security_group( - security_group_name: str, region: str, snake_name: str -) -> str: - sg_description = f"{snake_name} security group" - create_cmd = f"aws ec2 create-security-group --group-name {security_group_name} " - create_cmd += f'--region {region} --description "{sg_description}" ' - sg_output = subprocess.check_output( # nosec - create_cmd, - shell=True, - ) - sg_output_dict = json.loads(sg_output) - if "GroupId" in sg_output_dict: - return sg_output_dict["GroupId"] - - return "" - - -def open_port_aws( - security_group_name: str, port_no: int, cidr: str, region: str -) -> None: - cmd = f"aws ec2 authorize-security-group-ingress --group-name {security_group_name} --protocol tcp " - cmd += f"--port {port_no} --cidr {cidr} --region {region}" - subprocess.check_call( # nosec - cmd, - shell=True, - ) - - -def extract_instance_ids_aws(stdout: bytes) -> list: - output = stdout.decode("utf-8") - output_dict = json.loads(output) - instance_ids: list = [] - if "Instances" in output_dict: - for ec2_instance_metadata in output_dict["Instances"]: - if "InstanceId" in ec2_instance_metadata: - instance_ids.append(ec2_instance_metadata["InstanceId"]) - - return instance_ids - - -def get_host_ips_given_instance_ids( - instance_ids: list, timeout: int = 600, wait_time: int = 10 -) -> list: - checks = int(timeout / wait_time) # 10 minutes in 10 second chunks - instance_ids_str = " ".join(instance_ids) - cmd = f"aws ec2 describe-instances --instance-ids {instance_ids_str}" - cmd += " --query 'Reservations[*].Instances[*].{StateName:State.Name,PublicIpAddress:PublicIpAddress}'" - cmd += " --output json" - while checks > 0: - checks -= 1 - time.sleep(wait_time) - desc_ec2_output = subprocess.check_output(cmd, shell=True) # nosec - instances_output_json = json.loads(desc_ec2_output.decode("utf-8")) - host_ips: list = [] - all_instances_running = True - for reservation in instances_output_json: - for instance_metadata in reservation: - if instance_metadata["StateName"] != "running": - all_instances_running = False - break - else: - host_ips.append(instance_metadata["PublicIpAddress"]) - if all_instances_running: - return host_ips - # else, wait another wait_time seconds and try again - - return [] - - -def make_aws_ec2_instance( - ami_id: str, ec2_instance_type: str, key_name: str, security_group_name: str -) -> list: - # From the docs: "For security groups in a nondefault VPC, you must specify the security group ID". - # Right now, since we're using default VPC, we can use security group name instead of ID. - - ebs_size = 200 # gb - cmd = f"aws ec2 run-instances --image-id {ami_id} --count 1 --instance-type {ec2_instance_type} " - cmd += f"--key-name {key_name} --security-groups {security_group_name} " - tmp_cmd = rf"[{{\"DeviceName\":\"/dev/sdf\",\"Ebs\":{{\"VolumeSize\":{ebs_size},\"DeleteOnTermination\":false}}}}]" - cmd += f'--block-device-mappings "{tmp_cmd}"' - - host_ips: list = [] - try: - print(f"Creating EC2 instance.\nRunning: {cmd}") - create_ec2_output = subprocess.check_output(cmd, shell=True) # nosec - instance_ids = extract_instance_ids_aws(create_ec2_output) - host_ips = get_host_ips_given_instance_ids(instance_ids=instance_ids) - except Exception as e: - print("failed", e) - - if not (host_ips): - raise Exception("Failed to create EC2 instance(s) or get public ip(s)") - - return host_ips - - -def create_launch_aws_cmd( - verb: GrammarVerb, - region: str, - ec2_instance_type: str, - security_group_name: str, - aws_security_group_cidr: str, - key_name: str, - key_path: str, - ansible_extras: str, - kwargs: dict[str, Any], - repo: str, - branch: str, - ami_id: str, - username: str, - auth: AuthCredentials, -) -> list[str]: - node_name = verb.get_named_term_type(name="node_name") - snake_name = str(node_name.snake_input) - create_aws_security_group(security_group_name, region, snake_name) - open_port_aws( - security_group_name=security_group_name, - port_no=80, - cidr=aws_security_group_cidr, - region=region, - ) # HTTP - open_port_aws( - security_group_name=security_group_name, - port_no=443, - cidr=aws_security_group_cidr, - region=region, - ) # HTTPS - open_port_aws( - security_group_name=security_group_name, - port_no=22, - cidr=aws_security_group_cidr, - region=region, - ) # SSH - if kwargs["jupyter"]: - open_port_aws( - security_group_name=security_group_name, - port_no=8888, - cidr=aws_security_group_cidr, - region=region, - ) # Jupyter - - host_ips = make_aws_ec2_instance( - ami_id=ami_id, - ec2_instance_type=ec2_instance_type, - key_name=key_name, - security_group_name=security_group_name, - ) - - launch_cmds: list[str] = [] - - for host_ip in host_ips: - # get old host - host_term = verb.get_named_term_hostgrammar(name="host") - - # replace - host_term.parse_input(host_ip) - verb.set_named_term_type(name="host", new_term=host_term) - - if not bool(kwargs["provision"]): - print("Skipping automatic provisioning.") - print("VM created with:") - print(f"IP: {host_ip}") - print(f"Key: {key_path}") - print("\nConnect with:") - print(f"ssh -i {key_path} {username}@{host_ip}") - - else: - extra_kwargs = { - "repo": repo, - "branch": branch, - "ansible_extras": ansible_extras, - } - kwargs.update(extra_kwargs) - - # provision - host_up = check_ip_for_ssh(host_ip=host_ip) - if not host_up: - print(f"Warning: {host_ip} ssh not available yet") - launch_cmd = create_launch_custom_cmd(verb=verb, auth=auth, kwargs=kwargs) - launch_cmds.append(launch_cmd) - - return launch_cmds - - -def make_vm_azure( - node_name: str, - resource_group: str, - username: str, - password: str | None, - key_path: str | None, - size: str, - image_name: str, - node_count: int, -) -> list: - disk_size_gb = "200" - try: - temp_dir = tempfile.TemporaryDirectory() - public_key_path = ( - private_to_public_key( - private_key_path=key_path, temp_path=temp_dir.name, username=username - ) - if key_path - else None - ) - except Exception: # nosec - temp_dir.cleanup() - - authentication_type = "ssh" if key_path else "password" - cmd = f"az vm create -n {node_name} -g {resource_group} --size {size} " - cmd += f"--image {image_name} --os-disk-size-gb {disk_size_gb} " - cmd += f"--public-ip-sku Standard --authentication-type {authentication_type} --admin-username {username} " - cmd += f"--ssh-key-values {public_key_path} " if public_key_path else "" - cmd += f"--admin-password '{password}' " if password else "" - cmd += f"--count {node_count} " if node_count > 1 else "" - - host_ips: list | None = [] - try: - print(f"Creating vm.\nRunning: {hide_azure_vm_password(cmd)}") - subprocess.check_output(cmd, shell=True) # nosec - host_ips = get_vm_host_ips(node_name=node_name, resource_group=resource_group) - except Exception as e: - print("failed", e) - finally: - temp_dir.cleanup() - - if not host_ips: - raise Exception("Failed to create vm or get VM public ip") - - try: - # clean up temp public key - if public_key_path: - os.unlink(public_key_path) - except Exception: # nosec - pass - - return host_ips - - -def open_port_vm_azure( - resource_group: str, node_name: str, port_name: str, port: int, priority: int -) -> None: - cmd = f"az network nsg rule create --resource-group {resource_group} " - cmd += f"--nsg-name {node_name}NSG --name {port_name} --destination-port-ranges {port} --priority {priority}" - try: - print(f"Creating {port_name} {port} ngs rule.\nRunning: {cmd}") - output = subprocess.check_call(cmd, shell=True) # nosec - print("output", output) - pass - except Exception as e: - print("failed", e) - - -def create_project(project_id: str) -> None: - cmd = f"gcloud projects create {project_id} --set-as-default" - try: - print(f"Creating project.\nRunning: {cmd}") - subprocess.check_call(cmd, shell=True) # nosec - except Exception as e: - print("failed", e) - - print("create project complete") - - -def create_launch_gcp_cmd( - verb: GrammarVerb, - project_id: str, - zone: str, - machine_type: str, - ansible_extras: str, - kwargs: dict[str, Any], - repo: str, - branch: str, - auth: AuthCredentials, -) -> str: - # create project if it doesn't exist - create_project(project_id) - # vm - node_name = verb.get_named_term_type(name="node_name") - kebab_name = str(node_name.kebab_input) - disk_size_gb = "200" - host_ip = make_gcp_vm( - vm_name=kebab_name, - project_id=project_id, - zone=zone, - machine_type=machine_type, - disk_size_gb=disk_size_gb, - ) - - # get old host - host_term = verb.get_named_term_hostgrammar(name="host") - - host_up = check_ip_for_ssh(host_ip=host_ip) - if not host_up: - raise Exception(f"Something went wrong launching the VM at IP: {host_ip}.") - - if not bool(kwargs["provision"]): - print("Skipping automatic provisioning.") - print("VM created with:") - print(f"IP: {host_ip}") - print(f"User: {auth.username}") - print(f"Key: {auth.key_path}") - print("\nConnect with:") - print(f"ssh -i {auth.key_path} {auth.username}@{host_ip}") - sys.exit(0) - - # replace - host_term.parse_input(host_ip) - verb.set_named_term_type(name="host", new_term=host_term) - - extra_kwargs = { - "repo": repo, - "branch": branch, - "auth_type": "key", - "ansible_extras": ansible_extras, - } - kwargs.update(extra_kwargs) - - # provision - return create_launch_custom_cmd(verb=verb, auth=auth, kwargs=kwargs) - - -def make_gcp_vm( - vm_name: str, project_id: str, zone: str, machine_type: str, disk_size_gb: str -) -> str: - create_cmd = "gcloud compute instances create" - network_settings = "network=default,network-tier=PREMIUM" - maintenance_policy = "MIGRATE" - scopes = [ - "https://www.googleapis.com/auth/devstorage.read_only", - "https://www.googleapis.com/auth/logging.write", - "https://www.googleapis.com/auth/monitoring.write", - "https://www.googleapis.com/auth/servicecontrol", - "https://www.googleapis.com/auth/service.management.readonly", - "https://www.googleapis.com/auth/trace.append", - ] - tags = "http-server,https-server" - disk_image = "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230429" - disk = ( - f"auto-delete=yes,boot=yes,device-name={vm_name},image={disk_image}," - + f"mode=rw,size={disk_size_gb},type=pd-ssd" - ) - security_flags = ( - "--no-shielded-secure-boot --shielded-vtpm " - + "--shielded-integrity-monitoring --reservation-affinity=any" - ) - - cmd = ( - f"{create_cmd} {vm_name} " - + f"--project={project_id} " - + f"--zone={zone} " - + f"--machine-type={machine_type} " - + f"--create-disk={disk} " - + f"--network-interface={network_settings} " - + f"--maintenance-policy={maintenance_policy} " - + f"--scopes={','.join(scopes)} --tags={tags} " - + f"{security_flags}" - ) - - host_ip = None - try: - print(f"Creating vm.\nRunning: {cmd}") - output = subprocess.check_output(cmd, shell=True) # nosec - host_ip = extract_host_ip_gcp(stdout=output) - except Exception as e: - print("failed", e) - - if host_ip is None: - raise Exception("Failed to create vm or get VM public ip") - - return host_ip - - -def create_launch_azure_cmd( - verb: GrammarVerb, - resource_group: str, - location: str, - size: str, - username: str, - password: str | None, - key_path: str | None, - repo: str, - branch: str, - auth: AuthCredentials, - ansible_extras: str, - kwargs: dict[str, Any], -) -> list[str]: - get_or_make_resource_group(resource_group=resource_group, location=location) - - node_count = kwargs.get("node_count", 1) - print("Total VMs to create: ", node_count) - - # vm - node_name = verb.get_named_term_type(name="node_name") - snake_name = str(node_name.snake_input) - image_name = get_azure_image(kwargs["image_name"]) - host_ips = make_vm_azure( - snake_name, - resource_group, - username, - password, - key_path, - size, - image_name, - node_count, - ) - - # open port 80 - open_port_vm_azure( - resource_group=resource_group, - node_name=snake_name, - port_name="HTTP", - port=80, - priority=500, - ) - - # open port 443 - open_port_vm_azure( - resource_group=resource_group, - node_name=snake_name, - port_name="HTTPS", - port=443, - priority=501, - ) - - if kwargs["jupyter"]: - # open port 8888 - open_port_vm_azure( - resource_group=resource_group, - node_name=snake_name, - port_name="Jupyter", - port=8888, - priority=502, - ) - - launch_cmds: list[str] = [] - - for host_ip in host_ips: - # get old host - host_term = verb.get_named_term_hostgrammar(name="host") - - # replace - host_term.parse_input(host_ip) - verb.set_named_term_type(name="host", new_term=host_term) - - if not bool(kwargs["provision"]): - print("Skipping automatic provisioning.") - print("VM created with:") - print(f"Name: {snake_name}") - print(f"IP: {host_ip}") - print(f"User: {username}") - print(f"Password: {password}") - print(f"Key: {key_path}") - print("\nConnect with:") - if kwargs["auth_type"] == "key": - print(f"ssh -i {key_path} {username}@{host_ip}") - else: - print(f"ssh {username}@{host_ip}") - else: - extra_kwargs = { - "repo": repo, - "branch": branch, - "ansible_extras": ansible_extras, - } - kwargs.update(extra_kwargs) - - # provision - host_up = check_ip_for_ssh(host_ip=host_ip) - if not host_up: - print(f"Warning: {host_ip} ssh not available yet") - launch_cmd = create_launch_custom_cmd(verb=verb, auth=auth, kwargs=kwargs) - launch_cmds.append(launch_cmd) - - return launch_cmds - - -def create_ansible_land_cmd( - verb: GrammarVerb, auth: AuthCredentials | None, kwargs: dict[str, Any] -) -> str: - try: - host_term = verb.get_named_term_hostgrammar(name="host") - print("Landing PyGrid node on port " + str(host_term.port) + "!\n") - - print(" - PORT: " + str(host_term.port)) - print("\n") - - grid_path = GRID_SRC_PATH() - playbook_path = grid_path + "/ansible/site.yml" - ansible_cfg_path = grid_path + "/ansible.cfg" - auth = cast(AuthCredentials, auth) - - if not os.path.exists(playbook_path): - print(f"Can't find playbook site.yml at: {playbook_path}") - cmd = f"ANSIBLE_CONFIG={ansible_cfg_path} ansible-playbook " - if host_term.host == "localhost": - cmd += "--connection=local " - cmd += f"-i {host_term.host}, {playbook_path}" - if host_term.host != "localhost" and kwargs["auth_type"] == "key": - cmd += f" --private-key {auth.key_path} --user {auth.username}" - elif host_term.host != "localhost" and kwargs["auth_type"] == "password": - cmd += f" -c paramiko --user {auth.username}" - - ANSIBLE_ARGS = {"install": "false"} - - if host_term.host != "localhost" and kwargs["auth_type"] == "password": - ANSIBLE_ARGS["ansible_ssh_pass"] = kwargs["password"] - - if host_term.host == "localhost": - ANSIBLE_ARGS["local"] = "true" - - if "ansible_extras" in kwargs and kwargs["ansible_extras"] != "": - options = kwargs["ansible_extras"].split(",") - for option in options: - parts = option.strip().split("=") - if len(parts) == 2: - ANSIBLE_ARGS[parts[0]] = parts[1] - - for k, v in ANSIBLE_ARGS.items(): - cmd += f" -e \"{k}='{v}'\"" - - cmd = "cd " + grid_path + ";" + cmd - return cmd - except Exception as e: - print(f"Failed to construct custom deployment cmd: {cmd}. {e}") - raise e - - -def create_launch_custom_cmd( - verb: GrammarVerb, auth: AuthCredentials | None, kwargs: dict[str, Any] -) -> str: - try: - host_term = verb.get_named_term_hostgrammar(name="host") - node_name = verb.get_named_term_type(name="node_name") - node_type = verb.get_named_term_type(name="node_type") - # source_term = verb.get_named_term_type(name="source") - - snake_name = str(node_name.snake_input) - - if ART: - hagrid() - - print( - "Launching a " - + str(node_type.input) - + " PyGrid node on port " - + str(host_term.port) - + "!\n" - ) - - print(" - TYPE: " + str(node_type.input)) - print(" - NAME: " + str(snake_name)) - print(" - PORT: " + str(host_term.port)) - print("\n") - - grid_path = GRID_SRC_PATH() - playbook_path = grid_path + "/ansible/site.yml" - ansible_cfg_path = grid_path + "/ansible.cfg" - auth = cast(AuthCredentials, auth) - - if not os.path.exists(playbook_path): - print(f"Can't find playbook site.yml at: {playbook_path}") - cmd = f"ANSIBLE_CONFIG={ansible_cfg_path} ansible-playbook " - if host_term.host == "localhost": - cmd += "--connection=local " - cmd += f"-i {host_term.host}, {playbook_path}" - if host_term.host != "localhost" and kwargs["auth_type"] == "key": - cmd += f" --private-key {auth.key_path} --user {auth.username}" - elif host_term.host != "localhost" and kwargs["auth_type"] == "password": - cmd += f" -c paramiko --user {auth.username}" - - version_string = kwargs["tag"] - if version_string is None: - version_string = "local" - - ANSIBLE_ARGS = { - "node_type": node_type.input, - "node_name": snake_name, - "github_repo": kwargs["repo"], - "repo_branch": kwargs["branch"], - "docker_tag": version_string, - } - - if host_term.host != "localhost" and kwargs["auth_type"] == "password": - ANSIBLE_ARGS["ansible_ssh_pass"] = kwargs["password"] - - if host_term.host == "localhost": - ANSIBLE_ARGS["local"] = "true" - - if "node_side_type" in kwargs: - ANSIBLE_ARGS["node_side_type"] = kwargs["node_side_type"] - - if kwargs["tls"] is True: - ANSIBLE_ARGS["tls"] = "true" - - if "release" in kwargs: - ANSIBLE_ARGS["release"] = kwargs["release"] - - if "set_root_email" in kwargs and kwargs["set_root_email"] is not None: - ANSIBLE_ARGS["root_user_email"] = kwargs["set_root_email"] - - if "set_root_password" in kwargs and kwargs["set_root_password"] is not None: - ANSIBLE_ARGS["root_user_password"] = kwargs["set_root_password"] - - if ( - kwargs["tls"] is True - and "cert_store_path" in kwargs - and len(kwargs["cert_store_path"]) > 0 - ): - ANSIBLE_ARGS["cert_store_path"] = kwargs["cert_store_path"] - - if ( - kwargs["tls"] is True - and "upload_tls_key" in kwargs - and len(kwargs["upload_tls_key"]) > 0 - ): - ANSIBLE_ARGS["upload_tls_key"] = kwargs["upload_tls_key"] - - if ( - kwargs["tls"] is True - and "upload_tls_cert" in kwargs - and len(kwargs["upload_tls_cert"]) > 0 - ): - ANSIBLE_ARGS["upload_tls_cert"] = kwargs["upload_tls_cert"] - - if kwargs["jupyter"] is True: - ANSIBLE_ARGS["jupyter"] = "true" - ANSIBLE_ARGS["jupyter_token"] = generate_sec_random_password( - length=48, upper_case=False, special_chars=False - ) - - if "ansible_extras" in kwargs and kwargs["ansible_extras"] != "": - options = kwargs["ansible_extras"].split(",") - for option in options: - parts = option.strip().split("=") - if len(parts) == 2: - ANSIBLE_ARGS[parts[0]] = parts[1] - - # if mode == "deploy": - # ANSIBLE_ARGS["deploy"] = "true" - - for k, v in ANSIBLE_ARGS.items(): - cmd += f" -e \"{k}='{v}'\"" - - cmd = "cd " + grid_path + ";" + cmd - return cmd - except Exception as e: - print(f"Failed to construct custom deployment cmd: {cmd}. {e}") - raise e - - -def create_land_cmd(verb: GrammarVerb, kwargs: dict[str, Any]) -> str: - host_term = verb.get_named_term_hostgrammar(name="host") - host = host_term.host if host_term.host is not None else "" - - if host in ["docker"]: - target = verb.get_named_term_grammar("node_name").input - prune_volumes: bool = kwargs.get("prune_vol", False) - - if target == "all": - # land all syft nodes - if prune_volumes: - land_cmd = "docker rm `docker ps --filter label=orgs.openmined.syft -q` --force " - land_cmd += "&& docker volume rm " - land_cmd += "$(docker volume ls --filter label=orgs.openmined.syft -q)" - return land_cmd - else: - return "docker rm `docker ps --filter label=orgs.openmined.syft -q` --force" - - version = check_docker_version() - if version: - return create_land_docker_cmd(verb=verb, prune_volumes=prune_volumes) - - elif host == "localhost" or is_valid_ip(host): - parsed_kwargs = {} - if DEPENDENCIES["ansible-playbook"]: - if host != "localhost": - parsed_kwargs["username"] = ask( - question=Question( - var_name="username", - question=f"Username for {host} with sudo privledges?", - default=arg_cache["username"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - parsed_kwargs["auth_type"] = ask( - question=Question( - var_name="auth_type", - question="Do you want to login with a key or password", - default=arg_cache["auth_type"], - kind="option", - options=["key", "password"], - cache=True, - ), - kwargs=kwargs, - ) - if parsed_kwargs["auth_type"] == "key": - parsed_kwargs["key_path"] = ask( - question=Question( - var_name="key_path", - question=f"Private key to access {parsed_kwargs['username']}@{host}?", - default=arg_cache["key_path"], - kind="path", - cache=True, - ), - kwargs=kwargs, - ) - elif parsed_kwargs["auth_type"] == "password": - parsed_kwargs["password"] = ask( - question=Question( - var_name="password", - question=f"Password for {parsed_kwargs['username']}@{host}?", - kind="password", - ), - kwargs=kwargs, - ) - - auth = None - if host != "localhost": - if parsed_kwargs["auth_type"] == "key": - auth = AuthCredentials( - username=parsed_kwargs["username"], - key_path=parsed_kwargs["key_path"], - ) - else: - auth = AuthCredentials( - username=parsed_kwargs["username"], - key_path=parsed_kwargs["password"], - ) - if not auth.valid: - raise Exception(f"Login Credentials are not valid. {auth}") - parsed_kwargs["ansible_extras"] = kwargs["ansible_extras"] - return create_ansible_land_cmd(verb=verb, auth=auth, kwargs=parsed_kwargs) - else: - errors = [] - if not DEPENDENCIES["ansible-playbook"]: - errors.append("ansible-playbook") - raise MissingDependency( - f"Launching a Custom VM requires: {' '.join(errors)}" - ) - - host_options = ", ".join(allowed_hosts) - raise MissingDependency( - f"Launch requires a correct host option, try: {host_options}" - ) - - -def create_land_docker_cmd(verb: GrammarVerb, prune_volumes: bool = False) -> str: - """ - Create docker `land` command to remove containers when a node's name is specified - """ - node_name = verb.get_named_term_type(name="node_name") - snake_name = str(node_name.snake_input) - - path = GRID_SRC_PATH() - env_var = ";export $(cat .env | sed 's/#.*//g' | xargs);" - - cmd = "" - cmd += "docker compose" - cmd += ' --file "docker-compose.yml"' - cmd += ' --project-name "' + snake_name + '"' - cmd += " down --remove-orphans" - - if prune_volumes: - cmd += ( - f' && docker volume rm $(docker volume ls --filter name="{snake_name}" -q)' - ) - - cmd += f" && docker rm $(docker ps --filter name={snake_name} -q) --force" - - cmd = "cd " + path + env_var + cmd - return cmd - - -@click.command( - help="Stop a running PyGrid domain/network node.", - context_settings={"show_default": True}, -) -@click.argument("args", type=str, nargs=-1) -@click.option( - "--cmd", - is_flag=True, - help="Print the cmd without running it", -) -@click.option( - "--ansible-extras", - default="", - type=str, -) -@click.option( - "--build-src", - default=DEFAULT_BRANCH, - required=False, - type=str, - help="Git branch to use for launch / build operations", -) -@click.option( - "--silent", - is_flag=True, - help="Suppress extra outputs", -) -@click.option( - "--force", - is_flag=True, - help="Bypass the prompt during hagrid land", -) -@click.option( - "--prune-vol", - is_flag=True, - help="Prune docker volumes after land.", -) -def land(args: tuple[str], **kwargs: Any) -> None: - verb = get_land_verb() - silent = bool(kwargs["silent"]) - force = bool(kwargs["force"]) - try: - grammar = parse_grammar(args=args, verb=verb) - verb.load_grammar(grammar=grammar) - except BadGrammar as e: - print(e) - return - - try: - update_repo(repo=GIT_REPO(), branch=str(kwargs["build_src"])) - except Exception as e: - print(f"Failed to update repo. {e}") - - try: - cmd = create_land_cmd(verb=verb, kwargs=kwargs) - except Exception as e: - print(f"{e}") - return - - target = verb.get_named_term_grammar("node_name").input - - if not force: - _land_domain = ask( - Question( - var_name="_land_domain", - question=f"Are you sure you want to land {target} (y/n)", - kind="yesno", - ), - kwargs={}, - ) - - grid_path = GRID_SRC_PATH() - - if force or _land_domain == "y": - if not bool(kwargs["cmd"]): - if not silent: - print("Running: \n", cmd) - try: - if silent: - process = subprocess.Popen( # nosec - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=grid_path, - shell=True, - ) - process.communicate() - - print(f"HAGrid land {target} complete!") - else: - subprocess.call(cmd, shell=True, cwd=grid_path) # nosec - except Exception as e: - print(f"Failed to run cmd: {cmd}. {e}") - else: - print("Hagrid land aborted.") - - -cli.add_command(launch) -cli.add_command(land) -cli.add_command(clean) - - -@click.command( - help="Show HAGrid debug information", context_settings={"show_default": True} -) -@click.argument("args", type=str, nargs=-1) -def debug(args: tuple[str], **kwargs: Any) -> None: - debug_info = gather_debug() - print("\n\nWhen reporting bugs, please copy everything between the lines.") - print("==================================================================\n") - print(json.dumps(debug_info)) - print("\n=================================================================\n\n") - - -cli.add_command(debug) - - -DEFAULT_HEALTH_CHECKS = ["host", "UI (βeta)", "api", "ssh", "jupyter"] -HEALTH_CHECK_FUNCTIONS = { - "host": check_host, - "UI (βeta)": check_login_page, - "api": check_api_metadata, - "ssh": check_ip_for_ssh, - "jupyter": check_jupyter_server, -} - -HEALTH_CHECK_ICONS = { - "host": "🔌", - "UI (βeta)": "🖱", - "api": "⚙️", - "ssh": "🔐", - "jupyter": "📗", -} - -HEALTH_CHECK_URLS = { - "host": "{ip_address}", - "UI (βeta)": "http://{ip_address}/login", - "api": "http://{ip_address}/api/v2/openapi.json", - "ssh": "hagrid ssh {ip_address}", - "jupyter": "http://{ip_address}:8888", -} - - -def check_host_health(ip_address: str, keys: list[str]) -> dict[str, bool]: - status = {} - for key in keys: - func: Callable = HEALTH_CHECK_FUNCTIONS[key] # type: ignore - status[key] = func(ip_address, silent=True) - return status - - -def icon_status(status: bool) -> str: - return "✅" if status else "❌" - - -def get_health_checks(ip_address: str) -> tuple[bool, list[list[str]]]: - keys = list(DEFAULT_HEALTH_CHECKS) - if "localhost" in ip_address: - new_keys = [] - for key in keys: - if key not in ["host", "jupyter", "ssh"]: - new_keys.append(key) - keys = new_keys - - health_status = check_host_health(ip_address=ip_address, keys=keys) - complete_status = all(health_status.values()) - - # find port from ip_address - try: - port = int(ip_address.split(":")[1]) - except Exception: - # default to 80 - port = 80 - - # url to display based on running environment - display_url = gitpod_url(port).split("//")[1] if is_gitpod() else ip_address - - # figure out how to add this back? - # console.print("[bold magenta]Checking host:[/bold magenta]", ip_address, ":mage:") - table_contents = [] - for key, value in health_status.items(): - table_contents.append( - [ - HEALTH_CHECK_ICONS[key], - key, - HEALTH_CHECK_URLS[key].replace("{ip_address}", display_url), - icon_status(value), - ] - ) - - return complete_status, table_contents - - -def create_check_table( - table_contents: list[list[str]], time_left: int = 0 -) -> rich.table.Table: - table = rich.table.Table() - table.add_column("PyGrid", style="magenta") - table.add_column("Info", justify="left", overflow="fold") - time_left_str = "" if time_left == 0 else str(time_left) - table.add_column(time_left_str, justify="left") - for row in table_contents: - table.add_row(row[1], row[2], row[3]) - return table - - -def get_host_name(container_name: str, by_suffix: str) -> str: - # Assumption we always get proxy containers first. - # if users have old docker compose versios. - # the container names are _ instead of - - # canada_proxy_1 instead of canada-proxy-1 - try: - host_name = container_name[0 : container_name.find(by_suffix) - 1] # noqa: E203 - except Exception: - host_name = "" - return host_name - - -def get_docker_status( - ip_address: str, node_name: str | None -) -> tuple[bool, tuple[str, str]]: - url = from_url(ip_address) - port = url[2] - network_container = ( - shell( - "docker ps --format '{{.Names}} {{.Ports}}' | " + f"grep '0.0.0.0:{port}'" - ) - .strip() - .split(" ")[0] - ) - - # Second conditional handle the case when internal port of worker container - # matches with host port of launched Domain/Network Container - if not network_container or (node_name and node_name not in network_container): - # check if it is a worker container and an internal port was passed - worker_containers_output: str = shell( - "docker ps --format '{{.Names}} {{.Ports}}' | " + f"grep '{port}/tcp'" - ).strip() - if not worker_containers_output or not node_name: - return False, ("", "") - - # If there are worker containers with an internal port - # fetch the worker container with the launched worker name - worker_containers = worker_containers_output.split("\n") - for worker_container in worker_containers: - container_name = worker_container.split(" ")[0] - if node_name in container_name: - network_container = container_name - break - else: - # If the worker container is not created yet - return False, ("", "") - - if "proxy" in network_container: - host_name = get_host_name(network_container, by_suffix="proxy") - - backend_containers = shell( - "docker ps --format '{{.Names}}' | grep 'backend' " - ).split() - - _backend_exists = False - for container in backend_containers: - if host_name in container and "stream" not in container: - _backend_exists = True - break - if not _backend_exists: - return False, ("", "") - - node_type = "Domain" - - # TODO: Identify if node_type is Gateway - # for container in headscale_containers: - # if host_name in container: - # node_type = "Gateway" - # break - - return True, (host_name, node_type) - else: - # health check for worker node - host_name = get_host_name(network_container, by_suffix="worker") - return True, (host_name, "Worker") - - -def get_syft_install_status(host_name: str, node_type: str) -> bool: - container_search = "backend" if node_type != "Worker" else "worker" - search_containers = shell( - "docker ps --format '{{.Names}}' | " + f"grep '{container_search}' " - ).split() - - context_container = None - for container in search_containers: - # stream keyword is for our old container stack - if host_name in container and "stream" not in container: - context_container = container - break - - if not context_container: - print(f"❌ {container_search} Docker Stack for: {host_name} not found") - exit(0) - else: - container_log = shell(f"docker logs {context_container}") - if "Application startup complete" not in container_log: - return False - return True - - -@click.command( - help="Check health of an IP address/addresses or a resource group", - context_settings={"show_default": True}, -) -@click.argument("ip_addresses", type=str, nargs=-1) -@click.option( - "--timeout", - default=300, - help="Timeout for hagrid check command", -) -@click.option( - "--verbose", - is_flag=True, - help="Refresh output", -) -def check( - ip_addresses: list[str], verbose: bool = False, timeout: int | str = 300 -) -> None: - check_status(ip_addresses=ip_addresses, silent=not verbose, timeout=timeout) - - -def _check_status( - ip_addresses: str | list[str], - silent: bool = True, - signal: Event | None = None, - node_name: str | None = None, -) -> None: - OK_EMOJI = RichEmoji("white_heavy_check_mark").to_str() - # Check if ip_addresses is str, then convert to list - if ip_addresses and isinstance(ip_addresses, str): - ip_addresses = [ip_addresses] - console = Console() - node_info = None - if len(ip_addresses) == 0: - headers = {"User-Agent": "curl/7.79.1"} - print("Detecting External IP...") - ip_res = requests.get("https://ifconfig.co", headers=headers) # nosec - ip_address = ip_res.text.strip() - ip_addresses = [ip_address] - - if len(ip_addresses) == 1: - ip_address = ip_addresses[0] - status, table_contents = get_health_checks(ip_address=ip_address) - table = create_check_table(table_contents=table_contents) - max_timeout = 600 - if not status: - table = create_check_table( - table_contents=table_contents, time_left=max_timeout - ) - if silent: - with console.status("Gathering Node information") as console_status: - console_status.update( - "[bold orange_red1]Waiting for Container Creation" - ) - docker_status, node_info = get_docker_status(ip_address, node_name) - while not docker_status: - docker_status, node_info = get_docker_status( - ip_address, node_name - ) - time.sleep(1) - if ( - signal and signal.is_set() - ): # Stop execution if timeout is triggered - return - console.print( - f"{OK_EMOJI} {node_info[0]} {node_info[1]} Containers Created" - ) - - console_status.update("[bold orange_red1]Starting Backend") - syft_install_status = get_syft_install_status( - node_info[0], node_info[1] - ) - while not syft_install_status: - syft_install_status = get_syft_install_status( - node_info[0], node_info[1] - ) - time.sleep(1) - # Stop execution if timeout is triggered - if signal and signal.is_set(): - return - console.print(f"{OK_EMOJI} Backend") - console.print(f"{OK_EMOJI} Startup Complete") - - status, table_contents = get_health_checks(ip_address) - table = create_check_table( - table_contents=table_contents, time_left=max_timeout - ) - else: - while not status: - # Stop execution if timeout is triggered - if signal is not None and signal.is_set(): - return - with Live( - table, refresh_per_second=2, screen=True, auto_refresh=False - ) as live: - max_timeout -= 1 - if max_timeout % 5 == 0: - status, table_contents = get_health_checks(ip_address) - table = create_check_table( - table_contents=table_contents, time_left=max_timeout - ) - live.update(table) - if status: - break - time.sleep(1) - - # TODO: Create new health checks table for Worker Container - if (node_info and node_info[1] != "Worker") or not node_info: - console.print(table) - else: - for ip_address in ip_addresses: - _, table_contents = get_health_checks(ip_address) - table = create_check_table(table_contents=table_contents) - console.print(table) - - -def check_status( - ip_addresses: str | list[str], - silent: bool = True, - timeout: int | str = 300, - node_name: str | None = None, -) -> None: - timeout = int(timeout) - # third party - from rich import print - - signal = Event() - - t = Thread( - target=_check_status, - kwargs={ - "ip_addresses": ip_addresses, - "silent": silent, - "signal": signal, - "node_name": node_name, - }, - ) - t.start() - t.join(timeout=timeout) - - if t.is_alive(): - signal.set() - t.join() - - print(f"Hagrid check command timed out after: {timeout} seconds 🕛") - print( - "Please try increasing the timeout or kindly check the docker containers for error logs." - ) - print("You can view your container logs using the following tool:") - print("Tool: [link=https://ctop.sh]Ctop[/link]") - print("Video Explanation: https://youtu.be/BJhlCxerQP4 \n") - - -cli.add_command(check) - - -# add Hagrid info to the cli -@click.command(help="Show HAGrid info", context_settings={"show_default": True}) -def version() -> None: - print(f"HAGRID_VERSION: {get_version_string()}") - if EDITABLE_MODE: - print(f"HAGRID_REPO_SHA: {commit_hash()}") - - -cli.add_command(version) - - -def run_quickstart( - url: str | None = None, - syft: str = "latest", - reset: bool = False, - quiet: bool = False, - pre: bool = False, - test: bool = False, - repo: str = DEFAULT_REPO, - branch: str = DEFAULT_BRANCH, - commit: str | None = None, - python: str | None = None, - zip_file: str | None = None, -) -> None: - try: - quickstart_art() - directory = os.path.expanduser("~/.hagrid/quickstart/") - confirm_reset = None - if reset: - if not quiet: - confirm_reset = click.confirm( - "This will create a new quickstart virtualenv and reinstall Syft and " - "Jupyter. Are you sure you want to continue?" - ) - else: - confirm_reset = True - if confirm_reset is False: - return - - if reset and confirm_reset or not os.path.isdir(directory): - quickstart_setup( - directory=directory, - syft_version=syft, - reset=reset, - pre=pre, - python=python, - ) - downloaded_files = [] - if zip_file: - downloaded_files = fetch_notebooks_from_zipfile( - zip_file, - directory=directory, - reset=reset, - ) - elif url: - downloaded_files = fetch_notebooks_for_url( - url=url, - directory=directory, - reset=reset, - repo=repo, - branch=branch, - commit=commit, - ) - else: - file_path = add_intro_notebook(directory=directory, reset=reset) - downloaded_files.append(file_path) - - if len(downloaded_files) == 0: - raise Exception(f"Unable to find files at: {url}") - file_path = sorted(downloaded_files)[0] - - # add virtualenv path - environ = os.environ.copy() - os_bin_path = "Scripts" if is_windows() else "bin" - venv_dir = directory + ".venv" - environ["PATH"] = venv_dir + os.sep + os_bin_path + os.pathsep + environ["PATH"] - jupyter_binary = "jupyter.exe" if is_windows() else "jupyter" - - if is_windows(): - env_activate_cmd = ( - "(Powershell): " - + "cd " - + venv_dir - + "; " - + os_bin_path - + os.sep - + "activate" - ) - else: - env_activate_cmd = ( - "(Linux): source " + venv_dir + os.sep + os_bin_path + "/activate" - ) - - print(f"To activate your virtualenv {env_activate_cmd}") - - try: - allow_browser = " --no-browser" if is_gitpod() else "" - cmd = ( - venv_dir - + os.sep - + os_bin_path - + os.sep - + f"{jupyter_binary} lab{allow_browser} --ip 0.0.0.0 --notebook-dir={directory} {file_path}" - ) - if test: - jupyter_path = venv_dir + os.sep + os_bin_path + os.sep + jupyter_binary - if not os.path.exists(jupyter_path): - print(f"Failed to install Jupyter in path: {jupyter_path}") - sys.exit(1) - print(f"Jupyter exists at: {jupyter_path}. CI Test mode exiting.") - sys.exit(0) - - disable_toolbar_extension = ( - venv_dir - + os.sep - + os_bin_path - + os.sep - + f"{jupyter_binary} labextension disable @jupyterlab/cell-toolbar-extension" - ) - - subprocess.run( # nosec - disable_toolbar_extension.split(" "), cwd=directory, env=environ - ) - - ON_POSIX = "posix" in sys.builtin_module_names - - def enqueue_output(out: Any, queue: Queue) -> None: - for line in iter(out.readline, b""): - queue.put(line) - out.close() - - proc = subprocess.Popen( # nosec - cmd.split(" "), - cwd=directory, - env=environ, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - close_fds=ON_POSIX, - ) - queue: Queue = Queue() - thread_1 = Thread(target=enqueue_output, args=(proc.stdout, queue)) - thread_2 = Thread(target=enqueue_output, args=(proc.stderr, queue)) - thread_1.daemon = True # thread dies with the program - thread_1.start() - thread_2.daemon = True # thread dies with the program - thread_2.start() - - display_url = None - console = rich.get_console() - - # keepn reading the queue of stdout + stderr - while True: - try: - if not display_url: - # try to read the line and extract a jupyter url: - with console.status( - "Starting Jupyter service" - ) as console_status: - line = queue.get() - display_url = extract_jupyter_url(line.decode("utf-8")) - if display_url: - display_jupyter_url(url_parts=display_url) - console_status.stop() - except KeyboardInterrupt: - proc.kill() # make sure jupyter gets killed - sys.exit(1) - except Exception: # nosec - pass # nosec - except KeyboardInterrupt: - proc.kill() # make sure jupyter gets killed - sys.exit(1) - except Exception as e: - print(f"Error running quickstart: {e}") - raise e - - -@click.command( - help="Launch a Syft + Jupyter Session with a Notebook URL / Path", - context_settings={"show_default": True}, -) -@click.argument("url", type=str, required=False) -@click.option( - "--reset", - is_flag=True, - default=False, - help="Force hagrid quickstart to setup a fresh virtualenv", -) -@click.option( - "--syft", - default="latest", - help="Choose a syft version or just use latest", -) -@click.option( - "--quiet", - is_flag=True, - help="Silence confirmation prompts", -) -@click.option( - "--pre", - is_flag=True, - help="Install pre-release versions of syft", -) -@click.option( - "--python", - default=None, - help="Specify the path to which python to use", -) -@click.option( - "--test", - is_flag=True, - help="CI Test Mode, don't hang on Jupyter", -) -@click.option( - "--repo", - default=DEFAULT_REPO, - help="Choose a repo to fetch the notebook from or just use OpenMined/PySyft", -) -@click.option( - "--branch", - default=DEFAULT_BRANCH, - help="Choose a branch to fetch from or just use dev", -) -@click.option( - "--commit", - help="Choose a specific commit to fetch the notebook from", -) -def quickstart_cli( - url: str | None = None, - syft: str = "latest", - reset: bool = False, - quiet: bool = False, - pre: bool = False, - test: bool = False, - repo: str = DEFAULT_REPO, - branch: str = DEFAULT_BRANCH, - commit: str | None = None, - python: str | None = None, -) -> None: - return run_quickstart( - url=url, - syft=syft, - reset=reset, - quiet=quiet, - pre=pre, - test=test, - repo=repo, - branch=branch, - commit=commit, - python=python, - ) - - -cli.add_command(quickstart_cli, "quickstart") - - -def display_jupyter_url(url_parts: tuple[str, str, int]) -> None: - url = url_parts[0] - if is_gitpod(): - parts = urlparse(url) - query = getattr(parts, "query", "") - url = gitpod_url(port=url_parts[2]) + "?" + query - - console = rich.get_console() - - tick_emoji = RichEmoji("white_heavy_check_mark").to_str() - link_emoji = RichEmoji("link").to_str() - - console.print( - f"[bold white]{tick_emoji} Jupyter Server is running at:\n{link_emoji} [bold blue]{url}\n" - + "[bold white]Use Control-C to stop this server and shut down all kernels.", - new_line_start=True, - ) - - # if is_gitpod(): - # open_browser_with_url(url=url) - - -def open_browser_with_url(url: str) -> None: - webbrowser.open(url) - - -def extract_jupyter_url(line: str) -> tuple[str, str, int] | None: - jupyter_regex = r"^.*(http.*127.*)" - try: - matches = re.match(jupyter_regex, line) - if matches is not None: - url = matches.group(1).strip() - parts = urlparse(url) - host_or_ip_parts = parts.netloc.split(":") - # netloc is host:port - port = 8888 - if len(host_or_ip_parts) > 1: - port = int(host_or_ip_parts[1]) - host_or_ip = host_or_ip_parts[0] - return (url, host_or_ip, port) - except Exception as e: - print("failed to parse jupyter url", e) - return None - - -def quickstart_setup( - directory: str, - syft_version: str, - reset: bool = False, - pre: bool = False, - python: str | None = None, -) -> None: - console = rich.get_console() - OK_EMOJI = RichEmoji("white_heavy_check_mark").to_str() - - try: - with console.status( - "[bold blue]Setting up Quickstart Environment" - ) as console_status: - os.makedirs(directory, exist_ok=True) - virtual_env_dir = os.path.abspath(directory + ".venv/") - if reset and os.path.exists(virtual_env_dir): - shutil.rmtree(virtual_env_dir) - env = VirtualEnvironment(virtual_env_dir, python=python) - console.print( - f"{OK_EMOJI} Created Virtual Environment {RichEmoji('evergreen_tree').to_str()}" - ) - - # upgrade pip - console_status.update("[bold blue]Installing pip") - env.install("pip", options=["-U"]) - console.print(f"{OK_EMOJI} pip") - - # upgrade packaging - console_status.update("[bold blue]Installing packaging") - env.install("packaging", options=["-U"]) - console.print(f"{OK_EMOJI} packaging") - - # Install jupyter lab - console_status.update("[bold blue]Installing Jupyter Lab") - env.install("jupyterlab") - env.install("ipywidgets") - console.print(f"{OK_EMOJI} Jupyter Lab") - - # Install hagrid - if EDITABLE_MODE: - local_hagrid_dir = Path( - os.path.abspath(Path(hagrid_root()) / "../hagrid") - ) - console_status.update( - f"[bold blue]Installing HAGrid in Editable Mode: {str(local_hagrid_dir)}" - ) - env.install("-e " + str(local_hagrid_dir)) - console.print( - f"{OK_EMOJI} HAGrid in Editable Mode: {str(local_hagrid_dir)}" - ) - else: - console_status.update("[bold blue]Installing hagrid") - env.install("hagrid", options=["-U"]) - console.print(f"{OK_EMOJI} HAGrid") - except Exception as e: - print(e) - raise e - - -def add_intro_notebook(directory: str, reset: bool = False) -> str: - filenames = ["00-quickstart.ipynb", "01-install-wizard.ipynb"] - - files = os.listdir(directory) - try: - files.remove(".venv") - except Exception: # nosec - pass - - existing = 0 - for file in files: - if file in filenames: - existing += 1 - - if existing != len(filenames) or reset: - if EDITABLE_MODE: - local_src_dir = Path(os.path.abspath(Path(hagrid_root()) / "../../")) - for filename in filenames: - file_path = os.path.abspath(f"{directory}/{filename}") - shutil.copyfile( - local_src_dir / f"notebooks/quickstart/{filename}", - file_path, - ) - else: - for filename in filenames: - url = ( - "https://raw.githubusercontent.com/OpenMined/PySyft/dev/" - + f"notebooks/quickstart/{filename}" - ) - file_path, _, _ = quickstart_download_notebook( - url=url, directory=directory, reset=reset - ) - if arg_cache["install_wizard_complete"]: - filename = filenames[0] - else: - filename = filenames[1] - return os.path.abspath(f"{directory}/{filename}") - - -@click.command(help="Walk the Path", context_settings={"show_default": True}) -@click.argument("zip_file", type=str, default="padawan.zip", metavar="ZIPFILE") -def dagobah(zip_file: str) -> None: - if not os.path.exists(zip_file): - for text in ( - f"{zip_file} does not exists.", - "Please specify the path to the zip file containing the notebooks.", - "hagrid dagobah [ZIPFILE]", - ): - print(text, file=sys.stderr) - sys.exit(1) - - return run_quickstart(zip_file=zip_file) - - -cli.add_command(dagobah) - - -def ssh_into_remote_machine( - host_ip: str, - username: str, - auth_type: str, - private_key_path: str | None, - cmd: str = "", -) -> None: - """Access or execute command on the remote machine. - - Args: - host_ip (str): ip address of the VM - private_key_path (str): private key of the VM - username (str): username on the VM - cmd (str, optional): Command to execute on the remote machine. Defaults to "". - """ - try: - if auth_type == "key": - subprocess.call( # nosec - ["ssh", "-i", f"{private_key_path}", f"{username}@{host_ip}", cmd] - ) - elif auth_type == "password": - subprocess.call(["ssh", f"{username}@{host_ip}", cmd]) # nosec - except Exception as e: - raise e - - -@click.command( - help="SSH into the IP address or a resource group", - context_settings={"show_default": True}, -) -@click.argument("ip_address", type=str) -@click.option( - "--cmd", - type=str, - required=False, - default="", - help="Optional: command to execute on the remote machine.", -) -def ssh(ip_address: str, cmd: str) -> None: - kwargs: dict = {} - key_path: str | None = None - - if check_ip_for_ssh(ip_address, timeout=10, silent=False): - username = ask( - question=Question( - var_name="azure_username", - question="What is the username for the VM?", - default=arg_cache["azure_username"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - auth_type = ask( - question=Question( - var_name="auth_type", - question="Do you want to login with a key or password", - default=arg_cache["auth_type"], - kind="option", - options=["key", "password"], - cache=True, - ), - kwargs=kwargs, - ) - - if auth_type == "key": - key_path = ask( - question=Question( - var_name="azure_key_path", - question="Absolute path to the private key of the VM?", - default=arg_cache["azure_key_path"], - kind="string", - cache=True, - ), - kwargs=kwargs, - ) - - # SSH into the remote and execute the command - ssh_into_remote_machine( - host_ip=ip_address, - username=username, - auth_type=auth_type, - private_key_path=key_path, - cmd=cmd, - ) - - -cli.add_command(ssh) - - -# Add hagrid logs command to the CLI -@click.command( - help="Get the logs of the HAGrid node", context_settings={"show_default": True} -) -@click.argument("domain_name", type=str) -def logs(domain_name: str) -> None: # nosec - container_ids = ( - subprocess.check_output( # nosec - f"docker ps -qf name=^{domain_name}-*", shell=True - ) - .decode("utf-8") - .split() - ) - Container = namedtuple("Container", "id name logs") - container_names = [] - for container in container_ids: - container_name = ( - subprocess.check_output( # nosec - "docker inspect --format '{{.Name}}' " + container, shell=True - ) - .decode("utf-8") - .strip() - .replace("/", "") - ) - log_command = "docker logs -f " + container_name - container_names.append( - Container(id=container, name=container_name, logs=log_command) - ) - # Generate a table of the containers and their logs with Rich - table = rich.table.Table(title="Container Logs") - table.add_column("Container ID", justify="center", style="cyan", no_wrap=True) - table.add_column("Container Name", justify="right", style="cyan", no_wrap=True) - table.add_column("Log Command", justify="right", style="cyan", no_wrap=True) - for container in container_names: # type: ignore - table.add_row(container.id, container.name, container.logs) # type: ignore - console = rich.console.Console() - console.print(table) - # Print instructions on how to view the logs - console.print( - rich.panel.Panel( - long_string, - title="How to view logs", - border_style="white", - expand=False, - padding=1, - highlight=True, - ) - ) - - -long_string = ( - "ℹ [bold green]To view the live logs of a container,copy the log command and paste it into your terminal.[/bold green]\n" # noqa: E501 - + "\n" - + "ℹ [bold green]The logs will be streamed to your terminal until you exit the command.[/bold green]\n" - + "\n" - + "ℹ [bold green]To exit the logs, press CTRL+C.[/bold green]\n" - + "\n" - + "🚨 The [bold white]backend,backend_stream & celery[/bold white] [bold green]containers are the most important to monitor for debugging.[/bold green]\n" # noqa: E501 - + "\n" - + " [bold white]--------------- Ctop 🦾 -------------------------[/bold white]\n" - + "\n" - + "🧠 To learn about using [bold white]ctop[/bold white] to monitor your containers,visit https://www.youtube.com/watch?v=BJhlCxerQP4n \n" # noqa: E501 - + "\n" - + " [bold white]----------------- How to view this. 🙂 ---------------[/bold white]\n" - + "\n" - + """ℹ [bold green]To view this panel again, run the command [bold white]hagrid logs {{NODE_NAME}}[/bold white] [/bold green]\n""" # noqa: E501 - + "\n" - + """🚨 NODE_NAME above is the name of your Hagrid deployment,without the curly braces. E.g hagrid logs canada [bold green]\n""" # noqa: E501 - + "\n" - + " [bold green]HAPPY DEBUGGING! 🐛🐞🦗🦟🦠🦠🦠[/bold green]\n " -) - -cli.add_command(logs) From 2df1656944aefb630f328e0799a538b1c1a04475 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 18:31:06 +0530 Subject: [PATCH 092/309] lint notebook --- notebooks/Experimental/Network.ipynb | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/notebooks/Experimental/Network.ipynb b/notebooks/Experimental/Network.ipynb index 88240421fb3..d12b95f7129 100644 --- a/notebooks/Experimental/Network.ipynb +++ b/notebooks/Experimental/Network.ipynb @@ -7,6 +7,7 @@ "metadata": {}, "outputs": [], "source": [ + "# syft absolute\n", "import syft as sy" ] }, @@ -37,7 +38,9 @@ } ], "source": [ - "gateway_client = sy.login(url=\"http://localhost\", port=9081, email=\"info@openmined.org\", password=\"changethis\")" + "gateway_client = sy.login(\n", + " url=\"http://localhost\", port=9081, email=\"info@openmined.org\", password=\"changethis\"\n", + ")" ] }, { @@ -67,7 +70,9 @@ } ], "source": [ - "domain_client = sy.login(url=\"http://localhost\", port=9082, email=\"info@openmined.org\", password=\"changethis\")" + "domain_client = sy.login(\n", + " url=\"http://localhost\", port=9082, email=\"info@openmined.org\", password=\"changethis\"\n", + ")" ] }, { From cec992b993c5fbac0a7226d01d3f9f19d60608d1 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 10 Jun 2024 18:32:45 +0530 Subject: [PATCH 093/309] remove output from Network notebook --- notebooks/Experimental/Network.ipynb | 7888 +------------------------- 1 file changed, 28 insertions(+), 7860 deletions(-) diff --git a/notebooks/Experimental/Network.ipynb b/notebooks/Experimental/Network.ipynb index d12b95f7129..7a1f3f257dc 100644 --- a/notebooks/Experimental/Network.ipynb +++ b/notebooks/Experimental/Network.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "bd9a2226-3e53-4f27-9213-75a8c3ff9176", "metadata": {}, "outputs": [], @@ -13,30 +13,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "fddf8d07-d154-4284-a27b-d74e35d3f851", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logged into as \n" - ] - }, - { - "data": { - "text/html": [ - "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" - ], - "text/plain": [ - "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "gateway_client = sy.login(\n", " url=\"http://localhost\", port=9081, email=\"info@openmined.org\", password=\"changethis\"\n", @@ -45,30 +25,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "8f7b106d-b784-45d8-b54d-4ce2de2da453", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logged into as \n" - ] - }, - { - "data": { - "text/html": [ - "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" - ], - "text/plain": [ - "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "domain_client = sy.login(\n", " url=\"http://localhost\", port=9082, email=\"info@openmined.org\", password=\"changethis\"\n", @@ -77,2556 +37,27 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "ff504949-620d-4e26-beee-0d39e0e502eb", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
SyftSuccess: Connected domain 'syft-dev-node' to gateway 'syft-dev-node'. Routes Exchanged

" - ], - "text/plain": [ - "SyftSuccess: Connected domain 'syft-dev-node' to gateway 'syft-dev-node'. Routes Exchanged" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "domain_client.connect_to_gateway(gateway_client, reverse_tunnel=True)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "ba7bc71a-4e6a-4429-9588-7b3d0ed19e27", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "\n", - "
\n", - "
\n", - " \n", - "
\n", - "

Request List

\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

Total: 0

\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gateway_client.api.services.request" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "5b4984e1-331e-4fd8-b012-768fc613f48a", "metadata": {}, "outputs": [], @@ -2636,2525 +67,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "90dc44bd", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "\n", - "
\n", - "
\n", - " \n", - "
\n", - "

NodePeer List

\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

Total: 0

\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "" - ], - "text/plain": [ - "[syft.service.network.node_peer.NodePeer]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "node_peers = gateway_client.api.network.get_all_peers()\n", "node_peers" @@ -5162,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "8c06aaa6-4157-42d1-959f-9d47722a3420", "metadata": {}, "outputs": [], @@ -5172,88 +88,37 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "cb63a77b", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[syft.service.network.routes.HTTPNodeRoute]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "node_peer.node_routes" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "61882e86", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'syft_node_location': ,\n", - " 'syft_client_verify_key': cc82ec7abd5d516e6972e787ddaafe8b04e223228436b09c026be97b80ad6246,\n", - " 'id': None,\n", - " 'host_or_ip': 'syft-dev-node.syft.local',\n", - " 'private': False,\n", - " 'protocol': 'http',\n", - " 'port': 9082,\n", - " 'proxy_target_uid': None,\n", - " 'priority': 1,\n", - " 'rathole_token': 'b95e8d239d563e6fcc3a4f44a5292177e608a7b0b1194e6106adc1998a1b68a1'}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "node_peer.node_routes[0].__dict__" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "fb19dbc6-869b-46dc-92e3-5e75ee6d0b06", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'syft_node_location': ,\n", - " 'syft_client_verify_key': 5ed40db70e275e30e9001cda808922c4542c84f7472a106c00158795b9388b0a,\n", - " 'id': None,\n", - " 'host_or_ip': 'host.k3d.internal',\n", - " 'private': False,\n", - " 'protocol': 'http',\n", - " 'port': 9081,\n", - " 'proxy_target_uid': None,\n", - " 'priority': 1,\n", - " 'rathole_token': None}" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "domain_client.api.network.get_all_peers()[0].node_routes[0].__dict__" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "32d09a51", "metadata": {}, "outputs": [], @@ -5263,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "b7d9e41d", "metadata": {}, "outputs": [], @@ -5273,2706 +138,27 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "8fa24ec7", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \"Logo\"\n",\n", - "

Welcome to syft-dev-node

\n", - "
\n", - " URL: http://localhost:9081
Node Type: Gateway
Node Side Type: High Side
Syft Version: 0.8.7-beta.10
\n", - "
\n", - "
\n", - " ⓘ \n", - " This node is run by the library PySyft to learn more about how it works visit\n", - " github.com/OpenMined/PySyft.\n", - "
\n", - "

Commands to Get Started

\n", - " \n", - "
    \n", - " \n", - "
  • <your_client>\n", - " .domains - list domains connected to this gateway
  • \n", - "
  • <your_client>\n", - " .proxy_client_for - get a connection to a listed domain
  • \n", - "
  • <your_client>\n", - " .login - log into the gateway
  • \n", - " \n", - "
\n", - " \n", - "

\n", - " " - ], - "text/plain": [ - ": HTTPConnection: http://localhost:9081>" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gateway_client" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "3a081250-abc3-43a3-9e06-ff0c3a362ebf", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "\n", - "
\n", - "
\n", - " \n", - "
\n", - "

NodePeer List

\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

Total: 0

\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "" - ], - "text/markdown": [ - "```python\n", - "class ProxyClient:\n", - " id: str = 37073e9151ce4fa9b665501ec03924c8\n", - "\n", - "```" - ], - "text/plain": [ - "syft.client.gateway_client.ProxyClient" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "gateway_client.peers" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "b6fedfe4-9362-47c9-9342-5cf6eacde8ab", "metadata": {}, "outputs": [], @@ -7982,28 +168,10 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "f1940e00-0337-4b56-88c2-d70f397a7016", "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "```python\n", - "class HTTPConnection:\n", - " id: str = None\n", - "\n", - "```" - ], - "text/plain": [ - "HTTPConnection: http://localhost:9081" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "domain_client_proxy.connection" ] From ecf3f04ef42a2a7f646fce74698bd73e02fa3613 Mon Sep 17 00:00:00 2001 From: dk Date: Tue, 11 Jun 2024 10:33:34 +0700 Subject: [PATCH 094/309] [syft/action_obj] refactor `convert_to_pointers` - getting size of twin object before client uploading dataset --- .../syft/src/syft/client/domain_client.py | 7 +++ .../src/syft/service/action/action_object.py | 57 +++++++------------ 2 files changed, 26 insertions(+), 38 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 1c768710040..84040b588bb 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -5,6 +5,7 @@ from pathlib import Path import re from string import Template +import sys from typing import TYPE_CHECKING from typing import cast @@ -17,6 +18,7 @@ # relative from ..abstract_node import NodeSideType from ..serde.serializable import serializable +from ..serde.serialize import _serialize as serialize from ..service.action.action_object import ActionObject from ..service.code_history.code_history import CodeHistoriesDict from ..service.code_history.code_history import UsersCodeHistoriesDict @@ -134,6 +136,11 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: syft_node_location=self.id, syft_client_verify_key=self.verify_key, ) + serialized: bytes = serialize(twin, to_bytes=True) + size_mb: float = sys.getsizeof(serialized) / 1024 / 1024 + if size_mb < 16: + print(f"object's size = {size_mb} (MB), less than 16 MB") + # TODO: if less than 16 MB, save without using blob storage twin._save_to_blob_storage() except Exception as e: tqdm.write(f"Failed to create twin for {asset.name}. {e}") diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index dffa3d3d9de..05324ffdb1f 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -498,44 +498,25 @@ def convert_to_pointers( # relative from ..dataset.dataset import Asset - arg_list = [] - kwarg_dict = {} - if args is not None: - for arg in args: - if ( - not isinstance(arg, ActionObject | Asset | UID) - and api.signing_key is not None # type: ignore[unreachable] - ): - arg = ActionObject.from_obj( # type: ignore[unreachable] - syft_action_data=arg, - syft_client_verify_key=api.signing_key.verify_key, - syft_node_location=api.node_uid, - ) - arg.syft_node_uid = node_uid - r = arg._save_to_blob_storage() - if isinstance(r, SyftError): - print(r.message) - arg = api.services.action.set(arg) - arg_list.append(arg) - - if kwargs is not None: - for k, arg in kwargs.items(): - if ( - not isinstance(arg, ActionObject | Asset | UID) - and api.signing_key is not None # type: ignore[unreachable] - ): - arg = ActionObject.from_obj( # type: ignore[unreachable] - syft_action_data=arg, - syft_client_verify_key=api.signing_key.verify_key, - syft_node_location=api.node_uid, - ) - arg.syft_node_uid = node_uid - r = arg._save_to_blob_storage() - if isinstance(r, SyftError): - print(r.message) - arg = api.services.action.set(arg) - - kwarg_dict[k] = arg + def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: + if ( + not isinstance(arg, ActionObject | Asset | UID) + and api.signing_key is not None # type: ignore[unreachable] + ): + arg = ActionObject.from_obj( # type: ignore[unreachable] + syft_action_data=arg, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) + arg.syft_node_uid = node_uid + r = arg._save_to_blob_storage() + if isinstance(r, SyftError): + print(r.message) + arg = api.services.action.set(arg) + return arg + + arg_list = [process_arg(arg) for arg in args] if args else [] + kwarg_dict = {k: process_arg(v) for k, v in kwargs.items()} if kwargs else {} return arg_list, kwarg_dict From ed094f4e8e8488dd60dee60d1198a8848e4f7a7a Mon Sep 17 00:00:00 2001 From: khoaguin Date: Tue, 11 Jun 2024 15:38:41 +0700 Subject: [PATCH 095/309] [syft/action_obj] skip saving to blob storage if `syft_action_data` is less than `min_size_mb` (16 Mb by default) --- .../syft/src/syft/client/domain_client.py | 7 ------ .../src/syft/service/action/action_object.py | 24 +++++++++++++++---- packages/syft/src/syft/util/util.py | 5 ++++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 84040b588bb..1c768710040 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -5,7 +5,6 @@ from pathlib import Path import re from string import Template -import sys from typing import TYPE_CHECKING from typing import cast @@ -18,7 +17,6 @@ # relative from ..abstract_node import NodeSideType from ..serde.serializable import serializable -from ..serde.serialize import _serialize as serialize from ..service.action.action_object import ActionObject from ..service.code_history.code_history import CodeHistoriesDict from ..service.code_history.code_history import UsersCodeHistoriesDict @@ -136,11 +134,6 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: syft_node_location=self.id, syft_client_verify_key=self.verify_key, ) - serialized: bytes = serialize(twin, to_bytes=True) - size_mb: float = sys.getsizeof(serialized) / 1024 / 1024 - if size_mb < 16: - print(f"object's size = {size_mb} (MB), less than 16 MB") - # TODO: if less than 16 MB, save without using blob storage twin._save_to_blob_storage() except Exception as e: tqdm.write(f"Failed to create twin for {asset.name}. {e}") diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 05324ffdb1f..34673bbd224 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -47,6 +47,7 @@ from ...types.uid import LineageID from ...types.uid import UID from ...util.logger import debug +from ...util.util import get_mb_serialized_size from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..response import SyftException @@ -828,17 +829,30 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: return None - def _save_to_blob_storage(self) -> SyftError | None: + def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: + """ " + If less than min_size_mb, skip saving to blob storage + TODO: min_size_mb shoulb be passed as a env var + """ data = self.syft_action_data if isinstance(data, SyftError): return data if isinstance(data, ActionDataEmpty): - return SyftError(message=f"cannot store empty object {self.id}") - result = self._save_to_blob_storage_(data) - if isinstance(result, SyftError): - return result + return SyftError( + message=f"cannot store empty object {self.id} to the blob storage" + ) if not TraceResultRegistry.current_thread_is_tracing(): self.syft_action_data_cache = self.as_empty_data() + action_data_size_mb: float = get_mb_serialized_size(data) + if action_data_size_mb > min_size_mb: + result = self._save_to_blob_storage_(data) + if isinstance(result, SyftError): + return result + else: + debug( + f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), less than {min_size_mb} (MB). " + f"Skip saving to blob storage." + ) return None @property diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index b0affa2b1a0..99aba539cd9 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -38,6 +38,7 @@ import requests # relative +from ..serde.serialize import _serialize as serialize from .logger import critical from .logger import debug from .logger import error @@ -92,6 +93,10 @@ def get_mb_size(data: Any) -> float: return sys.getsizeof(data) / (1024 * 1024) +def get_mb_serialized_size(data: Any) -> float: + return sys.getsizeof(serialize(data)) / (1024 * 1024) + + def extract_name(klass: type) -> str: name_regex = r".+class.+?([\w\._]+).+" regex2 = r"([\w\.]+)" From a6334ab8fc3ee12c9935e4d020df3e166c362505 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Tue, 11 Jun 2024 17:27:51 +0700 Subject: [PATCH 096/309] [syft/action_obj] set `syft_action_data_cache` for `ActionObject` if data is less than the min size to save to blob storage --- packages/syft/src/syft/service/action/action_object.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 34673bbd224..e77e3498cc6 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -841,8 +841,6 @@ def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) - if not TraceResultRegistry.current_thread_is_tracing(): - self.syft_action_data_cache = self.as_empty_data() action_data_size_mb: float = get_mb_serialized_size(data) if action_data_size_mb > min_size_mb: result = self._save_to_blob_storage_(data) @@ -853,6 +851,9 @@ def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), less than {min_size_mb} (MB). " f"Skip saving to blob storage." ) + self.syft_action_data_cache = data + # if not TraceResultRegistry.current_thread_is_tracing(): + # self.syft_action_data_cache = self.as_empty_data() return None @property From 0db775f4969c02eb126543fc83d459d6f750dc06 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 11 Jun 2024 15:58:42 +0530 Subject: [PATCH 097/309] ignore prefix for rathole in blob store path --- packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 748dfee78c4..705c153442a 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -43,7 +43,7 @@ data: - "web" service: "backend" blob-storage: - rule: "PathPrefix(`/blob`)" + rule: "PathPrefix(`/blob`) && !PathPrefix(`/rathole`)" entryPoints: - "web" service: "seaweedfs" From 7335c567cfd6a63da15fa3eaec740d2a3153d788 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 12 Jun 2024 09:13:07 +0700 Subject: [PATCH 098/309] [syft/action_obj] remove redundant setting syft_action_data_cache to ActionDataEmpty when current thread is not tracing --- packages/syft/src/syft/service/action/action_object.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index e77e3498cc6..a9e68d68da6 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -785,8 +785,8 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: size = sys.getsizeof(serialized) storage_entry = CreateBlobStorageEntry.from_obj(data, file_size=size) - if not TraceResultRegistry.current_thread_is_tracing(): - self.syft_action_data_cache = self.as_empty_data() + # if not TraceResultRegistry.current_thread_is_tracing(): + # self.syft_action_data_cache = self.as_empty_data() if self.syft_blob_storage_entry_id is not None: # TODO: check if it already exists storage_entry.id = self.syft_blob_storage_entry_id @@ -846,14 +846,14 @@ def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: result = self._save_to_blob_storage_(data) if isinstance(result, SyftError): return result + if not TraceResultRegistry.current_thread_is_tracing(): + self.syft_action_data_cache = self.as_empty_data() else: debug( f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), less than {min_size_mb} (MB). " f"Skip saving to blob storage." ) self.syft_action_data_cache = data - # if not TraceResultRegistry.current_thread_is_tracing(): - # self.syft_action_data_cache = self.as_empty_data() return None @property From 1505f22bccc5c8098361672d861b715b5224e078 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 12 Jun 2024 09:51:52 +0700 Subject: [PATCH 099/309] [syft/blob_storage] print out on-disk blob store path in dev mode --- packages/syft/src/syft/node/node.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 1e2c00c6f24..65a2cad07f3 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -497,6 +497,15 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: remote_profile.profile_name ] = remote_profile + if ( + isinstance(self.blob_store_config, OnDiskBlobStorageConfig) + and self.dev_mode + ): + print( + f"Using on-disk blob storage with path: " + f"{self.blob_store_config.client_config.base_directory}" + ) + def run_peer_health_checks(self, context: AuthedServiceContext) -> None: self.peer_health_manager = PeerHealthCheckTask() self.peer_health_manager.run(context=context) From 94793e5942b7339658e0b7cffd8c6e49691ebf9b Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 12 Jun 2024 11:00:51 +0700 Subject: [PATCH 100/309] [syft/blob_storage] pass obj's min size to save to the blob storage in an env varialbe - print out this min size in dev mode when launching nodes --- packages/grid/backend/grid/core/config.py | 1 + packages/grid/backend/grid/core/node.py | 1 + packages/grid/default.env | 1 + packages/syft/src/syft/node/node.py | 30 ++++++++++++------- packages/syft/src/syft/orchestra.py | 4 +++ .../src/syft/service/action/action_object.py | 4 +-- .../src/syft/store/blob_storage/__init__.py | 1 + 7 files changed, 29 insertions(+), 13 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 8c55b8cd3f7..86f37533977 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,6 +155,7 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) + MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index cde36f8c5fe..45e33bfa669 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -105,5 +105,6 @@ def seaweedfs_config() -> SeaweedFSConfig: smtp_port=settings.SMTP_PORT, smtp_host=settings.SMTP_HOST, association_request_auto_approval=settings.ASSOCIATION_REQUEST_AUTO_APPROVAL, + min_size_blob_storage_mb=settings.MIN_SIZE_BLOB_STORAGE_MB, background_tasks=True, ) diff --git a/packages/grid/default.env b/packages/grid/default.env index 6ae9748bfef..755f53317d1 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -54,6 +54,7 @@ CREATE_PRODUCER=False N_CONSUMERS=1 INMEMORY_WORKERS=True ASSOCIATION_REQUEST_AUTO_APPROVAL=False +MIN_SIZE_BLOB_STORAGE_MB=16 # New Service Flag USE_NEW_SERVICE=False diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 65a2cad07f3..e11886a31de 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -345,6 +345,7 @@ def __init__( smtp_port: int | None = None, smtp_host: str | None = None, association_request_auto_approval: bool = False, + min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this @@ -432,6 +433,7 @@ def __init__( self.init_queue_manager(queue_config=self.queue_config) + self.min_size_blob_storage_mb = min_size_blob_storage_mb self.init_blob_storage(config=blob_storage_config) context = AuthedServiceContext( @@ -464,7 +466,7 @@ def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig: path = self.get_temp_dir("db") file_name: str = f"{self.id}.sqlite" if self.dev_mode: - print(f"{store_type}'s SQLite DB path: {path/file_name}") + print(f"{store_type}'s SQLite DB path: {path/file_name}.") return SQLiteStoreConfig( client_config=SQLiteStoreClientConfig( filename=file_name, @@ -478,7 +480,9 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: client_config = OnDiskBlobStorageClientConfig( base_directory=self.get_temp_dir("blob") ) - config_ = OnDiskBlobStorageConfig(client_config=client_config) + config_ = OnDiskBlobStorageConfig( + client_config=client_config, min_size_mb=self.min_size_blob_storage_mb + ) else: config_ = config self.blob_store_config = config_ @@ -497,13 +501,16 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: remote_profile.profile_name ] = remote_profile - if ( - isinstance(self.blob_store_config, OnDiskBlobStorageConfig) - and self.dev_mode - ): + if self.dev_mode: + if isinstance(self.blob_store_config, OnDiskBlobStorageConfig): + print( + f"Using on-disk blob storage with path: " + f"{self.blob_store_config.client_config.base_directory}", + end=". ", + ) print( - f"Using on-disk blob storage with path: " - f"{self.blob_store_config.client_config.base_directory}" + f"Minimum object size to be saved to the blob storage: " + f"{self.blob_store_config.min_size_mb} (MB)." ) def run_peer_health_checks(self, context: AuthedServiceContext) -> None: @@ -1754,7 +1761,7 @@ def create_default_worker_pool(node: Node) -> SyftError | None: ) return default_worker_pool - print(f"Creating default worker image with tag='{default_worker_tag}'") + print(f"Creating default worker image with tag='{default_worker_tag}'", end=". ") # Get/Create a default worker SyftWorkerImage default_image = create_default_image( credentials=credentials, @@ -1767,7 +1774,7 @@ def create_default_worker_pool(node: Node) -> SyftError | None: return default_image if not default_image.is_built: - print(f"Building default worker image with tag={default_worker_tag}") + print(f"Building default worker image with tag={default_worker_tag}", end=". ") image_build_method = node.get_service_method(SyftWorkerImageService.build) # Build the Image for given tag result = image_build_method( @@ -1787,7 +1794,8 @@ def create_default_worker_pool(node: Node) -> SyftError | None: f"name={default_pool_name} " f"workers={worker_count} " f"image_uid={default_image.id} " - f"in_memory={node.in_memory_workers}" + f"in_memory={node.in_memory_workers}", + end=". ", ) if default_worker_pool is None: worker_to_add_ = worker_count diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 1a08f594aa2..6f34a872099 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -165,6 +165,7 @@ def deploy_to_python( create_producer: bool = False, queue_port: int | None = None, association_request_auto_approval: bool = False, + min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ) -> NodeHandle: worker_classes = { @@ -192,6 +193,7 @@ def deploy_to_python( "n_consumers": n_consumers, "create_producer": create_producer, "association_request_auto_approval": association_request_auto_approval, + "min_size_blob_storage_mb": min_size_blob_storage_mb, "background_tasks": background_tasks, } @@ -281,6 +283,7 @@ def launch( create_producer: bool = False, queue_port: int | None = None, association_request_auto_approval: bool = False, + min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ) -> NodeHandle: if dev_mode is True: @@ -317,6 +320,7 @@ def launch( create_producer=create_producer, queue_port=queue_port, association_request_auto_approval=association_request_auto_approval, + min_size_blob_storage_mb=min_size_blob_storage_mb, background_tasks=background_tasks, ) elif deployment_type_enum == DeploymentType.REMOTE: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index a9e68d68da6..981809df47e 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -850,8 +850,8 @@ def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: self.syft_action_data_cache = self.as_empty_data() else: debug( - f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), less than {min_size_mb} (MB). " - f"Skip saving to blob storage." + f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), " + f"less than {min_size_mb} (MB). Skip saving to blob storage." ) self.syft_action_data_cache = data return None diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index 663660e777c..0f52ebed642 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -273,6 +273,7 @@ def connect(self) -> BlobStorageConnection: class BlobStorageConfig(SyftBaseModel): client_type: type[BlobStorageClient] client_config: BlobStorageClientConfig + min_size_mb: int @migrate(BlobRetrievalByURLV4, BlobRetrievalByURL) From 44fee3adc36638fddd38b0354343371f16f70bfb Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 12 Jun 2024 16:27:16 +0700 Subject: [PATCH 101/309] [syft/blob_storage] stop passing min size to upload obj to blob storage to node and `BlobStorageConfig` - only pass the min size as an env varialbe so it can be read by both the client / server and can easily be configured by the client Co-authored-by: Shubham Gupta --- packages/grid/backend/grid/core/config.py | 1 - packages/grid/backend/grid/core/node.py | 1 - .../backend/backend-statefulset.yaml | 2 ++ packages/grid/helm/syft/values.yaml | 2 ++ packages/syft/src/syft/node/node.py | 11 ++++---- packages/syft/src/syft/orchestra.py | 4 --- .../src/syft/service/action/action_object.py | 27 ++++++++----------- .../src/syft/store/blob_storage/__init__.py | 1 - packages/syft/src/syft/util/util.py | 11 ++++++++ 9 files changed, 31 insertions(+), 29 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 86f37533977..8c55b8cd3f7 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,7 +155,6 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) - MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index 45e33bfa669..cde36f8c5fe 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -105,6 +105,5 @@ def seaweedfs_config() -> SeaweedFSConfig: smtp_port=settings.SMTP_PORT, smtp_host=settings.SMTP_HOST, association_request_auto_approval=settings.ASSOCIATION_REQUEST_AUTO_APPROVAL, - min_size_blob_storage_mb=settings.MIN_SIZE_BLOB_STORAGE_MB, background_tasks=True, ) diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 106d2fee893..13c5ef82523 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -122,6 +122,8 @@ spec: name: {{ .Values.seaweedfs.secretKeyName | required "seaweedfs.secretKeyName is required" }} key: s3RootPassword {{- end }} + - name: MIN_SIZE_BLOB_STORAGE_MB + value: {{ .Values.seaweedfs.minSizeBlobStorageMB | quote }} # Tracing - name: TRACE value: "false" diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 2644eac26e4..c2d13fd3cbe 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -66,6 +66,8 @@ seaweedfs: s3: rootUser: admin + minSizeBlobStorageMB: 16 + # Mount API mountApi: # automount: diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index e11886a31de..fb70f71d23b 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -345,7 +345,6 @@ def __init__( smtp_port: int | None = None, smtp_host: str | None = None, association_request_auto_approval: bool = False, - min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this @@ -433,7 +432,6 @@ def __init__( self.init_queue_manager(queue_config=self.queue_config) - self.min_size_blob_storage_mb = min_size_blob_storage_mb self.init_blob_storage(config=blob_storage_config) context = AuthedServiceContext( @@ -480,9 +478,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: client_config = OnDiskBlobStorageClientConfig( base_directory=self.get_temp_dir("blob") ) - config_ = OnDiskBlobStorageConfig( - client_config=client_config, min_size_mb=self.min_size_blob_storage_mb - ) + config_ = OnDiskBlobStorageConfig(client_config=client_config) else: config_ = config self.blob_store_config = config_ @@ -502,6 +498,9 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ] = remote_profile if self.dev_mode: + # relative + from ..util.util import min_size_for_blob_storage_upload + if isinstance(self.blob_store_config, OnDiskBlobStorageConfig): print( f"Using on-disk blob storage with path: " @@ -510,7 +509,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ) print( f"Minimum object size to be saved to the blob storage: " - f"{self.blob_store_config.min_size_mb} (MB)." + f"{min_size_for_blob_storage_upload()} (MB)." ) def run_peer_health_checks(self, context: AuthedServiceContext) -> None: diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 6f34a872099..1a08f594aa2 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -165,7 +165,6 @@ def deploy_to_python( create_producer: bool = False, queue_port: int | None = None, association_request_auto_approval: bool = False, - min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ) -> NodeHandle: worker_classes = { @@ -193,7 +192,6 @@ def deploy_to_python( "n_consumers": n_consumers, "create_producer": create_producer, "association_request_auto_approval": association_request_auto_approval, - "min_size_blob_storage_mb": min_size_blob_storage_mb, "background_tasks": background_tasks, } @@ -283,7 +281,6 @@ def launch( create_producer: bool = False, queue_port: int | None = None, association_request_auto_approval: bool = False, - min_size_blob_storage_mb: int = 16, background_tasks: bool = False, ) -> NodeHandle: if dev_mode is True: @@ -320,7 +317,6 @@ def launch( create_producer=create_producer, queue_port=queue_port, association_request_auto_approval=association_request_auto_approval, - min_size_blob_storage_mb=min_size_blob_storage_mb, background_tasks=background_tasks, ) elif deployment_type_enum == DeploymentType.REMOTE: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 981809df47e..41fe70f4b95 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -47,7 +47,7 @@ from ...types.uid import LineageID from ...types.uid import UID from ...util.logger import debug -from ...util.util import get_mb_serialized_size +from ...util.util import can_upload_to_blob_storage from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..response import SyftException @@ -785,8 +785,8 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: size = sys.getsizeof(serialized) storage_entry = CreateBlobStorageEntry.from_obj(data, file_size=size) - # if not TraceResultRegistry.current_thread_is_tracing(): - # self.syft_action_data_cache = self.as_empty_data() + if not TraceResultRegistry.current_thread_is_tracing(): + self.syft_action_data_cache = self.as_empty_data() if self.syft_blob_storage_entry_id is not None: # TODO: check if it already exists storage_entry.id = self.syft_blob_storage_entry_id @@ -829,7 +829,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: return None - def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: + def _save_to_blob_storage(self) -> SyftError | None: """ " If less than min_size_mb, skip saving to blob storage TODO: min_size_mb shoulb be passed as a env var @@ -841,19 +841,14 @@ def _save_to_blob_storage(self, min_size_mb: int = 16) -> SyftError | None: return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) - action_data_size_mb: float = get_mb_serialized_size(data) - if action_data_size_mb > min_size_mb: - result = self._save_to_blob_storage_(data) - if isinstance(result, SyftError): - return result - if not TraceResultRegistry.current_thread_is_tracing(): - self.syft_action_data_cache = self.as_empty_data() - else: - debug( - f"self.syft_action_data's size = {action_data_size_mb:4f} (MB), " - f"less than {min_size_mb} (MB). Skip saving to blob storage." - ) + if not can_upload_to_blob_storage(data): self.syft_action_data_cache = data + return None + result = self._save_to_blob_storage_(data) + if isinstance(result, SyftError): + return result + if not TraceResultRegistry.current_thread_is_tracing(): + self.syft_action_data_cache = self.as_empty_data() return None @property diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index 0f52ebed642..663660e777c 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -273,7 +273,6 @@ def connect(self) -> BlobStorageConnection: class BlobStorageConfig(SyftBaseModel): client_type: type[BlobStorageClient] client_config: BlobStorageClientConfig - min_size_mb: int @migrate(BlobRetrievalByURLV4, BlobRetrievalByURL) diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 99aba539cd9..0785e01f8c3 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -97,6 +97,17 @@ def get_mb_serialized_size(data: Any) -> float: return sys.getsizeof(serialize(data)) / (1024 * 1024) +def min_size_for_blob_storage_upload() -> int: + """ + Return the minimum size in MB for a blob storage upload. Default to 16 MB for now + """ + return int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) + + +def can_upload_to_blob_storage(data: Any) -> bool: + return get_mb_size(data) >= min_size_for_blob_storage_upload() + + def extract_name(klass: type) -> str: name_regex = r".+class.+?([\w\._]+).+" regex2 = r"([\w\.]+)" From df3c70340bc455231bda3993cb2bde0dfeeae932 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 12 Jun 2024 10:39:19 -0400 Subject: [PATCH 102/309] fix pip install commands to add uv and test import succeeds in syft function --- notebooks/admin/Custom API + Custom Worker.ipynb | 2 +- notebooks/api/0.8/10-container-images.ipynb | 10 +++++----- notebooks/api/0.8/11-container-images-k8s.ipynb | 7 ++++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/notebooks/admin/Custom API + Custom Worker.ipynb b/notebooks/admin/Custom API + Custom Worker.ipynb index 94c99a2873e..d50c1f1b4f2 100644 --- a/notebooks/admin/Custom API + Custom Worker.ipynb +++ b/notebooks/admin/Custom API + Custom Worker.ipynb @@ -91,7 +91,7 @@ "custom_dockerfile_str = f\"\"\"\n", "FROM {registry}/openmined/grid-backend:{backend_version}\n", "\n", - "RUN pip install google-cloud-bigquery[all]==3.20.1 db-dtypes==1.2.0\n", + "RUN uv pip install google-cloud-bigquery[all]==3.20.1 db-dtypes==1.2.0\n", "\n", "\"\"\".strip()\n", "print(custom_dockerfile_str)" diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index b0bacf0295f..c91b625566b 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -171,7 +171,7 @@ "opendp_dockerfile_str = f\"\"\"\n", "FROM openmined/grid-backend:{syft_base_worker_tag}\n", "\n", - "RUN pip install opendp\n", + "RUN uv pip install opendp\n", "\n", "\"\"\".strip()\n", "\n", @@ -849,7 +849,7 @@ ")\n", "def custom_worker_func(x):\n", " # third party\n", - "\n", + " import opendp\n", " return {\"y\": x + 1}" ] }, @@ -1082,7 +1082,7 @@ "custom_dockerfile_str_2 = f\"\"\"\n", "FROM openmined/grid-backend:{syft_base_worker_tag}\n", "\n", - "RUN pip install opendp\n", + "RUN uv pip install opendp\n", "\"\"\".strip()\n", "\n", "docker_config_2 = sy.DockerWorkerConfig(dockerfile=custom_dockerfile_str_2)" @@ -1298,7 +1298,7 @@ "custom_dockerfile_str_3 = f\"\"\"\n", "FROM openmined/grid-backend:{syft_base_worker_tag}\n", "\n", - "RUN pip install recordlinkage\n", + "RUN uv pip install recordlinkage\n", "\"\"\".strip()\n", "\n", "docker_config_3 = sy.DockerWorkerConfig(dockerfile=custom_dockerfile_str_3)\n", @@ -1320,7 +1320,7 @@ " num_workers=2,\n", " tag=docker_tag_3,\n", " config=docker_config_3,\n", - " reason=\"I want to do some more cool data science with PySyft and OpenDP\",\n", + " reason=\"I want to do some more cool data science with PySyft and recordlinkage\",\n", " pull_image=pull,\n", " )\n", ")\n", diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index c685e6f8d7d..e8f30ce2440 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -266,7 +266,7 @@ "custom_dockerfile_str = f\"\"\"\n", "FROM {registry}/{repo}:{tag}\n", "\n", - "RUN pip install pydicom\n", + "RUN uv pip install pydicom\n", "\n", "\"\"\".strip()" ] @@ -812,6 +812,7 @@ " worker_pool_name=worker_pool_name,\n", ")\n", "def custom_worker_func(x):\n", + " import pydicom\n", " return {\"y\": x + 1}" ] }, @@ -985,7 +986,7 @@ "dockerfile_opendp = f\"\"\"\n", "FROM {registry}/{repo}:{tag}\n", "\n", - "RUN pip install opendp\n", + "RUN uv pip install opendp\n", "\"\"\".strip()\n", "\n", "docker_config_opendp = sy.DockerWorkerConfig(dockerfile=dockerfile_opendp)" @@ -1280,7 +1281,7 @@ "dockerfile_recordlinkage = f\"\"\"\n", "FROM {registry}/{repo}:{tag}\n", "\n", - "RUN pip install recordlinkage\n", + "RUN uv pip install recordlinkage\n", "\"\"\".strip()\n", "\n", "docker_config_recordlinkage = sy.DockerWorkerConfig(dockerfile=dockerfile_recordlinkage)\n", From e46e0df6d7b6ba66bcd5afc9bb05987a4c1883ef Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 12 Jun 2024 10:56:13 -0400 Subject: [PATCH 103/309] add version print to avoid unused import --- notebooks/api/0.8/10-container-images.ipynb | 1 + notebooks/api/0.8/11-container-images-k8s.ipynb | 9 +++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index c91b625566b..ee302227110 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -850,6 +850,7 @@ "def custom_worker_func(x):\n", " # third party\n", " import opendp\n", + " print(opendp.__version__)\n", " return {\"y\": x + 1}" ] }, diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index e8f30ce2440..dc6198ccd2f 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -813,6 +813,7 @@ ")\n", "def custom_worker_func(x):\n", " import pydicom\n", + " print(pydicom.__version__)\n", " return {\"y\": x + 1}" ] }, @@ -1272,11 +1273,9 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "101", "metadata": {}, - "outputs": [], "source": [ "dockerfile_recordlinkage = f\"\"\"\n", "FROM {registry}/{repo}:{tag}\n", @@ -1290,11 +1289,9 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "102", "metadata": {}, - "outputs": [], "source": [ "pool_name_recordlinkage = \"recordlinkage-pool\"\n", "recordlinkage_pod_annotations = {\n", From ba7bb8fe1d7b2fc26b6c7988483359bc0569cc14 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 12 Jun 2024 16:28:44 -0400 Subject: [PATCH 104/309] fix linting --- notebooks/api/0.8/10-container-images.ipynb | 1 + notebooks/api/0.8/11-container-images-k8s.ipynb | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index ee302227110..a97b5fcde2c 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -850,6 +850,7 @@ "def custom_worker_func(x):\n", " # third party\n", " import opendp\n", + "\n", " print(opendp.__version__)\n", " return {\"y\": x + 1}" ] diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index dc6198ccd2f..64a856cae3e 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -812,7 +812,9 @@ " worker_pool_name=worker_pool_name,\n", ")\n", "def custom_worker_func(x):\n", + " # third party\n", " import pydicom\n", + "\n", " print(pydicom.__version__)\n", " return {\"y\": x + 1}" ] @@ -1273,9 +1275,11 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "101", "metadata": {}, + "outputs": [], "source": [ "dockerfile_recordlinkage = f\"\"\"\n", "FROM {registry}/{repo}:{tag}\n", @@ -1289,9 +1293,11 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "102", "metadata": {}, + "outputs": [], "source": [ "pool_name_recordlinkage = \"recordlinkage-pool\"\n", "recordlinkage_pod_annotations = {\n", From dbd079e09b94c61bdb6478f0da27b63dc9342946 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 12 Jun 2024 17:56:57 -0400 Subject: [PATCH 105/309] remove function third party import in notebook 10 since not installed in thread worker --- notebooks/api/0.8/10-container-images.ipynb | 2 -- 1 file changed, 2 deletions(-) diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index a97b5fcde2c..25f620f476c 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -849,9 +849,7 @@ ")\n", "def custom_worker_func(x):\n", " # third party\n", - " import opendp\n", "\n", - " print(opendp.__version__)\n", " return {\"y\": x + 1}" ] }, From 495357b62b3972ff8cb98a607a27be80894fec3a Mon Sep 17 00:00:00 2001 From: dk Date: Thu, 13 Jun 2024 11:13:30 +0700 Subject: [PATCH 106/309] [syft/test] test action object's saving to blob storage behavior when using `.send` --- .../src/syft/service/action/action_object.py | 4 --- .../syft/blob_storage/blob_storage_test.py | 36 +++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 41fe70f4b95..52f9ab0f097 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -830,10 +830,6 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: return None def _save_to_blob_storage(self) -> SyftError | None: - """ " - If less than min_size_mb, skip saving to blob storage - TODO: min_size_mb shoulb be passed as a env var - """ data = self.syft_action_data if isinstance(data, SyftError): return data diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 11942815529..3efd8fb5e18 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -1,18 +1,23 @@ # stdlib import io +import os import random # third party +import numpy as np import pytest # syft absolute import syft as sy +from syft import ActionObject +from syft.client.domain_client import DomainClient from syft.service.context import AuthedServiceContext from syft.service.response import SyftSuccess from syft.service.user.user import UserCreate from syft.store.blob_storage import BlobDeposit from syft.store.blob_storage import SyftObjectRetrieval from syft.types.blob_storage import CreateBlobStorageEntry +from syft.util.util import min_size_for_blob_storage_upload raw_data = {"test": "test"} data = sy.serialize(raw_data, to_bytes=True) @@ -99,3 +104,34 @@ def test_blob_storage_delete(authed_context, blob_storage): with pytest.raises(FileNotFoundError): blob_storage.read(authed_context, blob_deposit.blob_storage_entry_id) + + +def test_action_obj_send_save_to_blob_storage(worker): + # set this so we will always save action objects to blob storage + os.environ["MIN_SIZE_BLOB_STORAGE_MB"] = "0" + + orig_obj: np.ndarray = np.array([1, 2, 3]) + action_obj = ActionObject.from_obj(orig_obj) + assert action_obj.dtype == orig_obj.dtype + + root_client: DomainClient = worker.root_client + action_obj.send(root_client) + assert isinstance(action_obj.syft_blob_storage_entry_id, sy.UID) + root_authed_ctx = AuthedServiceContext( + node=worker, credentials=root_client.verify_key + ) + + blob_storage = worker.get_service("BlobStorageService") + syft_retrieved_data = blob_storage.read( + root_authed_ctx, action_obj.syft_blob_storage_entry_id + ) + assert isinstance(syft_retrieved_data, SyftObjectRetrieval) + assert all(syft_retrieved_data.read() == orig_obj) + + # stop saving small action objects to blob storage + del os.environ["MIN_SIZE_BLOB_STORAGE_MB"] + assert min_size_for_blob_storage_upload() == 16 + orig_obj_2: np.ndarray = np.array([1, 2, 4]) + action_obj_2 = ActionObject.from_obj(orig_obj_2) + action_obj_2.send(root_client) + assert action_obj_2.syft_blob_storage_entry_id is None From 7410732e0c5e268b4da09229b3a9b2098e0b0ce7 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 3 Jun 2024 13:57:49 +0800 Subject: [PATCH 107/309] Remove redundant code --- packages/syft/src/syft/service/queue/zmq_queue.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 6cf4cf3794f..fc47098555c 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -109,8 +109,6 @@ class Worker(SyftBaseModel): syft_worker_id: UID | None = None expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @field_validator("syft_worker_id", mode="before") @classmethod def set_syft_worker_id(cls, v: Any) -> Any: @@ -366,7 +364,7 @@ def purge_workers(self) -> None: Workers are oldest to most recent, so we stop at the first alive worker. """ # work on a copy of the iterator - for worker in list(self.waiting): + for worker in self.waiting: if worker.has_expired(): logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", From f77a124050d862608e08d8940193d7bd67750b64 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 3 Jun 2024 13:56:15 +0800 Subject: [PATCH 108/309] Mark worker requested to be deleted as _to_be_deleted --- packages/syft/src/syft/service/queue/zmq_queue.py | 12 +++++++++++- packages/syft/src/syft/service/worker/worker_pool.py | 1 + .../syft/src/syft/service/worker/worker_service.py | 2 ++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index fc47098555c..675e4657295 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -11,6 +11,7 @@ # third party from loguru import logger from pydantic import field_validator +from result import Result import zmq from zmq import Frame from zmq import LINGER @@ -31,6 +32,7 @@ from ..response import SyftSuccess from ..service import AbstractService from ..worker.worker_pool import ConsumerState +from ..worker.worker_pool import SyftWorker from ..worker.worker_stash import WorkerStash from .base_queue import AbstractMessageHandler from .base_queue import QueueClient @@ -125,6 +127,14 @@ def get_expiry(self) -> float: def reset_expiry(self) -> None: self.expiry_t.reset() + def _syft_worker(self, stash: WorkerStash) -> Result[SyftWorker | None, str]: + return stash.get_by_uid(self.syft_worker_id) + + def _to_be_deleted(self, stash: WorkerStash) -> bool: + return self._syft_worker(stash).map_or( + False, lambda x: x is not None and x._to_be_deleted + ) + @serializable() class ZMQProducer(QueueProducer): @@ -365,7 +375,7 @@ def purge_workers(self) -> None: """ # work on a copy of the iterator for worker in self.waiting: - if worker.has_expired(): + if worker.has_expired() or worker._to_be_deleted(self.worker_stash): logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", worker.identity, diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index 461c4f61b7d..c591e494431 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -73,6 +73,7 @@ class SyftWorker(SyftObject): worker_pool_name: str consumer_state: ConsumerState = ConsumerState.DETACHED job_id: UID | None = None + _to_be_deleted: bool = False @property def logs(self) -> str | SyftError: diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 6c574b99735..e8fe18e4bb0 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -161,6 +161,8 @@ def delete( if isinstance(worker, SyftError): return worker + worker._to_be_deleted = True + worker_pool_name = worker.worker_pool_name # relative From 9499eecd7cb637224dfc168ae7c3d4ea55685dd6 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 14:30:18 +0800 Subject: [PATCH 109/309] Delete worker in queue --- .../syft/src/syft/service/queue/zmq_queue.py | 29 +++++++++++--- .../src/syft/service/worker/worker_service.py | 38 +++++++++++-------- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 675e4657295..62d9d4a8925 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -7,6 +7,7 @@ import time from time import sleep from typing import Any +from typing import cast # third party from loguru import logger @@ -18,6 +19,7 @@ from zmq.error import ContextTerminated # relative +from ...node.credentials import SyftVerifyKey from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize @@ -127,11 +129,13 @@ def get_expiry(self) -> float: def reset_expiry(self) -> None: self.expiry_t.reset() - def _syft_worker(self, stash: WorkerStash) -> Result[SyftWorker | None, str]: - return stash.get_by_uid(self.syft_worker_id) + def _syft_worker( + self, stash: WorkerStash, credentials: SyftVerifyKey + ) -> Result[SyftWorker | None, str]: + return stash.get_by_uid(credentials == credentials, uid=self.syft_worker_id) - def _to_be_deleted(self, stash: WorkerStash) -> bool: - return self._syft_worker(stash).map_or( + def _to_be_deleted(self, stash: WorkerStash, credentials: SyftVerifyKey) -> bool: + return self._syft_worker(stash, credentials).map_or( False, lambda x: x is not None and x._to_be_deleted ) @@ -375,7 +379,9 @@ def purge_workers(self) -> None: """ # work on a copy of the iterator for worker in self.waiting: - if worker.has_expired() or worker._to_be_deleted(self.worker_stash): + if worker.has_expired() or worker._to_be_deleted( + self.worker_stash, self.auth_context.syft_client_verify_key + ): logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", worker.identity, @@ -578,6 +584,19 @@ def process_worker(self, address: bytes, msg: list[bytes]) -> None: def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.node.get_service("WorkerService") + ) + worker_service._delete( + self.auth_context, + worker._syft_worker( + self.worker_stash, self.auth_context.syft_client_verify_key + ), + ) + if disconnect: self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT, None, None) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index e8fe18e4bb0..9c64e0f403d 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -146,23 +146,12 @@ def logs( return logs if raw else logs.decode(errors="ignore") - @service_method( - path="worker.delete", - name="delete", - roles=DATA_OWNER_ROLE_LEVEL, - ) - def delete( + def _delete( self, context: AuthedServiceContext, - uid: UID, - force: bool = False, + worker: SyftWorker, ) -> SyftSuccess | SyftError: - worker = self._get_worker(context=context, uid=uid) - if isinstance(worker, SyftError): - return worker - - worker._to_be_deleted = True - + uid = SyftWorker.id worker_pool_name = worker.worker_pool_name # relative @@ -207,7 +196,7 @@ def delete( if isinstance(docker_container, SyftError): return docker_container - stopped = _stop_worker_container(worker, docker_container, force) + stopped = _stop_worker_container(worker, docker_container, force=True) if stopped is not None: return stopped else: @@ -237,6 +226,25 @@ def delete( message=f"Worker with id: {uid} deleted successfully from pool: {worker_pool.name}" ) + @service_method( + path="worker.delete", + name="delete", + roles=DATA_OWNER_ROLE_LEVEL, + ) + def delete( + self, + context: AuthedServiceContext, + uid: UID, + force: bool = False, + ) -> SyftSuccess | SyftError: + worker = self._get_worker(context=context, uid=uid) + + if not force: + worker._to_be_deleted = True + return SyftSuccess(f"Worker {uid} has been marked for deletion.") + + return self._delete(context, worker) + def _get_worker( self, context: AuthedServiceContext, uid: UID ) -> SyftWorker | SyftError: From 5e47cac99790e15e426321fb9b27c9a52e71799c Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 14:59:04 +0800 Subject: [PATCH 110/309] Kill associated job when deleting worker --- packages/syft/src/syft/service/worker/worker_service.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 9c64e0f403d..6b8a49d934a 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -155,8 +155,12 @@ def _delete( worker_pool_name = worker.worker_pool_name # relative + from ...service.job.job_service import JobService from .worker_pool_service import SyftWorkerPoolService + job_service = cast(JobService, context.node.get_service("JobService")) + job_service._kill(job_service.get(context=context, uid=worker.job_id)) + worker_pool_service: AbstractService = context.node.get_service( SyftWorkerPoolService ) From 29ccd97bcc5cb1d4130b19f5b5c559d3ecdf4acc Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 15:00:28 +0800 Subject: [PATCH 111/309] Get service by class instead of name --- packages/syft/src/syft/service/queue/zmq_queue.py | 2 +- packages/syft/src/syft/service/worker/worker_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 62d9d4a8925..7ec18cbaafb 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -588,7 +588,7 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: from ...service.worker.worker_service import WorkerService worker_service = cast( - WorkerService, self.auth_context.node.get_service("WorkerService") + WorkerService, self.auth_context.node.get_service(WorkerService) ) worker_service._delete( self.auth_context, diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 6b8a49d934a..c7b3987ddd1 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -158,7 +158,7 @@ def _delete( from ...service.job.job_service import JobService from .worker_pool_service import SyftWorkerPoolService - job_service = cast(JobService, context.node.get_service("JobService")) + job_service = cast(JobService, context.node.get_service(JobService)) job_service._kill(job_service.get(context=context, uid=worker.job_id)) worker_pool_service: AbstractService = context.node.get_service( From 217c097820590c5f63cd222ae0622336d1155807 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 13 Jun 2024 17:50:53 +0700 Subject: [PATCH 112/309] [syft/blob_storage] pass `MIN_SIZE_BLOB_STORAGE_MB` env var into `Settings` and `BlobStorageConfig` which then is passed into a `Node` - `Node` then pass this information around by `NodeMetadata` / `NodeMetadataJSON` - Update utility methods to see if a client can upload an object to the blob storage - Modify test in `blob_storage_test` to reflect these changes Co-authored-by: Shubham Gupta --- packages/grid/backend/grid/core/config.py | 1 + packages/grid/backend/grid/core/node.py | 5 ++- packages/syft/src/syft/node/node.py | 11 ++--- .../src/syft/service/action/action_object.py | 5 ++- .../src/syft/service/blob_storage/util.py | 21 ++++++++++ .../syft/service/metadata/node_metadata.py | 2 + .../src/syft/store/blob_storage/__init__.py | 1 + packages/syft/src/syft/util/util.py | 13 +----- .../syft/blob_storage/blob_storage_test.py | 40 +++++++++---------- 9 files changed, 58 insertions(+), 41 deletions(-) create mode 100644 packages/syft/src/syft/service/blob_storage/util.py diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 8c55b8cd3f7..86f37533977 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,6 +155,7 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) + MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index cde36f8c5fe..926cbbc5556 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -66,7 +66,10 @@ def seaweedfs_config() -> SeaweedFSConfig: mount_port=settings.SEAWEED_MOUNT_PORT, ) - return SeaweedFSConfig(client_config=seaweed_client_config) + return SeaweedFSConfig( + client_config=seaweed_client_config, + min_blob_size=settings.MIN_SIZE_BLOB_STORAGE_MB, + ) node_type = NodeType(get_node_type()) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index fb70f71d23b..3d0c8a376f7 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -478,7 +478,10 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: client_config = OnDiskBlobStorageClientConfig( base_directory=self.get_temp_dir("blob") ) - config_ = OnDiskBlobStorageConfig(client_config=client_config) + config_ = OnDiskBlobStorageConfig( + client_config=client_config, + min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16), + ) else: config_ = config self.blob_store_config = config_ @@ -498,9 +501,6 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ] = remote_profile if self.dev_mode: - # relative - from ..util.util import min_size_for_blob_storage_upload - if isinstance(self.blob_store_config, OnDiskBlobStorageConfig): print( f"Using on-disk blob storage with path: " @@ -509,7 +509,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ) print( f"Minimum object size to be saved to the blob storage: " - f"{min_size_for_blob_storage_upload()} (MB)." + f"{self.blob_store_config.min_blob_size} (MB)." ) def run_peer_health_checks(self, context: AuthedServiceContext) -> None: @@ -1083,6 +1083,7 @@ def metadata(self) -> NodeMetadata: node_side_type=node_side_type, show_warnings=show_warnings, eager_execution_enabled=eager_execution_enabled, + min_size_blob_storage_mb=self.blob_store_config.min_blob_size, ) @property diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 52f9ab0f097..8419ff8c3e8 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -35,6 +35,7 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize +from ...service.blob_storage.util import can_upload_to_blob_storage from ...service.response import SyftError from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel @@ -47,7 +48,6 @@ from ...types.uid import LineageID from ...types.uid import UID from ...util.logger import debug -from ...util.util import can_upload_to_blob_storage from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..response import SyftException @@ -837,7 +837,8 @@ def _save_to_blob_storage(self) -> SyftError | None: return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) - if not can_upload_to_blob_storage(data): + api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + if not can_upload_to_blob_storage(data, api.metadata): self.syft_action_data_cache = data return None result = self._save_to_blob_storage_(data) diff --git a/packages/syft/src/syft/service/blob_storage/util.py b/packages/syft/src/syft/service/blob_storage/util.py new file mode 100644 index 00000000000..11d62f5eb97 --- /dev/null +++ b/packages/syft/src/syft/service/blob_storage/util.py @@ -0,0 +1,21 @@ +# stdlib +from typing import Any + +# relative +from ...util.util import get_mb_serialized_size +from ..metadata.node_metadata import NodeMetadata +from ..metadata.node_metadata import NodeMetadataJSON + + +def min_size_for_blob_storage_upload(metadata: NodeMetadata | NodeMetadataJSON) -> int: + if not isinstance(metadata, (NodeMetadata | NodeMetadataJSON)): + raise ValueError( + f"argument `metadata` is type {type(metadata)}, but it should be of type NodeMetadata or NodeMetadataJSON" + ) + return metadata.min_size_blob_storage_mb + + +def can_upload_to_blob_storage( + data: Any, metadata: NodeMetadata | NodeMetadataJSON +) -> bool: + return get_mb_serialized_size(data) >= min_size_for_blob_storage_upload(metadata) diff --git a/packages/syft/src/syft/service/metadata/node_metadata.py b/packages/syft/src/syft/service/metadata/node_metadata.py index de60b90a412..6aaae3fe69c 100644 --- a/packages/syft/src/syft/service/metadata/node_metadata.py +++ b/packages/syft/src/syft/service/metadata/node_metadata.py @@ -60,6 +60,7 @@ class NodeMetadata(SyftObject): node_side_type: str show_warnings: bool eager_execution_enabled: bool + min_size_blob_storage_mb: int def check_version(self, client_version: str) -> bool: return check_version( @@ -112,6 +113,7 @@ class NodeMetadataJSON(BaseModel, StorableObjectType): node_side_type: str show_warnings: bool supported_protocols: list = [] + min_size_blob_storage_mb: int @model_validator(mode="before") @classmethod diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index 663660e777c..8056be60207 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -273,6 +273,7 @@ def connect(self) -> BlobStorageConnection: class BlobStorageConfig(SyftBaseModel): client_type: type[BlobStorageClient] client_config: BlobStorageClientConfig + min_blob_size: int # in MB @migrate(BlobRetrievalByURLV4, BlobRetrievalByURL) diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 0785e01f8c3..7d13e348f67 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -94,18 +94,7 @@ def get_mb_size(data: Any) -> float: def get_mb_serialized_size(data: Any) -> float: - return sys.getsizeof(serialize(data)) / (1024 * 1024) - - -def min_size_for_blob_storage_upload() -> int: - """ - Return the minimum size in MB for a blob storage upload. Default to 16 MB for now - """ - return int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) - - -def can_upload_to_blob_storage(data: Any) -> bool: - return get_mb_size(data) >= min_size_for_blob_storage_upload() + return sys.getsizeof(serialize(data, to_bytes=True)) / (1024 * 1024) def extract_name(klass: type) -> str: diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 3efd8fb5e18..944b25628f2 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -1,6 +1,5 @@ # stdlib import io -import os import random # third party @@ -11,13 +10,14 @@ import syft as sy from syft import ActionObject from syft.client.domain_client import DomainClient +from syft.service.blob_storage.util import can_upload_to_blob_storage +from syft.service.blob_storage.util import min_size_for_blob_storage_upload from syft.service.context import AuthedServiceContext from syft.service.response import SyftSuccess from syft.service.user.user import UserCreate from syft.store.blob_storage import BlobDeposit from syft.store.blob_storage import SyftObjectRetrieval from syft.types.blob_storage import CreateBlobStorageEntry -from syft.util.util import min_size_for_blob_storage_upload raw_data = {"test": "test"} data = sy.serialize(raw_data, to_bytes=True) @@ -107,31 +107,29 @@ def test_blob_storage_delete(authed_context, blob_storage): def test_action_obj_send_save_to_blob_storage(worker): - # set this so we will always save action objects to blob storage - os.environ["MIN_SIZE_BLOB_STORAGE_MB"] = "0" - - orig_obj: np.ndarray = np.array([1, 2, 3]) - action_obj = ActionObject.from_obj(orig_obj) - assert action_obj.dtype == orig_obj.dtype - + # this small object should not be saved to blob storage + data_small: np.ndarray = np.array([1, 2, 3]) + action_obj = ActionObject.from_obj(data_small) + assert action_obj.dtype == data_small.dtype root_client: DomainClient = worker.root_client action_obj.send(root_client) - assert isinstance(action_obj.syft_blob_storage_entry_id, sy.UID) + assert action_obj.syft_blob_storage_entry_id is None + + # big object that should be saved to blob storage + assert min_size_for_blob_storage_upload(root_client.api.metadata) == 16 + num_elements = 50 * 1024 * 1024 + data_big = np.random.randint(0, 100, size=num_elements) # 4 bytes per int32 + action_obj_2 = ActionObject.from_obj(data_big) + assert can_upload_to_blob_storage(action_obj_2, root_client.api.metadata) + action_obj_2.send(root_client) + assert isinstance(action_obj_2.syft_blob_storage_entry_id, sy.UID) + # get back the object from blob storage to check if it is the same root_authed_ctx = AuthedServiceContext( node=worker, credentials=root_client.verify_key ) - blob_storage = worker.get_service("BlobStorageService") syft_retrieved_data = blob_storage.read( - root_authed_ctx, action_obj.syft_blob_storage_entry_id + root_authed_ctx, action_obj_2.syft_blob_storage_entry_id ) assert isinstance(syft_retrieved_data, SyftObjectRetrieval) - assert all(syft_retrieved_data.read() == orig_obj) - - # stop saving small action objects to blob storage - del os.environ["MIN_SIZE_BLOB_STORAGE_MB"] - assert min_size_for_blob_storage_upload() == 16 - orig_obj_2: np.ndarray = np.array([1, 2, 4]) - action_obj_2 = ActionObject.from_obj(orig_obj_2) - action_obj_2.send(root_client) - assert action_obj_2.syft_blob_storage_entry_id is None + assert all(syft_retrieved_data.read() == data_big) From fbcb4ea8a75b28de364644ba393c9bc6ae5a85f3 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 13 Jun 2024 20:46:04 +0700 Subject: [PATCH 113/309] [syft/tests] add `min_size_blob_storage_mb` to `NodeMetadataJSON` in fixture --- packages/syft/tests/syft/settings/fixtures.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/tests/syft/settings/fixtures.py b/packages/syft/tests/syft/settings/fixtures.py index e083b93bd95..65cbf18deca 100644 --- a/packages/syft/tests/syft/settings/fixtures.py +++ b/packages/syft/tests/syft/settings/fixtures.py @@ -66,6 +66,7 @@ def metadata_json(faker) -> NodeMetadataJSON: node_side_type=NodeSideType.LOW_SIDE.value, show_warnings=False, node_type=NodeType.DOMAIN.value, + min_size_blob_storage_mb=16, ) From eaefaffb99fde1bb8c635920308bd2892318213d Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 15:01:03 +0800 Subject: [PATCH 114/309] Fix type --- packages/syft/src/syft/service/worker/worker_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index c7b3987ddd1..f1f1464c6cc 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -161,8 +161,8 @@ def _delete( job_service = cast(JobService, context.node.get_service(JobService)) job_service._kill(job_service.get(context=context, uid=worker.job_id)) - worker_pool_service: AbstractService = context.node.get_service( - SyftWorkerPoolService + worker_pool_service = cast( + SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService) ) worker_pool_stash = worker_pool_service.stash result = worker_pool_stash.get_by_name( From 45c3d77df5671e06c3ebc04386d7c439f6be62fc Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 15:03:47 +0800 Subject: [PATCH 115/309] Use auth_context.credentials --- packages/syft/src/syft/service/queue/zmq_queue.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 7ec18cbaafb..4da1b758488 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -380,7 +380,7 @@ def purge_workers(self) -> None: # work on a copy of the iterator for worker in self.waiting: if worker.has_expired() or worker._to_be_deleted( - self.worker_stash, self.auth_context.syft_client_verify_key + self.worker_stash, self.auth_context.credentials ): logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", @@ -592,9 +592,7 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: ) worker_service._delete( self.auth_context, - worker._syft_worker( - self.worker_stash, self.auth_context.syft_client_verify_key - ), + worker._syft_worker(self.worker_stash, self.auth_context.credentials), ) if disconnect: From 7578157f1e5695bc1694f940fef90045be5b20ed Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 22:20:05 +0800 Subject: [PATCH 116/309] Fix call to JobService._kill --- packages/syft/src/syft/service/worker/worker_service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index f1f1464c6cc..314253aabef 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -159,7 +159,10 @@ def _delete( from .worker_pool_service import SyftWorkerPoolService job_service = cast(JobService, context.node.get_service(JobService)) - job_service._kill(job_service.get(context=context, uid=worker.job_id)) + job_service._kill( + context=context, + job=job_service.get(context=context, uid=worker.job_id), + ) worker_pool_service = cast( SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService) From 7fb8d5f252c3118aebc48d72e4311f6cf3f7ed56 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 22:21:21 +0800 Subject: [PATCH 117/309] Move deleting worker to the last step --- .../syft/src/syft/service/queue/zmq_queue.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 4da1b758488..a1b7798846a 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -584,17 +584,6 @@ def process_worker(self, address: bytes, msg: list[bytes]) -> None: def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" - # relative - from ...service.worker.worker_service import WorkerService - - worker_service = cast( - WorkerService, self.auth_context.node.get_service(WorkerService) - ) - worker_service._delete( - self.auth_context, - worker._syft_worker(self.worker_stash, self.auth_context.credentials), - ) - if disconnect: self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT, None, None) @@ -611,6 +600,17 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: worker.syft_worker_id, ConsumerState.DETACHED ) + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.node.get_service(WorkerService) + ) + worker_service._delete( + self.auth_context, + worker._syft_worker(self.worker_stash, self.auth_context.credentials), + ) + @property def alive(self) -> bool: return not self.socket.closed From 44a40555eece425eb772798cfaef2678814c1ce3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 22:23:55 +0800 Subject: [PATCH 118/309] Disconnect worker if it's marked for deletion --- packages/syft/src/syft/service/queue/zmq_queue.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index a1b7798846a..b99327b37a5 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -379,8 +379,10 @@ def purge_workers(self) -> None: """ # work on a copy of the iterator for worker in self.waiting: - if worker.has_expired() or worker._to_be_deleted( - self.worker_stash, self.auth_context.credentials + if worker.has_expired() or ( + to_be_deleted := worker._to_be_deleted( + self.worker_stash, self.auth_context.credentials + ) ): logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", @@ -389,7 +391,7 @@ def purge_workers(self) -> None: worker.get_expiry(), Timeout.now(), ) - self.delete_worker(worker, False) + self.delete_worker(worker, to_be_deleted) def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState From 6c1aaa18e6adc829a8b4b3f9b5d5ea205a4be273 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 13 Jun 2024 22:31:59 +0800 Subject: [PATCH 119/309] Fix SyftSuccess call --- packages/syft/src/syft/service/worker/worker_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 314253aabef..b4bc5f36ba4 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -248,7 +248,7 @@ def delete( if not force: worker._to_be_deleted = True - return SyftSuccess(f"Worker {uid} has been marked for deletion.") + return SyftSuccess(message=f"Worker {uid} has been marked for deletion.") return self._delete(context, worker) From 63ee3ec010fa8e3a4116853c91c8e24b09429f68 Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 14 Jun 2024 10:28:07 +0700 Subject: [PATCH 120/309] [syft/blob_storage] fix lint, add some debugging statements --- .../src/syft/protocol/protocol_version.json | 2 +- .../src/syft/service/action/action_object.py | 17 ++++++++++++++++- .../src/syft/service/action/action_service.py | 5 ++++- .../syft/src/syft/service/blob_storage/util.py | 3 ++- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 375aa1af66b..6dfcf62f3a2 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -16,7 +16,7 @@ "NodeMetadata": { "5": { "version": 5, - "hash": "70197b4725dbdea0560ed8388e4d20b76808bee988f3630c5f916ee8f48761f8", + "hash": "f3927d167073a4db369a07e3bbbf756075bbb29e9addec324b8cd2c3597b75a1", "action": "add" } }, diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 8419ff8c3e8..22038f47d2f 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -837,7 +837,22 @@ def _save_to_blob_storage(self) -> SyftError | None: return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) - api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, + ) + print() + print() + print("---- inside ActionObject._save_to_blob_storage() ----") + print(f"{self.syft_node_location = }") + print(f"{self.syft_client_verify_key = }") + print(f"{APIRegistry = }") + print(f"{api = }") + if api is None: + raise ValueError( + f"api is None. You must login to {self.syft_node_location}" + ) if not can_upload_to_blob_storage(data, api.metadata): self.syft_action_data_cache = data return None diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 00f1414b247..ab92f29592f 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -448,6 +448,10 @@ def set_result_to_store( context.node.id, context.credentials, ) + print("---- inside ActionService.set_result_to_store() ----") + print(f"{result_action_object.__dict__ = }") + print("calling result_action_object._save_to_blob_storage()") + blob_store_result = result_action_object._save_to_blob_storage() if isinstance(blob_store_result, SyftError): return Err(blob_store_result.message) @@ -723,7 +727,6 @@ def execute( context.node.id, context.credentials, ) - blob_store_result = result_action_object._save_to_blob_storage() if isinstance(blob_store_result, SyftError): return blob_store_result diff --git a/packages/syft/src/syft/service/blob_storage/util.py b/packages/syft/src/syft/service/blob_storage/util.py index 11d62f5eb97..68f3250035c 100644 --- a/packages/syft/src/syft/service/blob_storage/util.py +++ b/packages/syft/src/syft/service/blob_storage/util.py @@ -10,7 +10,8 @@ def min_size_for_blob_storage_upload(metadata: NodeMetadata | NodeMetadataJSON) -> int: if not isinstance(metadata, (NodeMetadata | NodeMetadataJSON)): raise ValueError( - f"argument `metadata` is type {type(metadata)}, but it should be of type NodeMetadata or NodeMetadataJSON" + f"argument `metadata` is type {type(metadata)}, " + f"but it should be of type NodeMetadata or NodeMetadataJSON" ) return metadata.min_size_blob_storage_mb From d37606cf919419b262ce453b35ad9b1381844dd6 Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 14 Jun 2024 15:00:25 +0700 Subject: [PATCH 121/309] [syft/action_obj] put action object's upload to blob storage code in a try-catch - if can't save the data to blob store, save the data to `syft_action_data_cache` --- .../src/syft/service/action/action_object.py | 44 +++++++++---------- .../src/syft/service/action/action_service.py | 4 -- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 87266358815..3453c86f4af 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -837,30 +837,28 @@ def _save_to_blob_storage(self) -> SyftError | None: return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) - - api = APIRegistry.api_for( - node_uid=self.syft_node_location, - user_verify_key=self.syft_client_verify_key, - ) - print() - print() - print("---- inside ActionObject._save_to_blob_storage() ----") - print(f"{self.syft_node_location = }") - print(f"{self.syft_client_verify_key = }") - print(f"{APIRegistry = }") - print(f"{api = }") - if api is None: - raise ValueError( - f"api is None. You must login to {self.syft_node_location}" + try: + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, ) - if not can_upload_to_blob_storage(data, api.metadata): - self.syft_action_data_cache = data - return None - result = self._save_to_blob_storage_(data) - if isinstance(result, SyftError): - return result - if not TraceResultRegistry.current_thread_is_tracing(): - self.syft_action_data_cache = self.as_empty_data() + if api is None: + raise ValueError( + f"api is None. You must login to {self.syft_node_location}" + ) + if can_upload_to_blob_storage(data, api.metadata): + result = self._save_to_blob_storage_(data) + if isinstance(result, SyftError): + return result + if not TraceResultRegistry.current_thread_is_tracing(): + self.syft_action_data_cache = self.as_empty_data() + return None + except Exception as e: + print( + f"Failed to save action object {self.id} to the blob store. Error: {e}" + ) + + self.syft_action_data_cache = data return None @property diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index ab92f29592f..a7b908c99fa 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -448,10 +448,6 @@ def set_result_to_store( context.node.id, context.credentials, ) - print("---- inside ActionService.set_result_to_store() ----") - print(f"{result_action_object.__dict__ = }") - print("calling result_action_object._save_to_blob_storage()") - blob_store_result = result_action_object._save_to_blob_storage() if isinstance(blob_store_result, SyftError): return Err(blob_store_result.message) From 759f66b68744d5272136aa41808b91c0773a7f13 Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 14 Jun 2024 15:56:55 +0700 Subject: [PATCH 122/309] [syft/tests] test saving big objects to blob storage when uploading big datasets --- .../syft/blob_storage/blob_storage_test.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 944b25628f2..1889004a47e 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -133,3 +133,46 @@ def test_action_obj_send_save_to_blob_storage(worker): ) assert isinstance(syft_retrieved_data, SyftObjectRetrieval) assert all(syft_retrieved_data.read() == data_big) + + +def test_upload_dataset_save_to_blob_storage(worker): + root_client: DomainClient = worker.root_client + root_authed_ctx = AuthedServiceContext( + node=worker, credentials=root_client.verify_key + ) + dataset = sy.Dataset( + name="small_dataset", + asset_list=[ + sy.Asset( + name="small_dataset", + data=np.array([1, 2, 3]), + mock=np.array([1, 1, 1]), + ) + ], + ) + root_client.upload_dataset(dataset) + blob_storage = worker.get_service("BlobStorageService") + assert len(blob_storage.get_all_blob_storage_entries(context=root_authed_ctx)) == 0 + + num_elements = 50 * 1024 * 1024 + data_big = np.random.randint(0, 100, size=num_elements) + dataset_big = sy.Dataset( + name="big_dataset", + asset_list=[ + sy.Asset( + name="big_dataset", + data=data_big, + mock=np.array([1, 1, 1]), + ) + ], + ) + root_client.upload_dataset(dataset_big) + # the private data should be saved to the blob storage + blob_entries: list = blob_storage.get_all_blob_storage_entries( + context=root_authed_ctx + ) + assert len(blob_entries) == 1 + data_big_retrieved: SyftObjectRetrieval = blob_storage.read( + context=root_authed_ctx, uid=blob_entries[0].id + ) + assert all(data_big_retrieved.read() == data_big) From abcc997a4de04843679ba917e968d460f9b7297e Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Sun, 16 Jun 2024 22:40:32 +0530 Subject: [PATCH 123/309] add a stream upload API for blob storage - integrate proxy uid to stream upload data from client to domain via gateway --- packages/syft/src/syft/client/client.py | 30 ++++++++++++++++ packages/syft/src/syft/node/node.py | 3 +- packages/syft/src/syft/node/routes.py | 35 ++++++++++++++++++- .../src/syft/protocol/protocol_version.json | 12 +++++++ .../src/syft/store/blob_storage/seaweedfs.py | 17 ++++++--- 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index dc3fa447d58..0969c6cf4d9 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -4,6 +4,7 @@ # stdlib import base64 from collections.abc import Callable +from collections.abc import Generator from collections.abc import Iterator from copy import deepcopy from enum import Enum @@ -216,6 +217,35 @@ def _make_get( return response.content + def _make_put( + self, path: str, data: bytes | Generator, stream: bool = False + ) -> Response: + headers = {} + url = self.url + + if self.rathole_token: + url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) + headers = {"Host": self.url.host_or_ip} + + url = url.with_path(path) + response = self.session.put( + str(url), + verify=verify_tls(), + proxies={}, + data=data, + headers=headers, + stream=stream, + ) + if response.status_code != 200: + raise requests.ConnectionError( + f"Failed to fetch {url}. Response returned with code {response.status_code}" + ) + + # upgrade to tls if available + self.url = upgrade_tls(self.url, response) + + return response + def _make_post( self, path: str, diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 1e2c00c6f24..15508cdf184 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -117,6 +117,7 @@ from ..store.blob_storage import BlobStorageConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig +from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig from ..store.linked_obj import LinkedObject @@ -1181,7 +1182,7 @@ def forward_message( # relative from ..store.blob_storage import BlobRetrievalByURL - if isinstance(result, BlobRetrievalByURL): + if isinstance(result, BlobRetrievalByURL | SeaweedFSBlobDeposit): result.proxy_node_uid = peer.id return result diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index f32aa5c3bd9..8b91786af2c 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -1,6 +1,7 @@ # stdlib import base64 import binascii +from collections.abc import AsyncGenerator from typing import Annotated # third party @@ -62,7 +63,7 @@ def _get_node_connection(peer_uid: UID) -> NodeConnection: return connection @router.get("/stream/{peer_uid}/{url_path}/", name="stream") - async def stream(peer_uid: str, url_path: str) -> StreamingResponse: + async def stream_download(peer_uid: str, url_path: str) -> StreamingResponse: try: url_path_parsed = base64.urlsafe_b64decode(url_path.encode()).decode() except binascii.Error: @@ -79,6 +80,38 @@ async def stream(peer_uid: str, url_path: str) -> StreamingResponse: return StreamingResponse(stream_response, media_type="text/event-stream") + async def read_request_body_in_chunks( + request: Request, + ) -> AsyncGenerator[bytes, None]: + async for chunk in request.stream(): + yield chunk + + @router.put("/stream/{peer_uid}/{url_path}/", name="stream") + async def stream_upload(peer_uid: str, url_path: str, request: Request) -> Response: + try: + url_path_parsed = base64.urlsafe_b64decode(url_path.encode()).decode() + except binascii.Error: + raise HTTPException(404, "Invalid `url_path`.") + + data = await request.body() + + peer_uid_parsed = UID.from_string(peer_uid) + + try: + peer_connection = _get_node_connection(peer_uid_parsed) + url = peer_connection.to_blob_route(url_path_parsed) + + print("Url on stream", url.path) + response = peer_connection._make_put(url.path, data=data, stream=True) + except requests.RequestException: + raise HTTPException(404, "Failed to upload data to domain") + + return Response( + content=response.content, + headers=response.headers, + media_type="application/octet-stream", + ) + @router.get( "/", name="healthcheck", diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 98c6b4576ba..02ca42bc84d 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -308,6 +308,18 @@ "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", "action": "add" } + }, + "SeaweedFSBlobDeposit": { + "3": { + "version": 3, + "hash": "05e61e6328b085b738e5d41c0781d87852d44d218894cb3008f5be46e337f6d8", + "action": "remove" + }, + "4": { + "version": 4, + "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 1d88fedda37..b3667e61251 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -37,7 +37,8 @@ from ...types.blob_storage import SeaweedSecureFilePathLocation from ...types.blob_storage import SecureFilePathLocation from ...types.grid_url import GridURL -from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.uid import UID from ...util.constants import DEFAULT_TIMEOUT MAX_QUEUE_SIZE = 100 @@ -49,10 +50,11 @@ @serializable() class SeaweedFSBlobDeposit(BlobDeposit): __canonical_name__ = "SeaweedFSBlobDeposit" - __version__ = SYFT_OBJECT_VERSION_3 + __version__ = SYFT_OBJECT_VERSION_4 urls: list[GridURL] size: int + proxy_node_uid: UID | None = None def write(self, data: BytesIO) -> SyftSuccess | SyftError: # relative @@ -87,9 +89,14 @@ def write(self, data: BytesIO) -> SyftSuccess | SyftError: start=1, ): if api is not None and api.connection is not None: - blob_url = api.connection.to_blob_route( - url.url_path, host=url.host_or_ip - ) + if self.proxy_node_uid is None: + blob_url = api.connection.to_blob_route( + url.url_path, host=url.host_or_ip + ) + else: + blob_url = api.connection.stream_via( + self.proxy_node_uid, url.url_path + ) else: blob_url = url From b1f2fb0acd750a4baebc07a60723440abfcd5e99 Mon Sep 17 00:00:00 2001 From: teo Date: Tue, 18 Jun 2024 16:37:24 +0300 Subject: [PATCH 124/309] fix response repr --- packages/syft/src/syft/service/response.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 0d5c04c64ad..3eb6ff81fe7 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -40,9 +40,10 @@ def _repr_html_class_(self) -> str: return "alert-info" def _repr_html_(self) -> str: + msg = self.message.replace('\n', '
') return ( f'
' - + f"{type(self).__name__}: {self.message.replace("\n", "
")}

" + + f"{type(self).__name__}: {msg}
" ) From 4597798f8f280bb9b3a8f9a054c21c8fb0712330 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 19 Jun 2024 11:20:09 +0300 Subject: [PATCH 125/309] proper check on adding workers on node and settings migrations --- packages/syft/src/syft/node/node.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index eb7b6be0bbb..f06a5abb325 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -120,7 +120,7 @@ from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig -from ..types.syft_object import SYFT_OBJECT_VERSION_2 +from ..types.syft_object import SYFT_OBJECT_VERSION_2, Context from ..types.syft_object import SyftObject from ..types.uid import UID from ..util.experimental_flags import flags @@ -1591,6 +1591,15 @@ def create_initial_settings(self, admin_email: str) -> NodeSettings | None: settings_exists = settings_stash.get_all(self.signing_key.verify_key).ok() if settings_exists: node_settings = settings_exists[0] + if node_settings.__version__ != NodeSettings.__version__: + context = Context() + node_settings = node_settings.migrate_to(NodeSettings.__version__, context) + res = settings_stash.delete_by_uid(self.signing_key.verify_key, node_settings.id) + if res.is_err(): + raise Exception(res.value) + res = settings_stash.set(self.signing_key.verify_key, node_settings) + if res.is_err(): + raise Exception(res.value) self.name = node_settings.name self.association_request_auto_approval = ( node_settings.association_request_auto_approval @@ -1775,12 +1784,15 @@ def create_default_worker_pool(node: Node) -> SyftError | None: worker_to_add_ = max(default_worker_pool.max_count, worker_count) - len( default_worker_pool.worker_list ) - add_worker_method = node.get_service_method(SyftWorkerPoolService.add_workers) - result = add_worker_method( - context=context, - number=worker_to_add_, - pool_name=default_pool_name, - ) + if worker_to_add_ > 0: + add_worker_method = node.get_service_method(SyftWorkerPoolService.add_workers) + result = add_worker_method( + context=context, + number=worker_to_add_, + pool_name=default_pool_name, + ) + else: + return None if isinstance(result, SyftError): print(f"Default worker pool error. {result.message}") From 59de87223b8973188a5c3a51902509eeeaf598c1 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 19 Jun 2024 11:20:32 +0300 Subject: [PATCH 126/309] proper migration on settings --- packages/syft/src/syft/service/settings/settings.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index bfc9800c6bc..6fa54fc8f67 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -2,6 +2,8 @@ from collections.abc import Callable from typing import Any +from syft.util.util import get_env + # relative from ...abstract_node import NodeSideType from ...abstract_node import NodeType @@ -189,7 +191,10 @@ class NodeSettingsV2(SyftObject): @migrate(NodeSettingsV2, NodeSettings) def upgrade_node_settings() -> list[Callable]: - return [make_set_default("association_request_auto_approval", False)] + return [ + make_set_default("association_request_auto_approval", False), + make_set_default("default_worker_pool", get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME)) + ] @migrate(NodeSettings, NodeSettingsV2) @@ -204,4 +209,4 @@ def upgrade_node_settings_update() -> list[Callable]: @migrate(NodeSettings, NodeSettingsV2) def downgrade_node_settings_update() -> list[Callable]: - return [drop(["association_request_auto_approval"])] + return [drop(["association_request_auto_approval"]), drop(["default_worker_pool"])] From f059b496a7ec617beb90492f406a4045c072dc01 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 19 Jun 2024 11:21:03 +0300 Subject: [PATCH 127/309] added changes to migration service --- .../service/migration/migration_service.py | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 80a64ec8762..07e94564196 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -5,6 +5,7 @@ # stdlib # third party +from collections import defaultdict from result import Err from result import Ok from result import Result @@ -123,8 +124,9 @@ def get_migration_objects( self, context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, + get_all: bool = False, ) -> dict | SyftError: - res = self._get_migration_objects(context, document_store_object_types) + res = self._get_migration_objects(context, document_store_object_types, get_all) if res.is_err(): return SyftError(message=res.value) else: @@ -134,6 +136,7 @@ def _get_migration_objects( self, context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, + get_all: bool = False, ) -> Result[dict, str]: if document_store_object_types is None: document_store_object_types = [ @@ -141,16 +144,19 @@ def _get_migration_objects( for partition in self.store.partitions.values() ] - klasses_to_migrate = self._find_klasses_pending_for_migration( - context=context, object_types=document_store_object_types - ) + if get_all: + klasses_to_migrate = document_store_object_types + else: + klasses_to_migrate = self._find_klasses_pending_for_migration( + context=context, object_types=document_store_object_types + ) if klasses_to_migrate: print( f"Classes in Document Store that need migration: {klasses_to_migrate}" ) - result = {} + result = defaultdict(list) for klass in klasses_to_migrate: canonical_name = klass.__canonical_name__ @@ -163,7 +169,10 @@ def _get_migration_objects( if objects_result.is_err(): return objects_result objects = objects_result.ok() - result[klass] = objects + for object in objects: + actual_klass = type(object) + use_klass = klass if actual_klass.__canonical_name__ == klass.__canonical_name__ else actual_klass + result[use_klass].append(object) return Ok(result) @service_method( @@ -185,10 +194,23 @@ def _update_migrated_objects( ) -> Result[str, str]: for migrated_object in migrated_objects: klass = type(migrated_object) - canonical_name = klass.__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) + mro = klass.__mro__ + class_index = 0 + while len(mro) > class_index: + canonical_name = mro[class_index].__canonical_name__ + object_partition = self.store.partitions.get(canonical_name) + if object_partition is not None: + break + class_index += 1 + + # canonical_name = mro[class_index].__canonical_name__ + # object_partition = self.store.partitions.get(canonical_name) + + # print(klass, canonical_name, object_partition) qk = object_partition.settings.store_key.with_obj(migrated_object.id) - + # print(migrated_object) + import sys + result = object_partition._update( context.credentials, qk=qk, @@ -197,9 +219,12 @@ def _update_migrated_objects( overwrite=True, allow_missing_keys=True, ) - + if result.is_err(): - return result + print("ERR:", result.value, file=sys.stderr) + print("ERR:", klass, file=sys.stderr) + print("ERR:", migrated_object, file=sys.stderr) + # return result return Ok(value="success") @service_method( From f4a4b97c550864e6c021cc16eb973ab96a0a6c49 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 19 Jun 2024 13:50:22 +0300 Subject: [PATCH 128/309] helper methods to save to file --- .../syft/src/syft/client/domain_client.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 2a377485a3c..489455f0663 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -5,12 +5,14 @@ from pathlib import Path import re from string import Template -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List from typing import cast # third party from loguru import logger import markdown +from syft.serde import deserialize, serialize +from syft.types.syft_object import Context from tqdm import tqdm # relative @@ -384,6 +386,29 @@ def code_status(self) -> APIModule | None: @property def output(self) -> APIModule | None: return self._get_service_by_name_if_exists("output") + + def save_migration_objects_to_file(self, filename: str, get_all: bool = False) -> Dict[Any, Any] | SyftError: + migration_dict = self.api.services.migration.get_migration_objects(get_all=get_all) + if isinstance(migration_dict, SyftError): + return migration_dict + ser_bytes = serialize(migration_dict, to_bytes=True) + with open(filename, 'wb') as f: + f.write(ser_bytes) + return migration_dict + + + def migrate_objects_from_file(self, filename: str) -> SyftSuccess | SyftError: + with open(filename, 'rb') as f: + ser_bytes = f.read() + migration_dict = deserialize(ser_bytes, from_bytes=True) + context = Context() + migrated_objects = [] + for klass, objects in migration_dict.items(): + for obj in objects: + migrated_obj = obj.migrate_to(klass.__version__, context) + migrated_objects.append(migrated_obj) + res = self.api.services.migration.update_migrated_objects(migrated_objects) + return res def get_project( self, From 8d7f3894d042b56c31e69d224f52c38efa40059b Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 20 Jun 2024 14:37:48 +0800 Subject: [PATCH 129/309] Force deleting workers in notebook tests --- notebooks/api/0.8/10-container-images.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index b0bacf0295f..c77fce374ba 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -762,7 +762,7 @@ "outputs": [], "source": [ "worker_delete_res = domain_client.api.services.worker.delete(\n", - " uid=second_worker.id,\n", + " uid=second_worker.id, force=True\n", ")" ] }, From d6399c15673b649efe75d041c163029bcceaccad Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 20 Jun 2024 14:47:34 +0800 Subject: [PATCH 130/309] Fix using SyftWorker class instead of instance by mistake --- packages/syft/src/syft/service/worker/worker_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index b4bc5f36ba4..074871dd702 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -151,7 +151,7 @@ def _delete( context: AuthedServiceContext, worker: SyftWorker, ) -> SyftSuccess | SyftError: - uid = SyftWorker.id + uid = worker.id worker_pool_name = worker.worker_pool_name # relative From cfb9e7cfe866b037bf6e078c2f2cba80fd820af3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 20 Jun 2024 17:21:40 +0800 Subject: [PATCH 131/309] Only kill worker job if exists --- .../syft/src/syft/service/worker/worker_service.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 074871dd702..ac973f1f257 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -158,11 +158,12 @@ def _delete( from ...service.job.job_service import JobService from .worker_pool_service import SyftWorkerPoolService - job_service = cast(JobService, context.node.get_service(JobService)) - job_service._kill( - context=context, - job=job_service.get(context=context, uid=worker.job_id), - ) + if worker.job_id is not None: + job_service = cast(JobService, context.node.get_service(JobService)) + job_service._kill( + context=context, + job=job_service.get(context=context, uid=worker.job_id), + ) worker_pool_service = cast( SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService) From 0860713d5bf7ce635911b93eb74c9c01aa4b07a9 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 20 Jun 2024 17:22:33 +0800 Subject: [PATCH 132/309] Set SyftWorker._to_be_deleted whether or not deletion is forced --- packages/syft/src/syft/service/worker/worker_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index ac973f1f257..5395f0ca7cd 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -246,9 +246,9 @@ def delete( force: bool = False, ) -> SyftSuccess | SyftError: worker = self._get_worker(context=context, uid=uid) + worker._to_be_deleted = True if not force: - worker._to_be_deleted = True return SyftSuccess(message=f"Worker {uid} has been marked for deletion.") return self._delete(context, worker) From 801522902007967b70c431f1427542619c2d40f4 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 20 Jun 2024 18:20:07 +0800 Subject: [PATCH 133/309] Use JobService.kill instead of _kill --- packages/syft/src/syft/service/worker/worker_service.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 5395f0ca7cd..49ceaeea401 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -160,10 +160,7 @@ def _delete( if worker.job_id is not None: job_service = cast(JobService, context.node.get_service(JobService)) - job_service._kill( - context=context, - job=job_service.get(context=context, uid=worker.job_id), - ) + job_service.kill(context=context, id=worker.job_id) worker_pool_service = cast( SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService) From 52e904a5c06e81f4519f4591270117f24ed8ee4f Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 21 Jun 2024 15:43:27 +0800 Subject: [PATCH 134/309] Add a test --- .../syft/tests/syft/users/user_code_test.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 333d246e37f..9bc26995034 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -10,12 +10,20 @@ # syft absolute import syft as sy from syft.client.domain_client import DomainClient +from syft.node.worker import Worker from syft.service.action.action_object import ActionObject from syft.service.request.request import Request from syft.service.request.request import UserCodeStatusChange from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.service.user.user import User +from syft.service.user.user import UserUpdate +from syft.service.user.user_roles import ServiceRole + +# relative +from .user_test import ds_client as ds_client_fixture + +ds_client = ds_client_fixture # workaround some ruff quirks @sy.syft_function( @@ -39,6 +47,32 @@ def test_repr_markdown_not_throwing_error(guest_client: DomainClient) -> None: assert result[0]._repr_markdown_() +def test_new_admin_can_list_user_code( + worker: Worker, + ds_client: DomainClient, + faker: Faker, +) -> None: + root_client = worker.root_client + + project = sy.Project(name="", members=[ds_client]) + project.create_code_request(mock_syft_func, ds_client) + + email = faker.email() + pw = uuid.uuid4().hex + root_client.register( + name=faker.name(), email=email, password=pw, password_verify=pw + ) + + admin = root_client.login(email=email, password=pw) + + root_client.api.services.user.update( + admin.me.id, UserUpdate(role=ServiceRole.ADMIN) + ) + + assert len(root_client.code.get_all()) == len(admin.code.get_all()) + assert {c.id for c in root_client.code} == {c.id for c in admin.code} + + def test_user_code(worker) -> None: root_domain_client = worker.root_client root_domain_client.register( From 88edda1901b655b6b27773bf8bcb23112dade88a Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 21 Jun 2024 14:55:36 +0800 Subject: [PATCH 135/309] Fix new admin not able to list user code --- packages/syft/src/syft/service/code/user_code_service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 9302dd37404..d0ea943a9aa 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -265,7 +265,8 @@ def request_code_execution( @service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError: """Get a Dataset""" - result = self.stash.get_all(context.credentials) + has_permission = context.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) + result = self.stash.get_all(context.credentials, has_permission=has_permission) if result.is_ok(): return result.ok() return SyftError(message=result.err()) From f67af966fabf21ea658d13cd8a40c061b66539f3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 21 Jun 2024 18:11:34 +0800 Subject: [PATCH 136/309] Fix at the store partition level --- .../syft/service/code/user_code_service.py | 4 +-- .../syft/src/syft/store/kv_document_store.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index d0ea943a9aa..501a10c853d 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -264,9 +264,7 @@ def request_code_execution( @service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError: - """Get a Dataset""" - has_permission = context.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - result = self.stash.get_all(context.credentials, has_permission=has_permission) + result = self.stash.get_all(context.credentials) if result.is_ok(): return result.ok() return SyftError(message=result.err()) diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index b594be92775..7f86db0cda9 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -101,6 +101,20 @@ def init_store(self) -> Result[Ok, Err]: if store_status.is_err(): return store_status + store = self.store_config.store_type( + node_uid=self.node_uid, + root_verify_key=self.root_verify_key, + store_config=self.store_config, + ) + + # relative + from ..service.user.user import User + from ..service.user.user_stash import UserStash + + self.__user_stash = ( + UserStash(store=store) if self.settings.object_type is not User else None + ) + try: self.data = self.store_config.backing_store( "data", self.settings, self.store_config @@ -277,6 +291,21 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: ): return True + if permission.credentials and self.__user_stash is not None: + # relative + from ..service.user.user_roles import ServiceRole + + res = self.__user_stash.get_by_verify_key( + credentials=permission.credentials, + verify_key=permission.credentials, + ) + if ( + res.is_ok() + and (user := res.ok()) is not None + and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) + ): + return True + if ( permission.uid in self.permissions and permission.permission_string in self.permissions[permission.uid] From 969f0c3d609bba1dbad6b02f2ecba5a42497b2ad Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 21 Jun 2024 21:14:34 +0800 Subject: [PATCH 137/309] Add the admin permission check as a callback passed from DocumentStore so that it works for all partition types, including sqlite and mongo Co-authored-by: Aziz Berkay Yesilyurt --- .../syft/src/syft/store/document_store.py | 30 +++++++++++++++++ .../syft/src/syft/store/kv_document_store.py | 32 +++---------------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index fea96e6d456..63a78ccca62 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -309,6 +309,7 @@ def __init__( root_verify_key: SyftVerifyKey | None, settings: PartitionSettings, store_config: StoreConfig, + has_admin_permissions: Callable[[SyftVerifyKey], bool], ) -> None: if root_verify_key is None: root_verify_key = SyftSigningKey.generate().verify_key @@ -316,6 +317,7 @@ def __init__( self.root_verify_key = root_verify_key self.settings = settings self.store_config = store_config + self.has_admin_permissions = has_admin_permissions res = self.init_store() if res.is_err(): raise RuntimeError( @@ -578,6 +580,33 @@ def __init__( self.node_uid = node_uid self.root_verify_key = root_verify_key + def __has_admin_permissions( + self, settings: PartitionSettings + ) -> Callable[[SyftVerifyKey], bool]: + # relative + from ..service.user.user import User + from ..service.user.user_roles import ServiceRole + from ..service.user.user_stash import UserStash + + if settings.object_type is User: + return lambda x: False + + user_stash = UserStash(store=self) + + def has_admin_permissions(credentials: SyftVerifyKey) -> bool: + res = user_stash.get_by_verify_key( + credentials=credentials, + verify_key=credentials, + ) + + return ( + res.is_ok() + and (user := res.ok()) is not None + and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) + ) + + return has_admin_permissions + def partition(self, settings: PartitionSettings) -> StorePartition: if settings.name not in self.partitions: self.partitions[settings.name] = self.partition_type( @@ -585,6 +614,7 @@ def partition(self, settings: PartitionSettings) -> StorePartition: root_verify_key=self.root_verify_key, settings=settings, store_config=self.store_config, + has_admin_permissions=self.__has_admin_permissions(settings), ) return self.partitions[settings.name] diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 7f86db0cda9..2f880bd63ae 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -101,20 +101,6 @@ def init_store(self) -> Result[Ok, Err]: if store_status.is_err(): return store_status - store = self.store_config.store_type( - node_uid=self.node_uid, - root_verify_key=self.root_verify_key, - store_config=self.store_config, - ) - - # relative - from ..service.user.user import User - from ..service.user.user_stash import UserStash - - self.__user_stash = ( - UserStash(store=store) if self.settings.object_type is not User else None - ) - try: self.data = self.store_config.backing_store( "data", self.settings, self.store_config @@ -291,20 +277,10 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: ): return True - if permission.credentials and self.__user_stash is not None: - # relative - from ..service.user.user_roles import ServiceRole - - res = self.__user_stash.get_by_verify_key( - credentials=permission.credentials, - verify_key=permission.credentials, - ) - if ( - res.is_ok() - and (user := res.ok()) is not None - and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - ): - return True + if permission.credentials and self.has_admin_permissions( + permission.credentials + ): + return True if ( permission.uid in self.permissions From aaebc29ef24228b07ca3dda1ed66034b715b1f16 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Fri, 21 Jun 2024 22:24:07 +0800 Subject: [PATCH 138/309] Nitpicks --- packages/syft/src/syft/store/document_store.py | 5 ++++- packages/syft/src/syft/store/kv_document_store.py | 1 - packages/syft/tests/syft/users/user_code_test.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 63a78ccca62..5c05bb7dcd5 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -588,8 +588,11 @@ def __has_admin_permissions( from ..service.user.user_roles import ServiceRole from ..service.user.user_stash import UserStash + # leave out UserStash to avoid recursion + # TODO: pass the callback from BaseStash instead of DocumentStore + # so that this works with UserStash after the sqlite thread fix is merged if settings.object_type is User: - return lambda x: False + return lambda credentials: False user_stash = UserStash(store=self) diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 2f880bd63ae..11c2ae3d137 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -270,7 +270,6 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: if not isinstance(permission.permission, ActionPermission): raise Exception(f"ObjectPermission type: {permission.permission} not valid") - # TODO: fix for other admins if ( permission.credentials and self.root_verify_key.verify == permission.credentials.verify diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 9bc26995034..13aa6fadc28 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -23,7 +23,7 @@ # relative from .user_test import ds_client as ds_client_fixture -ds_client = ds_client_fixture # workaround some ruff quirks +ds_client = ds_client_fixture # work around some ruff quirks @sy.syft_function( From 1ebd1893d3f17ab1656176ba0a0cfc1c42becff9 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Sat, 22 Jun 2024 00:05:54 +0800 Subject: [PATCH 139/309] Make StorePartition.has_admin_permissions backward-compatible --- packages/syft/src/syft/store/document_store.py | 2 +- packages/syft/src/syft/store/kv_document_store.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 5c05bb7dcd5..9ad56408d99 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -309,7 +309,7 @@ def __init__( root_verify_key: SyftVerifyKey | None, settings: PartitionSettings, store_config: StoreConfig, - has_admin_permissions: Callable[[SyftVerifyKey], bool], + has_admin_permissions: Callable[[SyftVerifyKey], bool] | None = None, ) -> None: if root_verify_key is None: root_verify_key = SyftSigningKey.generate().verify_key diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 11c2ae3d137..1c60c194b47 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -276,8 +276,10 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: ): return True - if permission.credentials and self.has_admin_permissions( + if ( permission.credentials + and self.has_admin_permissions is not None + and self.has_admin_permissions(permission.credentials) ): return True From 64861a0a298d4c5906f96a20f67a255e8b928dc3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Sat, 22 Jun 2024 00:16:20 +0800 Subject: [PATCH 140/309] Add the admin permission fix for mongo store --- packages/syft/src/syft/store/mongo_document_store.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index 234dd2c723b..64a3b8d3b26 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -479,13 +479,19 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: if permissions is None: return False - # TODO: fix for other admins if ( permission.credentials and self.root_verify_key.verify == permission.credentials.verify ): return True + if ( + permission.credentials + and self.has_admin_permissions is not None + and self.has_admin_permissions(permission.credentials) + ): + return True + if permission.permission_string in permissions["permissions"]: return True From c9b5c3e3986f2a67ac35e634b15f07cbc6725643 Mon Sep 17 00:00:00 2001 From: dk Date: Mon, 24 Jun 2024 12:05:54 +0700 Subject: [PATCH 141/309] [syft/action_obj] - `_save_to_blob_store` returns a SyftWarning if the object is small and not saved to the blob store - passing flags telling action service to not clear the cache data of these small objects --- packages/syft/src/syft/client/domain_client.py | 13 +++++++++++-- .../src/syft/service/action/action_object.py | 18 ++++++++++++++---- .../src/syft/service/action/action_service.py | 13 ++++++++++++- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index ca9ee53ce22..16a0170b556 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -25,6 +25,7 @@ from ..service.dataset.dataset import CreateDataset from ..service.response import SyftError from ..service.response import SyftSuccess +from ..service.response import SyftWarning from ..service.sync.diff_state import ResolvedSyncState from ..service.sync.sync_state import SyncState from ..service.user.roles import Roles @@ -128,7 +129,7 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: ) as pbar: for asset in dataset.asset_list: try: - contains_empty = asset.contains_empty() + contains_empty: bool = asset.contains_empty() twin = TwinObject( private_obj=ActionObject.from_obj(asset.data), mock_obj=ActionObject.from_obj(asset.mock), @@ -142,8 +143,16 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: tqdm.write(f"Failed to create twin for {asset.name}. {e}") return SyftError(message=f"Failed to create twin. {e}") + if isinstance(res, SyftWarning): + print(res.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False response = self.api.services.action.set( - twin, ignore_detached_objs=contains_empty + twin, + ignore_detached_objs=contains_empty, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, ) if isinstance(response, SyftError): tqdm.write(f"Failed to upload asset: {asset.name}") diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 2f04d4a8295..7749410db28 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -36,7 +36,10 @@ from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize from ...service.blob_storage.util import can_upload_to_blob_storage +from ...service.blob_storage.util import min_size_for_blob_storage_upload from ...service.response import SyftError +from ...service.response import SyftSuccess +from ...service.response import SyftWarning from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel from ...types.datetime import DateTime @@ -844,7 +847,9 @@ def _set_reprs(self, data: any) -> None: ) self.syft_action_data_str_ = truncate_str(str(data)) - def _save_to_blob_storage(self, allow_empty: bool = False) -> SyftError | None: + def _save_to_blob_storage( + self, allow_empty: bool = False + ) -> SyftError | SyftSuccess | SyftWarning: data = self.syft_action_data if isinstance(data, SyftError): return data @@ -867,15 +872,20 @@ def _save_to_blob_storage(self, allow_empty: bool = False) -> SyftError | None: return result if not TraceResultRegistry.current_thread_is_tracing(): self._clear_cache() - return None + return SyftSuccess( + message=f"Saved action object {self.id} to the blob store" + ) except Exception as e: print( f"Failed to save action object {self.id} to the blob store. Error: {e}" ) self.syft_action_data_cache = data - - return None + return SyftWarning( + message=f"The action object {self.id} was not saved to " + f"the blob store but to memory cache since it is " + f"smaller than {min_size_for_blob_storage_upload(api.metadata)} Mb." + ) def _clear_cache(self) -> None: self.syft_action_data_cache = self.as_empty_data() diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index c680c7003f7..6ecb9d97b6c 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -84,6 +84,8 @@ def set( action_object: ActionObject | TwinObject, add_storage_permission: bool = True, ignore_detached_objs: bool = False, + skip_clear_cache: bool = False, + skip_save_to_blob_store: bool = False, ) -> ActionObject | SyftError: res = self._set( context, @@ -91,6 +93,8 @@ def set( has_result_read_permission=True, add_storage_permission=add_storage_permission, ignore_detached_objs=ignore_detached_objs, + skip_clear_cache=skip_clear_cache, + skip_save_to_blob_store=skip_save_to_blob_store, ) if res.is_err(): return SyftError(message=res.value) @@ -102,6 +106,9 @@ def is_detached_obj( action_object: ActionObject | TwinObject, ignore_detached_obj: bool = False, ) -> bool: + """ + A detached object is an object that is not yet saved to the blob storage. + """ if ( isinstance(action_object, TwinObject) and ( @@ -125,8 +132,12 @@ def _set( add_storage_permission: bool = True, ignore_detached_objs: bool = False, skip_clear_cache: bool = False, + skip_save_to_blob_store: bool = False, ) -> Result[ActionObject, str]: - if self.is_detached_obj(action_object, ignore_detached_objs): + if ( + self.is_detached_obj(action_object, ignore_detached_objs) + and not skip_save_to_blob_store + ): return Err( "you uploaded an ActionObject that is not yet in the blob storage" ) From 033365ed323437cb0bbfce6ff1e63caa751ec438 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 13:47:47 +0700 Subject: [PATCH 142/309] [syft/action_obj] set `skip_save_to_blob_stores` and `skip_clear_cache` in `ActionObject.send` --- .../syft/src/syft/service/action/action_object.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 7749410db28..d0366fd0f35 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1237,8 +1237,17 @@ def _send( api = self._get_api() if isinstance(api, SyftError): return api + + if isinstance(blob_storage_res, SyftWarning): + print(blob_storage_res.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False res = api.services.action.set( - self, add_storage_permission=add_storage_permission + self, + add_storage_permission=add_storage_permission, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, ) if isinstance(res, ActionObject): self.syft_created_at = res.syft_created_at From 52751ed6b44c9b169d565b6b1021e05a390aef1f Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 24 Jun 2024 14:49:08 +0800 Subject: [PATCH 143/309] Fix call to _syft_worker Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/queue/zmq_queue.py | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index ef9381fec72..e4368ee9882 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -133,12 +133,7 @@ def reset_expiry(self) -> None: def _syft_worker( self, stash: WorkerStash, credentials: SyftVerifyKey ) -> Result[SyftWorker | None, str]: - return stash.get_by_uid(credentials == credentials, uid=self.syft_worker_id) - - def _to_be_deleted(self, stash: WorkerStash, credentials: SyftVerifyKey) -> bool: - return self._syft_worker(stash, credentials).map_or( - False, lambda x: x is not None and x._to_be_deleted - ) + return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) @serializable() @@ -431,11 +426,12 @@ def purge_workers(self) -> None: """ # work on a copy of the iterator for worker in self.waiting: - if worker.has_expired() or ( - to_be_deleted := worker._to_be_deleted( - self.worker_stash, self.auth_context.credentials - ) - ): + res = worker._syft_worker(self.worker_stash, self.auth_context.credentials) + if res.is_err() or (syft_worker := res.ok()) is None: + logger.info("Failed to retrieve SyftWorker {worker.syft_worker_id}") + continue + + if worker.has_expired() or syft_worker._to_be_deleted: logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", worker.identity, @@ -443,7 +439,15 @@ def purge_workers(self) -> None: worker.get_expiry(), Timeout.now(), ) - self.delete_worker(worker, to_be_deleted) + self.delete_worker(worker, syft_worker._to_be_deleted) + + # relative + from ...service.worker.worker_service import WorkerService + + worker_service = cast( + WorkerService, self.auth_context.node.get_service(WorkerService) + ) + worker_service._delete(self.auth_context, syft_worker) def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState @@ -654,17 +658,6 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: worker.syft_worker_id, ConsumerState.DETACHED ) - # relative - from ...service.worker.worker_service import WorkerService - - worker_service = cast( - WorkerService, self.auth_context.node.get_service(WorkerService) - ) - worker_service._delete( - self.auth_context, - worker._syft_worker(self.worker_stash, self.auth_context.credentials), - ) - @property def alive(self) -> bool: return not self.socket.closed From 4b2ec09583fc59c110bbe6d71716f713d9e9fef9 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 14:07:56 +0700 Subject: [PATCH 144/309] [syft/action_obj] set skip_save_to_blob_stores and `skip_clear_cache` in `ActionObject.execute` --- .../syft/src/syft/service/action/action_service.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 6ecb9d97b6c..79a453f4452 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -23,6 +23,7 @@ from ..policy.policy import retrieve_from_db from ..response import SyftError from ..response import SyftSuccess +from ..response import SyftWarning from ..service import AbstractService from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE @@ -794,8 +795,17 @@ def execute( context.extra_kwargs = { "has_result_read_permission": has_result_read_permission } - - set_result = self._set(context, result_action_object) + if isinstance(blob_store_result, SyftWarning): + print(blob_store_result.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False + set_result = self._set( + context, + result_action_object, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, + ) if set_result.is_err(): return Err( f"Failed executing action {action}, set result is an error: {set_result.err()}" From 92c4843d905ede2f8ebc20dc890e1bf222777366 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 14:15:52 +0700 Subject: [PATCH 145/309] [syft/twin_obj] pass `skip_save_to_blob_stores` and `skip_clear_cache` to `ActionService.set` for `TwinObject.send` --- packages/syft/src/syft/types/twin_object.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index f94d75744d6..82eae3590b8 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -17,6 +17,8 @@ from ..service.action.action_object import TwinMode from ..service.action.action_types import action_types from ..service.response import SyftError +from ..service.response import SyftSuccess +from ..service.response import SyftWarning from ..types.syft_object import SYFT_OBJECT_VERSION_2 from .syft_object import SyftObject from .uid import UID @@ -82,7 +84,9 @@ def mock(self) -> ActionObject: mock.id = twin_id return mock - def _save_to_blob_storage(self, allow_empty: bool = False) -> SyftError | None: + def _save_to_blob_storage( + self, allow_empty: bool = False + ) -> SyftError | SyftSuccess | SyftWarning: # Set node location and verify key self.private_obj._set_obj_location_( self.syft_node_location, @@ -99,8 +103,16 @@ def _save_to_blob_storage(self, allow_empty: bool = False) -> SyftError | None: def send(self, client: SyftClient, add_storage_permission: bool = True) -> Any: self._set_obj_location_(client.id, client.verify_key) - self._save_to_blob_storage() + blob_store_result = self._save_to_blob_storage() + if isinstance(blob_store_result, SyftWarning): + print(blob_store_result.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False res = client.api.services.action.set( - self, add_storage_permission=add_storage_permission + self, + add_storage_permission=add_storage_permission, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, ) return res From 3827313bc289a0fc325db785f26a2565426d1abc Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 24 Jun 2024 15:15:38 +0800 Subject: [PATCH 146/309] Don't kill job in case of graceful worker deletion Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/worker/worker_service.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 49ceaeea401..689d6f8b94e 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -147,9 +147,7 @@ def logs( return logs if raw else logs.decode(errors="ignore") def _delete( - self, - context: AuthedServiceContext, - worker: SyftWorker, + self, context: AuthedServiceContext, worker: SyftWorker, force: bool = False ) -> SyftSuccess | SyftError: uid = worker.id worker_pool_name = worker.worker_pool_name @@ -158,7 +156,7 @@ def _delete( from ...service.job.job_service import JobService from .worker_pool_service import SyftWorkerPoolService - if worker.job_id is not None: + if force and worker.job_id is not None: job_service = cast(JobService, context.node.get_service(JobService)) job_service.kill(context=context, id=worker.job_id) @@ -201,7 +199,7 @@ def _delete( if isinstance(docker_container, SyftError): return docker_container - stopped = _stop_worker_container(worker, docker_container, force=True) + stopped = _stop_worker_container(worker, docker_container, force=force) if stopped is not None: return stopped else: @@ -246,9 +244,10 @@ def delete( worker._to_be_deleted = True if not force: + # relative return SyftSuccess(message=f"Worker {uid} has been marked for deletion.") - return self._delete(context, worker) + return self._delete(context, worker, force=True) def _get_worker( self, context: AuthedServiceContext, uid: UID From f4c9c39d146c7edab6e6f248ddc9d2b501595150 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 15:04:17 +0700 Subject: [PATCH 147/309] [syft/twin_obj] pass skip_save_to_blob_stores and `skip_clear_cache` to `ActionService.set` for `ActionService.set_result_to_store` --- .../syft/src/syft/service/action/action_service.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 79a453f4452..26019d632f7 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -516,6 +516,11 @@ def set_result_to_store( blob_store_result = result_action_object._save_to_blob_storage() if isinstance(blob_store_result, SyftError): return Err(blob_store_result.message) + if isinstance(blob_store_result, SyftWarning): + print(blob_store_result.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False # IMPORTANT: DO THIS ONLY AFTER ._save_to_blob_storage if isinstance(result_action_object, TwinObject): @@ -528,7 +533,11 @@ def set_result_to_store( # Since this just meta data about the result, they always have access to it. set_result = self._set( - context, result_action_object, has_result_read_permission=True + context, + result_action_object, + has_result_read_permission=True, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, ) if set_result.is_err(): From bbecc0743adcc47feb410eee0848d8970778ba9e Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 15:16:24 +0700 Subject: [PATCH 148/309] [syft/twin_obj] pass skip_save_to_blob_stores and `skip_clear_cache` in the rest of the code base - unit tests passed --- .../syft/src/syft/service/action/action_object.py | 11 ++++++++++- .../syft/src/syft/service/action/action_service.py | 12 +++++++++++- packages/syft/src/syft/service/dataset/dataset.py | 8 +++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index d0366fd0f35..036dc99fea6 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -523,7 +523,16 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: r = arg._save_to_blob_storage() if isinstance(r, SyftError): print(r.message) - arg = api.services.action.set(arg) + if isinstance(r, SyftWarning): + print(r.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False + arg = api.services.action.set( + arg, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, + ) return arg arg_list = [process_arg(arg) for arg in args] if args else [] diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 26019d632f7..f7ca2d78efa 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -70,8 +70,18 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: blob_store_result = np_obj._save_to_blob_storage() if isinstance(blob_store_result, SyftError): return blob_store_result + if isinstance(blob_store_result, SyftWarning): + print(blob_store_result.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False - np_pointer = self._set(context, np_obj) + np_pointer = self._set( + context, + np_obj, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, + ) return np_pointer @service_method( diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 9dde84429c4..a345aea94bd 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -709,7 +709,11 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: res = twin._save_to_blob_storage(allow_empty=contains_empty) if isinstance(res, SyftError): raise ValueError(res.message) - + if isinstance(res, SyftWarning): + print(res.message) + skip_save_to_blob_store, skip_clear_cache = True, True + else: + skip_save_to_blob_store, skip_clear_cache = False, False # TODO, upload to blob storage here if context.node is None: raise ValueError( @@ -719,6 +723,8 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: result = action_service._set( context=context.to_node_context(), action_object=twin, + skip_save_to_blob_store=skip_save_to_blob_store, + skip_clear_cache=skip_clear_cache, ) if result.is_err(): raise RuntimeError(f"Failed to create and store twin. Error: {result}") From 1e2f9602d9fc1c83a838e852dc1021a9d43e4d56 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 24 Jun 2024 16:32:21 +0700 Subject: [PATCH 149/309] [syft] fix lint. Change using print "not saving action objs to blob store" to logger.debug --- packages/syft/src/syft/client/domain_client.py | 2 +- packages/syft/src/syft/service/action/action_object.py | 9 ++++----- packages/syft/src/syft/service/action/action_service.py | 7 ++++--- packages/syft/src/syft/service/dataset/dataset.py | 3 ++- packages/syft/src/syft/types/twin_object.py | 3 ++- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 16a0170b556..70543bd3649 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -144,7 +144,7 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: return SyftError(message=f"Failed to create twin. {e}") if isinstance(res, SyftWarning): - print(res.message) + logger.debug(res.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 036dc99fea6..58312bb822b 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING # third party +from loguru import logger from pydantic import ConfigDict from pydantic import Field from pydantic import field_validator @@ -36,7 +37,6 @@ from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize from ...service.blob_storage.util import can_upload_to_blob_storage -from ...service.blob_storage.util import min_size_for_blob_storage_upload from ...service.response import SyftError from ...service.response import SyftSuccess from ...service.response import SyftWarning @@ -524,7 +524,7 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: if isinstance(r, SyftError): print(r.message) if isinstance(r, SyftWarning): - print(r.message) + logger.debug(r.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False @@ -892,8 +892,7 @@ def _save_to_blob_storage( self.syft_action_data_cache = data return SyftWarning( message=f"The action object {self.id} was not saved to " - f"the blob store but to memory cache since it is " - f"smaller than {min_size_for_blob_storage_upload(api.metadata)} Mb." + f"the blob store but to memory cache since it is small." ) def _clear_cache(self) -> None: @@ -1248,7 +1247,7 @@ def _send( return api if isinstance(blob_storage_res, SyftWarning): - print(blob_storage_res.message) + logger.debug(blob_storage_res.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index f7ca2d78efa..419a6f1f42d 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -3,6 +3,7 @@ from typing import Any # third party +from loguru import logger import numpy as np from result import Err from result import Ok @@ -71,7 +72,7 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: if isinstance(blob_store_result, SyftError): return blob_store_result if isinstance(blob_store_result, SyftWarning): - print(blob_store_result.message) + logger.debug(blob_store_result.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False @@ -527,7 +528,7 @@ def set_result_to_store( if isinstance(blob_store_result, SyftError): return Err(blob_store_result.message) if isinstance(blob_store_result, SyftWarning): - print(blob_store_result.message) + logger.debug(blob_store_result.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False @@ -815,7 +816,7 @@ def execute( "has_result_read_permission": has_result_read_permission } if isinstance(blob_store_result, SyftWarning): - print(blob_store_result.message) + logger.debug(blob_store_result.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index a345aea94bd..d583b5d28d7 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -8,6 +8,7 @@ # third party from IPython.display import display import itables +from loguru import logger import pandas as pd from pydantic import ConfigDict from pydantic import field_validator @@ -710,7 +711,7 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: if isinstance(res, SyftError): raise ValueError(res.message) if isinstance(res, SyftWarning): - print(res.message) + logger.debug(res.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index 82eae3590b8..ba9b94b30ab 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -6,6 +6,7 @@ from typing import ClassVar # third party +from loguru import logger from pydantic import field_validator from pydantic import model_validator from typing_extensions import Self @@ -105,7 +106,7 @@ def send(self, client: SyftClient, add_storage_permission: bool = True) -> Any: self._set_obj_location_(client.id, client.verify_key) blob_store_result = self._save_to_blob_storage() if isinstance(blob_store_result, SyftWarning): - print(blob_store_result.message) + logger.debug(blob_store_result.message) skip_save_to_blob_store, skip_clear_cache = True, True else: skip_save_to_blob_store, skip_clear_cache = False, False From c335d5374cf10b178a52bb8cf545e64c882a01dc Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 24 Jun 2024 18:57:19 +0530 Subject: [PATCH 150/309] add a test for reverse tunnel - cleanup rathole start script - use slim version of python in rathole dockerfile --- .../rathole/rathole-statefulset.yaml | 4 +- packages/grid/helm/syft/values.yaml | 1 + packages/grid/helm/values.dev.yaml | 1 + packages/grid/rathole/rathole.dockerfile | 32 +---------- packages/grid/rathole/start.sh | 54 +++++++++++++------ tests/integration/network/gateway_test.py | 53 ++++++++++++++++++ 6 files changed, 97 insertions(+), 48 deletions(-) diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index 0f07516e352..c441991ef47 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -38,8 +38,8 @@ spec: env: - name: SERVICE_NAME value: "rathole" - - name: APP_LOG_LEVEL - value: {{ .Values.rathole.appLogLevel | quote }} + - name: LOG_LEVEL + value: {{ .Values.rathole.logLevel | quote }} - name: MODE value: {{ .Values.rathole.mode | quote }} - name: DEV_MODE diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 8cd78c68e89..32f52b3bd25 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -240,6 +240,7 @@ rathole: # Extra environment vars env: null enabled: true + logLevel: info port: 2333 mode: client diff --git a/packages/grid/helm/values.dev.yaml b/packages/grid/helm/values.dev.yaml index 97ccbe4dfc7..713948f4a9b 100644 --- a/packages/grid/helm/values.dev.yaml +++ b/packages/grid/helm/values.dev.yaml @@ -49,6 +49,7 @@ proxy: rathole: enabled: "true" + logLevel: "trace" # attestation: # enabled: true # resourcesPreset: null diff --git a/packages/grid/rathole/rathole.dockerfile b/packages/grid/rathole/rathole.dockerfile index 42d147527c7..1dee47a9411 100644 --- a/packages/grid/rathole/rathole.dockerfile +++ b/packages/grid/rathole/rathole.dockerfile @@ -10,9 +10,10 @@ RUN git clone -b v${RATHOLE_VERSION} https://github.com/rapiz1/rathole WORKDIR /rathole RUN cargo build --locked --release --features ${FEATURES:-default} -FROM python:${PYTHON_VERSION}-bookworm +FROM python:${PYTHON_VERSION}-slim-bookworm ARG RATHOLE_VERSION ENV MODE="client" +ENV LOG_LEVEL="info" RUN apt update && apt install -y netcat-openbsd vim rsync COPY --from=build /rathole/target/release/rathole /app/rathole @@ -23,32 +24,3 @@ EXPOSE 2333/udp EXPOSE 2333 CMD ["sh", "-c", "/app/start.sh"] - - -# build and run a fake domain to simulate a normal http container service -# docker build -f domain.dockerfile . -t domain -# docker run --name domain1 -it -d -p 8080:8000 domain - - - -# check the web server is running on 8080 -# curl localhost:8080 - -# build and run the rathole container -# docker build -f rathole.dockerfile . -t rathole - -# run the rathole server -# docker run --add-host host.docker.internal:host-gateway --name rathole-server -it -p 8001:8001 -p 8002:8002 -p 2333:2333 -e MODE=server rathole - -# check nothing is on port 8001 yet -# curl localhost:8001 - -# run the rathole client -# docker run --add-host host.docker.internal:host-gateway --name rathole-client -it -e MODE=client rathole - -# try port 8001 now -# curl localhost:8001 - -# add another client and edit the server.toml and client.toml for port 8002 - - diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index 0e708908836..b1af50597fc 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -1,30 +1,52 @@ #!/usr/bin/env bash + MODE=${MODE:-server} +RUST_LOG=${LOG_LEVEL:-trace} -cp -L -r -f /conf/* conf/ +# Copy configuration files +copy_config() { + cp -L -r -f /conf/* conf/ +} -if [[ $MODE == "server" ]]; then - RUST_LOG=trace /app/rathole conf/server.toml & -elif [[ $MODE = "client" ]]; then +# Start the server +start_server() { + RUST_LOG=$RUST_LOG /app/rathole conf/server.toml & +} + +# Start the client +start_client() { while true; do - RUST_LOG=trace /app/rathole conf/client.toml + RUST_LOG=$RUST_LOG /app/rathole conf/client.toml status=$? if [ $status -eq 0 ]; then - break + break else - echo "Failed to load client.toml, retrying in 5 seconds..." - sleep 10 + echo "Failed to load client.toml, retrying in 5 seconds..." + sleep 10 fi done & +} + +# Reload configuration every 10 seconds +reload_config() { + echo "Starting configuration reload loop..." + while true; do + copy_config + sleep 10 + done +} + +# Make an initial copy of the configuration +copy_config + +if [[ $MODE == "server" ]]; then + start_server +elif [[ $MODE == "client" ]]; then + start_client else echo "RATHOLE MODE is set to an invalid value. Exiting." + exit 1 fi -# reload config every 10 seconds -while true -do - # Execute your script here - cp -L -r -f /conf/* conf/ - # Sleep for 10 seconds - sleep 10 -done \ No newline at end of file +# Start the configuration reload in the background to keep the configuration up to date +reload_config diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index 9c42a9e9687..1e09c2a775e 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -912,3 +912,56 @@ def test_peer_health_check(set_env_var, gateway_port: int, domain_1_port: int) - # Remove existing peers assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_reverse_tunnel_connection(domain_1_port: int, gateway_port: int): + # login to the domain and gateway + + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + res = gateway_client.settings.allow_association_request_auto_approval(enable=False) + + # Try removing existing peers just to make sure + _remove_existing_peers(domain_client) + _remove_existing_peers(gateway_client) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client, reverse_tunnel=True) + + assert isinstance(result, Request) + assert isinstance(result.changes[0], AssociationRequestChange) + + assert len(domain_client.peers) == 1 + + # Domain's peer is a gateway and vice-versa + domain_peer = domain_client.peers[0] + assert domain_peer.node_type == NodeType.GATEWAY + assert domain_peer.node_routes[0].rathole_token is None + assert len(gateway_client.peers) == 0 + + gateway_client_root = gateway_client.login( + email="info@openmined.org", password="changethis" + ) + res = gateway_client_root.api.services.request.get_all()[-1].approve() + assert not isinstance(res, SyftError) + + time.sleep(90) + + gateway_peers = gateway_client.api.services.network.get_all_peers() + assert len(gateway_peers) == 1 + assert len(gateway_peers[0].node_routes) == 1 + assert gateway_peers[0].node_routes[0].rathole_token is not None + + proxy_domain_client = gateway_client.peers[0] + + assert isinstance(proxy_domain_client, DomainClient) + assert isinstance(domain_peer, NodePeer) + assert gateway_client.name == domain_peer.name + assert domain_client.name == proxy_domain_client.name + + assert not isinstance(proxy_domain_client.datasets.get_all(), SyftError) From 39fe4a6cedd4ed3ba9439663b45a26614aa5aa47 Mon Sep 17 00:00:00 2001 From: dk Date: Tue, 25 Jun 2024 09:57:22 +0700 Subject: [PATCH 151/309] fix logger --- packages/syft/src/syft/service/dataset/dataset.py | 3 ++- packages/syft/src/syft/types/twin_object.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index d583b5d28d7..19352d26377 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -2,13 +2,13 @@ from collections.abc import Callable from datetime import datetime from enum import Enum +import logging import textwrap from typing import Any # third party from IPython.display import display import itables -from loguru import logger import pandas as pd from pydantic import ConfigDict from pydantic import field_validator @@ -50,6 +50,7 @@ from ..response import SyftWarning NamePartitionKey = PartitionKey(key="name", type_=str) +logger = logging.getLogger(__name__) @serializable() diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index ba9b94b30ab..b5f7c90e42c 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -2,11 +2,11 @@ from __future__ import annotations # stdlib +import logging from typing import Any from typing import ClassVar # third party -from loguru import logger from pydantic import field_validator from pydantic import model_validator from typing_extensions import Self @@ -24,6 +24,8 @@ from .syft_object import SyftObject from .uid import UID +logger = logging.getLogger(__name__) + def to_action_object(obj: Any) -> ActionObject: if isinstance(obj, ActionObject): From 18ad21b1278ac776c94c9138aba2f46ba75fdf3c Mon Sep 17 00:00:00 2001 From: dk Date: Tue, 25 Jun 2024 10:05:19 +0700 Subject: [PATCH 152/309] fix logger in action_service.py --- packages/syft/src/syft/service/action/action_service.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 419a6f1f42d..50504c34301 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -1,9 +1,9 @@ # stdlib import importlib +import logging from typing import Any # third party -from loguru import logger import numpy as np from result import Err from result import Ok @@ -49,6 +49,8 @@ from .pandas import PandasDataFrameObject # noqa: F401 from .pandas import PandasSeriesObject # noqa: F401 +logger = logging.getLogger(__name__) + @serializable() class ActionService(AbstractService): From 8e88b48592757e2c230f488e5ad10e64f1363a13 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 25 Jun 2024 22:41:33 +0530 Subject: [PATCH 153/309] modulizer rathole config builder to a seperate class - have a more generalize class for reverse tunnel service - define methods to remove config from rathole client and server toml and configmaps --- .../service/network/association_request.py | 3 +- .../syft/service/network/network_service.py | 49 +++----- ...e_service.py => rathole_config_builder.py} | 105 ++++++++++++++++-- .../src/syft/service/network/rathole_toml.py | 13 +++ .../service/network/reverse_tunnel_service.py | 40 +++++++ 5 files changed, 165 insertions(+), 45 deletions(-) rename packages/syft/src/syft/service/network/{rathole_service.py => rathole_config_builder.py} (66%) create mode 100644 packages/syft/src/syft/service/network/reverse_tunnel_service.py diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 043bdbc101e..262692e283c 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -66,7 +66,8 @@ def _run( and self.remote_peer.latest_added_route == rathole_route ) - # If the remote peer is added via rathole, we don't need to ping the peer + # If the remote peer is added via rathole, we skip ping to peer + # and add the peer to the rathole server if add_rathole_route: network_service.rathole_service.add_host_to_server(self.remote_peer) else: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index bbbdc5a002c..9419e479348 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -56,7 +56,7 @@ from .association_request import AssociationRequestChange from .node_peer import NodePeer from .node_peer import NodePeerUpdate -from .rathole_service import RatholeService +from .reverse_tunnel_service import ReverseTunnelService from .routes import HTTPNodeRoute from .routes import NodeRoute from .routes import NodeRouteType @@ -69,7 +69,7 @@ REVERSE_TUNNEL_RATHOLE_ENABLED = "REVERSE_TUNNEL_RATHOLE_ENABLED" -def get_rathole_enabled() -> bool: +def reverse_tunnel_enabled() -> bool: return str_to_bool(get_env(REVERSE_TUNNEL_RATHOLE_ENABLED, "false")) @@ -164,8 +164,8 @@ class NetworkService(AbstractService): def __init__(self, store: DocumentStore) -> None: self.store = store self.stash = NetworkStash(store=store) - if get_rathole_enabled(): - self.rathole_service = RatholeService() + if reverse_tunnel_enabled(): + self.rtunnel_service = ReverseTunnelService() @service_method( path="network.exchange_credentials_with", @@ -188,7 +188,9 @@ def exchange_credentials_with( # Step 1: Validate the Route self_node_peer = self_node_route.validate_with_context(context=context) - if reverse_tunnel: + if reverse_tunnel and not reverse_tunnel_enabled(): + return SyftError(message="Reverse tunneling is not enabled on this node.") + elif reverse_tunnel: _rathole_route = self_node_peer.node_routes[-1] _rathole_route.rathole_token = generate_token() _rathole_route.host_or_ip = f"{self_node_peer.name}.syft.local" @@ -257,9 +259,10 @@ def exchange_credentials_with( return SyftError(message="Failed to update route information.") # Step 5: Save rathole config to enable reverse tunneling - if reverse_tunnel and get_rathole_enabled(): - self._add_reverse_tunneling_config_for_peer( - self_node_peer=self_node_peer, remote_node_route=remote_node_route + if reverse_tunnel and reverse_tunnel_enabled(): + self.rtunnel_service.set_client_config( + self_node_peer=self_node_peer, + remote_node_route=remote_node_route, ) return ( @@ -268,32 +271,6 @@ def exchange_credentials_with( else remote_res ) - def _add_reverse_tunneling_config_for_peer( - self, - self_node_peer: NodePeer, - remote_node_route: NodeRoute, - ) -> None: - rathole_route = self_node_peer.get_rathole_route() - if not rathole_route: - raise Exception( - "Failed to exchange routes via . " - + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" - ) - - remote_url = GridURL( - host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port - ) - rathole_remote_addr = remote_url.as_container_host() - - remote_addr = rathole_remote_addr.url_no_protocol - - self.rathole_service.add_host_to_client( - peer_name=self_node_peer.name, - peer_id=str(self_node_peer.id), - rathole_token=rathole_route.rathole_token, - remote_addr=remote_addr, - ) - @service_method(path="network.add_peer", name="add_peer", roles=GUEST_ROLE_LEVEL) def add_peer( self, @@ -509,12 +486,12 @@ def update_peer( node_side_type = cast(NodeType, context.node.node_type) if node_side_type.value == NodeType.GATEWAY.value: rathole_route = peer.get_rathole_route() - self.rathole_service.add_host_to_server(peer) if rathole_route else None + self.rtunnel_service.set_server_config(peer) if rathole_route else None else: self_node_peer: NodePeer = context.node.settings.to(NodePeer) rathole_route = self_node_peer.get_rathole_route() ( - self._add_reverse_tunneling_config_for_peer( + self.rtunnel_service.set_client_config( self_node_peer=self_node_peer, remote_node_route=peer.pick_highest_priority_route(), ) diff --git a/packages/syft/src/syft/service/network/rathole_service.py b/packages/syft/src/syft/service/network/rathole_config_builder.py similarity index 66% rename from packages/syft/src/syft/service/network/rathole_service.py rename to packages/syft/src/syft/service/network/rathole_config_builder.py index afdf48503d7..a847b6fcbde 100644 --- a/packages/syft/src/syft/service/network/rathole_service.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -21,7 +21,7 @@ PROXY_CONFIG_MAP = "proxy-config" -class RatholeService: +class RatholeConfigBuilder: def __init__(self) -> None: self.k8rs_client = get_kr8s_client() @@ -39,7 +39,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: if not rathole_route: raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") - random_port = self.get_random_port() + random_port = self._get_random_port() peer_id = cast(UID, peer.id) @@ -78,9 +78,43 @@ def add_host_to_server(self, peer: NodePeer) -> None: KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) # Add the peer info to the proxy config map - self.add_dynamic_addr_to_rathole(config) + self._add_dynamic_addr_to_rathole(config) - def get_random_port(self) -> int: + def remove_host_from_server(self, peer_id: str, server_name: str) -> None: + """Remove a host from the rathole server toml file. + + Args: + peer_id (str): The id of the peer to be removed. + server_name (str): The name of the peer to be removed. + + Returns: + None + """ + + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + if rathole_config_map is None: + raise Exception("Rathole config map not found.") + + client_filename = RatholeServerToml.filename + + toml_str = rathole_config_map.data[client_filename] + + rathole_toml = RatholeServerToml(toml_str=toml_str) + + rathole_toml.remove_config(peer_id) + + data = {client_filename: rathole_toml.toml_str} + + # Update the rathole config map + KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) + + # Remove the peer info from the proxy config map + self._remove_dynamic_addr_from_rathole(server_name) + + def _get_random_port(self) -> int: """Get a random port number.""" return secrets.randbits(15) @@ -120,7 +154,32 @@ def add_host_to_client( # Update the rathole config map KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) - def add_dynamic_addr_to_rathole( + def remove_host_from_client(self, peer_id: str) -> None: + """Remove a host from the rathole client toml file.""" + + rathole_config_map = KubeUtils.get_configmap( + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + ) + + if rathole_config_map is None: + raise Exception("Rathole config map not found.") + + client_filename = RatholeClientToml.filename + + toml_str = rathole_config_map.data[client_filename] + + rathole_toml = RatholeClientToml(toml_str=toml_str) + + rathole_toml.remove_config(peer_id) + + rathole_toml.clear_remote_addr() + + data = {client_filename: rathole_toml.toml_str} + + # Update the rathole config map + KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) + + def _add_dynamic_addr_to_rathole( self, config: RatholeConfig, entrypoint: str = "web" ) -> None: """Add a port to the rathole proxy config map.""" @@ -166,9 +225,39 @@ def add_dynamic_addr_to_rathole( patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, ) - self.expose_port_on_rathole_service(config.server_name, config.local_addr_port) + self._expose_port_on_rathole_service(config.server_name, config.local_addr_port) + + def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: + """Remove a port from the rathole proxy config map.""" + + rathole_proxy_config_map = KubeUtils.get_configmap( + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + ) + + if rathole_proxy_config_map is None: + raise Exception("Rathole proxy config map not found.") + + rathole_proxy = rathole_proxy_config_map.data["rathole-dynamic.yml"] + + if not rathole_proxy: + return + + rathole_proxy = yaml.safe_load(rathole_proxy) + + if server_name in rathole_proxy["http"]["routers"]: + del rathole_proxy["http"]["routers"][server_name] + + if server_name in rathole_proxy["http"]["services"]: + del rathole_proxy["http"]["services"][server_name] + + KubeUtils.update_configmap( + config_map=rathole_proxy_config_map, + patch={"data": {"rathole-dynamic.yml": yaml.safe_dump(rathole_proxy)}}, + ) + + self._remove_port_on_rathole_service(server_name) - def expose_port_on_rathole_service(self, port_name: str, port: int) -> None: + def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: """Expose a port on the rathole service.""" rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") @@ -199,7 +288,7 @@ def expose_port_on_rathole_service(self, port_name: str, port: int) -> None: rathole_service.patch(config) - def remove_port_on_rathole_service(self, port_name: str) -> None: + def _remove_port_on_rathole_service(self, port_name: str) -> None: """Remove a port from the rathole service.""" rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") diff --git a/packages/syft/src/syft/service/network/rathole_toml.py b/packages/syft/src/syft/service/network/rathole_toml.py index e5fe17b59e9..8ded821279e 100644 --- a/packages/syft/src/syft/service/network/rathole_toml.py +++ b/packages/syft/src/syft/service/network/rathole_toml.py @@ -53,6 +53,19 @@ def set_remote_addr(self, remote_host: str) -> None: self.save(toml) + def clear_remote_addr(self) -> None: + """Clear the remote address from the client toml file.""" + + toml = self.read() + + # Clear the remote address + if "client" not in toml: + return + + toml["client"]["remote_addr"] = "" + + self.save(toml) + def add_config(self, config: RatholeConfig) -> None: """Add a new config to the toml file.""" diff --git a/packages/syft/src/syft/service/network/reverse_tunnel_service.py b/packages/syft/src/syft/service/network/reverse_tunnel_service.py new file mode 100644 index 00000000000..99783649a44 --- /dev/null +++ b/packages/syft/src/syft/service/network/reverse_tunnel_service.py @@ -0,0 +1,40 @@ +# relative +from ...types.grid_url import GridURL +from .node_peer import NodePeer +from .rathole_config_builder import RatholeConfigBuilder +from .routes import NodeRoute + + +class ReverseTunnelService: + def __init__(self) -> None: + self.builder = RatholeConfigBuilder() + + def set_client_config( + self, + self_node_peer: NodePeer, + remote_node_route: NodeRoute, + ) -> None: + rathole_route = self_node_peer.get_rathole_route() + if not rathole_route: + raise Exception( + "Failed to exchange routes via . " + + f"Peer: {self_node_peer} has no rathole route: {rathole_route}" + ) + + remote_url = GridURL( + host_or_ip=remote_node_route.host_or_ip, port=remote_node_route.port + ) + rathole_remote_addr = remote_url.as_container_host() + + remote_addr = rathole_remote_addr.url_no_protocol + + self.builder.add_host_to_client( + peer_name=self_node_peer.name, + peer_id=str(self_node_peer.id), + rathole_token=rathole_route.rathole_token, + remote_addr=remote_addr, + ) + + def set_server_config(self, remote_peer: NodePeer) -> None: + rathole_route = remote_peer.get_rathole_route() + self.builder.add_host_to_server(remote_peer) if rathole_route else None From 2be7e689d45da98f3169274d100029862d961720 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 26 Jun 2024 00:34:26 +0530 Subject: [PATCH 154/309] integrate configmap deletion --- .../service/network/association_request.py | 5 ++--- .../syft/service/network/network_service.py | 18 ++++++++++++++++++ .../service/network/reverse_tunnel_service.py | 8 ++++++++ tests/integration/network/gateway_test.py | 4 ++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index 262692e283c..bdbe8dc6cbc 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -66,10 +66,9 @@ def _run( and self.remote_peer.latest_added_route == rathole_route ) - # If the remote peer is added via rathole, we skip ping to peer - # and add the peer to the rathole server + # If the remote peer is added via reverse tunnel, we skip ping to peer if add_rathole_route: - network_service.rathole_service.add_host_to_server(self.remote_peer) + network_service.rtunnel_service.set_server_config(self.remote_peer) else: # Pinging the remote peer to verify the connection try: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 9419e479348..b442f3b83b4 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -512,6 +512,24 @@ def delete_peer_by_id( self, context: AuthedServiceContext, uid: UID ) -> SyftSuccess | SyftError: """Delete Node Peer""" + retrieve_result = self.stash.get_by_uid(context.credentials, uid) + if err := retrieve_result.is_err(): + return SyftError( + message=f"Failed to retrieve peer with UID {uid}: {retrieve_result.err()}." + ) + peer_to_delete = cast(NodePeer, retrieve_result.ok()) + + node_side_type = cast(NodeType, context.node.node_type) + if node_side_type.value == NodeType.GATEWAY.value: + rathole_route = peer_to_delete.get_rathole_route() + ( + self.rtunnel_service.clear_server_config(peer_to_delete) + if rathole_route + else None + ) + + # TODO: Handle the case when peer is deleted from domain node + result = self.stash.delete_by_uid(context.credentials, uid) if err := result.is_err(): return SyftError(message=f"Failed to delete peer with UID {uid}: {err}.") diff --git a/packages/syft/src/syft/service/network/reverse_tunnel_service.py b/packages/syft/src/syft/service/network/reverse_tunnel_service.py index 99783649a44..bb80c56f401 100644 --- a/packages/syft/src/syft/service/network/reverse_tunnel_service.py +++ b/packages/syft/src/syft/service/network/reverse_tunnel_service.py @@ -38,3 +38,11 @@ def set_client_config( def set_server_config(self, remote_peer: NodePeer) -> None: rathole_route = remote_peer.get_rathole_route() self.builder.add_host_to_server(remote_peer) if rathole_route else None + + def clear_client_config(self, self_node_peer: NodePeer) -> None: + self.builder.remove_host_from_client(str(self_node_peer.id)) + + def clear_server_config(self, remote_peer: NodePeer) -> None: + self.builder.remove_host_from_server( + str(remote_peer.id), server_name=remote_peer.name + ) diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index 1e09c2a775e..e8f36be2ff2 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -965,3 +965,7 @@ def test_reverse_tunnel_connection(domain_1_port: int, gateway_port: int): assert domain_client.name == proxy_domain_client.name assert not isinstance(proxy_domain_client.datasets.get_all(), SyftError) + + # Try removing existing peers just to make sure + _remove_existing_peers(gateway_client) + _remove_existing_peers(domain_client) From 5f3c7694002f8b2b67f05a15a3f183a71067fcb5 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 26 Jun 2024 08:25:38 +0700 Subject: [PATCH 155/309] - process code on the client side before submitting - add test for when the submit code contains `global` --- .../syft/src/syft/service/code/user_code.py | 18 ++++++++++++++++++ .../syft/tests/syft/users/user_code_test.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b71f5aa4cc6..16fb099df4c 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1188,6 +1188,12 @@ def syft_function( def decorator(f: Any) -> SubmitUserCode: try: code = dedent(inspect.getsource(f)) + + res = process_code_client(code) + if isinstance(res, SyftError): + display(res) + return res + if name is not None: fname = name code = replace_func_name(code, fname) @@ -1233,6 +1239,18 @@ def decorator(f: Any) -> SubmitUserCode: return decorator +def process_code_client( + raw_code: str, +): + tree = ast.parse(raw_code) + # check there are no globals + v = GlobalsVisitor() + try: + v.visit(tree) + except Exception as e: + return SyftError(message=f"Failed to process code. {e}") + + def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 333d246e37f..68fbf7922ea 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -366,3 +366,20 @@ def valid_name_2(): valid_name_2.func_name = "get_all" with pytest.raises(ValidationError): client.code.submit(valid_name_2) + + +def test_submit_code_with_global_var(guest_client: DomainClient) -> None: + @sy.syft_function( + input_policy=sy.ExactMatch(), output_policy=sy.SingleExecutionExactOutput() + ) + def mock_syft_func_with_global(): + global x + + def example_function(): + return 1 + x + + return example_function() + + res = guest_client.code.submit(mock_syft_func_with_global) + assert isinstance(res, SyftError) + assert "No Globals allowed!" in res.message From 412efc8e9cc155a7b7bf27f011e6a94ef82ff1c1 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 26 Jun 2024 11:16:39 +0700 Subject: [PATCH 156/309] [syft/user_code] add a `global` keyword check locally for client before submitting code --- packages/syft/src/syft/client/api.py | 5 +++-- packages/syft/src/syft/orchestra.py | 5 ++++- packages/syft/src/syft/service/code/user_code.py | 15 ++++++++------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 9c8b244b129..8ebfac03c0b 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -23,6 +23,7 @@ from pydantic import TypeAdapter from result import OkErr from result import Result +from typeguard import TypeCheckError from typeguard import check_type # relative @@ -1385,7 +1386,7 @@ def validate_callable_args_and_kwargs( break # only need one to match else: check_type(arg, t) # raises Exception - except TypeError: + except TypeCheckError: t_arg = type(arg) if ( autoreload_enabled() @@ -1396,7 +1397,7 @@ def validate_callable_args_and_kwargs( pass else: _type_str = getattr(t, "__name__", str(t)) - msg = f"Arg: {arg} must be {_type_str} not {type(arg).__name__}" + msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`" if msg: return SyftError(message=msg) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 1a08f594aa2..08672657762 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -8,6 +8,7 @@ from enum import Enum import getpass import inspect +import logging import os import sys from typing import Any @@ -24,6 +25,8 @@ from .service.response import SyftError from .util.util import get_random_available_port +logger = logging.getLogger(__name__) + DEFAULT_PORT = 8080 DEFAULT_URL = "http://localhost" @@ -174,7 +177,7 @@ def deploy_to_python( } if dev_mode: - print("Staging Protocol Changes...") + logger.debug("Staging Protocol Changes...") stage_protocol_changes() kwargs = { diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 0c461a8ae84..fac45223721 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1192,14 +1192,14 @@ def syft_function( else: output_policy_type = type(output_policy) - def decorator(f: Any) -> SubmitUserCode: + def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) - res = process_code_client(code) - if isinstance(res, SyftError): - display(res) - return res + global_check = _check_global(code) + if isinstance(global_check, SyftError): + display(global_check) + return global_check if name is not None: fname = name @@ -1246,9 +1246,9 @@ def decorator(f: Any) -> SubmitUserCode: return decorator -def process_code_client( +def _check_global( raw_code: str, -): +) -> None | SyftError: tree = ast.parse(raw_code) # check there are no globals v = GlobalsVisitor() @@ -1256,6 +1256,7 @@ def process_code_client( v.visit(tree) except Exception as e: return SyftError(message=f"Failed to process code. {e}") + return None def generate_unique_func_name(context: TransformContext) -> TransformContext: From d4074b00be6611b1f36da4a7afbfc3280368e8e8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 26 Jun 2024 12:20:32 +0530 Subject: [PATCH 157/309] fix protocol version for HttpConnection --- .../src/syft/protocol/protocol_version.json | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 9ff9826a945..14a8214bc61 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -274,9 +274,14 @@ } }, "HTTPConnection": { + "2": { + "version": 2, + "hash": "68409295f8916ceb22a8cf4abf89f5e4bcff0d75dc37e16ede37250ada28df59", + "action": "remove" + }, "3": { "version": 3, - "hash": "54b452bb4ab76691ac1e704b62e7bcec740850fea00805145259b37973ecd0f4", + "hash": "cac31ba98bdcc42c0717555a0918d0c8aef0d2235f892a2d86dceff09930fb88", "action": "add" } }, @@ -354,6 +359,54 @@ "hash": "ba9ebb04cc3e8b3ae3302fd42a67e47261a0a330bae5f189d8f4819cf2804711", "action": "add" } + }, + "PythonConnection": { + "2": { + "version": 2, + "hash": "eb479c671fc112b2acbedb88bc5624dfdc9592856c04c22c66410f6c863e1708", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", + "action": "add" + } + }, + "HTTPNodeRoute": { + "2": { + "version": 2, + "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", + "action": "add" + } + }, + "PythonNodeRoute": { + "2": { + "version": 2, + "hash": "3eca5767ae4a8fbe67744509e58c6d9fb78f38fa0a0f7fcf5960ab4250acc1f0", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", + "action": "add" + } + }, + "SeaweedFSBlobDeposit": { + "3": { + "version": 3, + "hash": "05e61e6328b085b738e5d41c0781d87852d44d218894cb3008f5be46e337f6d8", + "action": "remove" + }, + "4": { + "version": 4, + "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", + "action": "add" + } } } } From a250fb1e2171898a6eaaf38e629017983faa1bea Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 26 Jun 2024 10:52:37 +0200 Subject: [PATCH 158/309] fix ActionObject caching Before this change, reload_cache was not updating syft_action_data_cache, because syft_created_at was not being set. --- packages/syft/src/syft/service/action/action_object.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index b8e22138d26..1d294a181be 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1388,6 +1388,7 @@ def from_obj( action_object.syft_blob_storage_entry_id = syft_blob_storage_entry_id action_object.syft_action_data_node_id = data_node_id action_object.syft_resolved = syft_resolved + action_object.syft_created_at = DateTime.now() if id is not None: action_object.id = id From f1745379f1b1e463fe722c9e86eb78a9313c351e Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 26 Jun 2024 15:57:24 +0530 Subject: [PATCH 159/309] fix url path in _make_get --- packages/syft/src/syft/client/client.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index d770ff73280..7d7f5be95d1 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -207,6 +207,9 @@ def session(self) -> Session: def _make_get( self, path: str, params: dict | None = None, stream: bool = False ) -> bytes | Iterable: + if params is None: + return self._make_get_no_params(path, stream=stream) + url = self.url if self.rathole_token: @@ -216,9 +219,6 @@ def _make_get( url = url.with_path(path) - if params is None: - return self._make_get_no_params(path) - url = self.url.with_path(path) response = self.session.get( str(url), headers=self.headers, @@ -239,8 +239,15 @@ def _make_get( @cached(cache=TTLCache(maxsize=128, ttl=300)) def _make_get_no_params(self, path: str, stream: bool = False) -> bytes | Iterable: - print(path) - url = self.url.with_path(path) + url = self.url + + if self.rathole_token: + self.headers = {} if self.headers is None else self.headers + url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) + self.headers["Host"] = self.url.host_or_ip + + url = url.with_path(path) + response = self.session.get( str(url), headers=self.headers, From 2f5554cbf07d5acf498ccdeb83b269433b36c08b Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 26 Jun 2024 16:08:58 +0530 Subject: [PATCH 160/309] deprecate HttpConnectionV2 --- packages/syft/src/syft/client/client.py | 12 ----- .../src/syft/protocol/protocol_version.json | 48 +++++++++++++++++++ 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 7d7f5be95d1..9540a50fb7a 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -51,7 +51,6 @@ from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..types.grid_url import GridURL -from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.uid import UID from ..util.telemetry import instrument @@ -131,17 +130,6 @@ class Routes(Enum): STREAM = f"{API_PATH}/stream" -@serializable(attrs=["proxy_target_uid", "url"]) -class HTTPConnectionV2(NodeConnection): - __canonical_name__ = "HTTPConnection" - __version__ = SYFT_OBJECT_VERSION_2 - - url: GridURL - proxy_target_uid: UID | None = None - routes: type[Routes] = Routes - session_cache: Session | None = None - - @serializable(attrs=["proxy_target_uid", "url", "rathole_token"]) class HTTPConnection(NodeConnection): __canonical_name__ = "HTTPConnection" diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 766da3a3326..c08ea39b7b2 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -378,6 +378,54 @@ "hash": "1b9bd1d3d096abab5617c2ff597b4c80751f686d16482a2cff4efd8741b84d53", "action": "add" } + }, + "PythonConnection": { + "2": { + "version": 2, + "hash": "eb479c671fc112b2acbedb88bc5624dfdc9592856c04c22c66410f6c863e1708", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", + "action": "add" + } + }, + "HTTPNodeRoute": { + "2": { + "version": 2, + "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", + "action": "add" + } + }, + "PythonNodeRoute": { + "2": { + "version": 2, + "hash": "3eca5767ae4a8fbe67744509e58c6d9fb78f38fa0a0f7fcf5960ab4250acc1f0", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", + "action": "add" + } + }, + "SeaweedFSBlobDeposit": { + "3": { + "version": 3, + "hash": "05e61e6328b085b738e5d41c0781d87852d44d218894cb3008f5be46e337f6d8", + "action": "remove" + }, + "4": { + "version": 4, + "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", + "action": "add" + } } } } From 5d5ac71403440c8465328cf5547618c7f1eacd3a Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 26 Jun 2024 16:21:57 +0530 Subject: [PATCH 161/309] ignore security on hardcoded binding for rathole config --- .../syft/src/syft/service/network/rathole_config_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index a847b6fcbde..90f499ec237 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -19,6 +19,7 @@ RATHOLE_TOML_CONFIG_MAP = "rathole-config" RATHOLE_PROXY_CONFIG_MAP = "proxy-config-dynamic" PROXY_CONFIG_MAP = "proxy-config" +DEFAULT_LOCAL_ADDR_HOST = "0.0.0.0" # nosec class RatholeConfigBuilder: @@ -46,7 +47,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: config = RatholeConfig( uuid=peer_id.to_string(), secret_token=rathole_route.rathole_token, - local_addr_host="0.0.0.0", + local_addr_host=DEFAULT_LOCAL_ADDR_HOST, local_addr_port=random_port, server_name=peer.name, ) From be17b3bd0e5378165da53efe653c2d2dbb5857cb Mon Sep 17 00:00:00 2001 From: dk Date: Thu, 27 Jun 2024 11:22:00 +0700 Subject: [PATCH 162/309] [syft/user_code] checking global in `syft_function` and `SubmitUserCode.local_call` --- .../syft/src/syft/service/code/user_code.py | 90 ++++++++++++++++--- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index fac45223721..ce82a26121f 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -980,11 +980,12 @@ def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: source = dedent(inspect.getsource(self.local_function)) - tree = ast.parse(source) - # check there are no globals - v = GlobalsVisitor() - v.visit(tree) + v: GlobalsVisitor | SyftWarning = _check_global(raw_code=source) + if isinstance(v, SyftWarning): # the code contains "global" keyword + return SyftError( + message=f"Error when running function locally: {v.message}" + ) # filtered_args = [] filtered_kwargs = {} @@ -1196,10 +1197,19 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) + # check that there are no globals global_check = _check_global(code) - if isinstance(global_check, SyftError): + if isinstance(global_check, SyftWarning): display(global_check) - return global_check + # err = SyftError(message=global_check.message) + # display(err) + # return err + + lint_issues = _lint_code(code) + lint_warning_msg = "" + for issue in lint_issues: + lint_warning_msg += f"{issue}\n\t" + display(SyftWarning(message=lint_warning_msg)) if name is not None: fname = name @@ -1246,17 +1256,73 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator -def _check_global( - raw_code: str, -) -> None | SyftError: +def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: tree = ast.parse(raw_code) # check there are no globals v = GlobalsVisitor() try: v.visit(tree) - except Exception as e: - return SyftError(message=f"Failed to process code. {e}") - return None + except Exception: + return SyftWarning( + message="Your code contains (a) global variable(s), which is not allowed" + ) + return v + + +# Define a linter function +def _lint_code(code: str) -> list: + # Parse the code into an AST + tree = ast.parse(code) + + # Initialize a list to collect linting issues + issues = [] + + # Define a visitor class to walk the AST + class CodeVisitor(ast.NodeVisitor): + def __init__(self) -> None: + self.globals: set = set() + self.defined_names: set = set() + self.current_scope_defined_names: set = set() + + def visit_Global(self, node: Any) -> None: + # Collect global variable names + for name in node.names: + self.globals.add(name) + self.generic_visit(node) + + def visit_FunctionDef(self, node: Any) -> None: + # Collect defined function names and handle function scope + self.defined_names.add(node.name) + self.current_scope_defined_names = set() # New scope + self.generic_visit(node) + self.current_scope_defined_names.clear() # Clear scope after visiting + + def visit_Assign(self, node: Any) -> None: + # Collect assigned variable names + for target in node.targets: + if isinstance(target, ast.Name): + self.current_scope_defined_names.add(target.id) + self.generic_visit(node) + + def visit_Name(self, node: Any) -> None: + # Check if variables are used before being defined + if isinstance(node.ctx, ast.Load): + if ( + node.id not in self.current_scope_defined_names + and node.id not in self.defined_names + and node.id not in self.globals + ): + issues.append( + f"Variable '{node.id}' used at line {node.lineno} before being defined." + ) + self.generic_visit(node) + + # Create a visitor instance and visit the AST + visitor = CodeVisitor() + visitor.visit(tree) + + # Return the collected issues + return issues def generate_unique_func_name(context: TransformContext) -> TransformContext: From fbcfe599b288c4149950124d87b2a5579779c907 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 27 Jun 2024 14:09:30 +0530 Subject: [PATCH 163/309] mark reverse tunnel with network marker --- tests/integration/network/gateway_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index e8f36be2ff2..07a5e360f31 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -112,6 +112,10 @@ def test_domain_connect_to_gateway( _remove_existing_peers(domain_client) _remove_existing_peers(gateway_client) + # Disable automatic acceptance of association requests + res = gateway_client.settings.allow_association_request_auto_approval(enable=False) + assert isinstance(res, SyftSuccess) + # connecting the domain to the gateway result = domain_client.connect_to_gateway(gateway_client) assert isinstance(result, Request) @@ -914,6 +918,7 @@ def test_peer_health_check(set_env_var, gateway_port: int, domain_1_port: int) - assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) +@pytest.mark.network def test_reverse_tunnel_connection(domain_1_port: int, gateway_port: int): # login to the domain and gateway From 77fb1c20d9804f9aaf37ab2e935fd2eac08df732 Mon Sep 17 00:00:00 2001 From: dk Date: Thu, 27 Jun 2024 16:01:52 +0700 Subject: [PATCH 164/309] [syft/action_service] merging flags to save action objects to the blob store / clear cache --- .../syft/src/syft/client/domain_client.py | 5 ++-- .../src/syft/service/action/action_object.py | 10 +++---- .../src/syft/service/action/action_service.py | 26 +++++++------------ .../syft/src/syft/service/dataset/dataset.py | 5 ++-- packages/syft/src/syft/types/twin_object.py | 5 ++-- 5 files changed, 20 insertions(+), 31 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 7214d8bf966..eb677c6b2cd 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -148,14 +148,13 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False response = self.api.services.action.set( twin, ignore_detached_objs=contains_empty, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) if isinstance(response, SyftError): tqdm.write(f"Failed to upload asset: {asset.name}") diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index bcdf3190fbc..f737c87fcb0 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -527,13 +527,12 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: print(r.message) if isinstance(r, SyftWarning): logger.debug(r.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False arg = api.services.action.set( arg, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) return arg @@ -1251,14 +1250,13 @@ def _send( if isinstance(blob_storage_res, SyftWarning): logger.debug(blob_storage_res.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False res = api.services.action.set( self, add_storage_permission=add_storage_permission, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) if isinstance(res, ActionObject): self.syft_created_at = res.syft_created_at diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 50504c34301..da273e2c12b 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -75,15 +75,14 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: return blob_store_result if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False np_pointer = self._set( context, np_obj, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) return np_pointer @@ -98,7 +97,6 @@ def set( action_object: ActionObject | TwinObject, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_clear_cache: bool = False, skip_save_to_blob_store: bool = False, ) -> ActionObject | SyftError: res = self._set( @@ -107,7 +105,6 @@ def set( has_result_read_permission=True, add_storage_permission=add_storage_permission, ignore_detached_objs=ignore_detached_objs, - skip_clear_cache=skip_clear_cache, skip_save_to_blob_store=skip_save_to_blob_store, ) if res.is_err(): @@ -145,7 +142,6 @@ def _set( has_result_read_permission: bool = False, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_clear_cache: bool = False, skip_save_to_blob_store: bool = False, ) -> Result[ActionObject, str]: if ( @@ -153,19 +149,19 @@ def _set( and not skip_save_to_blob_store ): return Err( - "you uploaded an ActionObject that is not yet in the blob storage" + "You uploaded an ActionObject that is not yet in the blob storage" ) """Save an object to the action store""" # 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable if isinstance(action_object, ActionObject): action_object.syft_created_at = DateTime.now() - if not skip_clear_cache: + if not skip_save_to_blob_store: action_object._clear_cache() - else: + else: # TwinObject action_object.private_obj.syft_created_at = DateTime.now() # type: ignore[unreachable] action_object.mock_obj.syft_created_at = DateTime.now() - if not skip_clear_cache: + if not skip_save_to_blob_store: action_object.private_obj._clear_cache() action_object.mock_obj._clear_cache() @@ -531,9 +527,9 @@ def set_result_to_store( return Err(blob_store_result.message) if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False # IMPORTANT: DO THIS ONLY AFTER ._save_to_blob_storage if isinstance(result_action_object, TwinObject): @@ -550,7 +546,6 @@ def set_result_to_store( result_action_object, has_result_read_permission=True, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) if set_result.is_err(): @@ -819,14 +814,13 @@ def execute( } if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False set_result = self._set( context, result_action_object, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) if set_result.is_err(): return Err( diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 19352d26377..2bd7dc33b90 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -713,9 +713,9 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: raise ValueError(res.message) if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False # TODO, upload to blob storage here if context.node is None: raise ValueError( @@ -726,7 +726,6 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: context=context.to_node_context(), action_object=twin, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) if result.is_err(): raise RuntimeError(f"Failed to create and store twin. Error: {result}") diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index b5f7c90e42c..eae86e9cb5b 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -109,13 +109,12 @@ def send(self, client: SyftClient, add_storage_permission: bool = True) -> Any: blob_store_result = self._save_to_blob_storage() if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store, skip_clear_cache = True, True + skip_save_to_blob_store = True else: - skip_save_to_blob_store, skip_clear_cache = False, False + skip_save_to_blob_store = False res = client.api.services.action.set( self, add_storage_permission=add_storage_permission, skip_save_to_blob_store=skip_save_to_blob_store, - skip_clear_cache=skip_clear_cache, ) return res From c9be64ecd921f204d91c921823b629dc23cc56c6 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 27 Jun 2024 14:58:14 +0530 Subject: [PATCH 165/309] remove use of route id for route deletion --- .../syft/service/network/network_service.py | 40 +++++-------------- .../src/syft/service/network/node_peer.py | 15 +------ tests/integration/network/gateway_test.py | 2 +- 3 files changed, 12 insertions(+), 45 deletions(-) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 7deb8775686..4a39a1edf98 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -625,6 +625,8 @@ def add_route( message=f"The route already exists between '{context.node.name}' and " f"peer '{remote_node_peer.name}'." ) + + remote_node_peer.update_route(route=route) # update the peer in the store with the updated routes peer_update = NodePeerUpdate( id=remote_node_peer.id, node_routes=remote_node_peer.node_routes @@ -646,8 +648,7 @@ def delete_route_on_peer( self, context: AuthedServiceContext, peer: NodePeer, - route: NodeRoute | None = None, - route_id: UID | None = None, + route: NodeRoute, ) -> SyftSuccess | SyftError | SyftInfo: """ Delete the route on the remote peer. @@ -656,7 +657,6 @@ def delete_route_on_peer( context (AuthedServiceContext): The authentication context for the service. peer (NodePeer): The peer for which the route will be deleted. route (NodeRoute): The route to be deleted. - route_id (UID): The UID of the route to be deleted. Returns: SyftSuccess: If the route is successfully deleted. @@ -664,17 +664,6 @@ def delete_route_on_peer( SyftInfo: If there is only one route left for the peer and the admin chose not to remove it """ - if route is None and route_id is None: - return SyftError( - message="Either `route` or `route_id` arg must be provided" - ) - - if route and route_id and route.id != route_id: - return SyftError( - message=f"Both `route` and `route_id` are provided, but " - f"route's id ({route.id}) and route_id ({route_id}) do not match" - ) - # creates a client on the remote node based on the credentials # of the current node's client remote_client = peer.client_with_context(context=context) @@ -688,7 +677,6 @@ def delete_route_on_peer( result = remote_client.api.services.network.delete_route( peer_verify_key=context.credentials, route=route, - route_id=route_id, called_by_peer=True, ) return result @@ -701,7 +689,6 @@ def delete_route( context: AuthedServiceContext, peer_verify_key: SyftVerifyKey, route: NodeRoute | None = None, - route_id: UID | None = None, called_by_peer: bool = False, ) -> SyftSuccess | SyftError | SyftInfo: """ @@ -713,7 +700,6 @@ def delete_route( context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the remote node peer. route (NodeRoute): The route to be deleted. - route_id (UID): The UID of the route to be deleted. called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: @@ -755,20 +741,12 @@ def delete_route( f"'{remote_node_peer.node_routes[0].id}' was not deleted." ) - if route: - result = remote_node_peer.delete_route(route=route) - return_message = ( - f"Route '{str(route)}' with id '{route.id}' to peer " - f"{remote_node_peer.node_type.value} '{remote_node_peer.name}' " - f"was deleted for {str(context.node.node_type)} '{context.node.name}'." - ) - if route_id: - result = remote_node_peer.delete_route(route_id=route_id) - return_message = ( - f"Route with id '{route_id}' to peer " - f"{remote_node_peer.node_type.value} '{remote_node_peer.name}' " - f"was deleted for {str(context.node.node_type)} '{context.node.name}'." - ) + result = remote_node_peer.delete_route(route=route) + return_message = ( + f"Route '{str(route)}' to peer " + f"{remote_node_peer.node_type.value} '{remote_node_peer.name}' " + f"was deleted for {str(context.node.node_type)} '{context.node.name}'." + ) if isinstance(result, SyftError): return result diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 357a8d17967..e3f14481f36 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -291,28 +291,17 @@ def get_rathole_route(self) -> HTTPNodeRoute | None: return route return None - def delete_route( - self, route: NodeRouteType | None = None, route_id: UID | None = None - ) -> SyftError | None: + def delete_route(self, route: NodeRouteType) -> SyftError | None: """ Deletes a route from the peer's route list. Takes O(n) where is n is the number of routes in self.node_routes. Args: route (NodeRouteType): The route to be deleted; - route_id (UID): The id of the route to be deleted; Returns: - SyftError: If deleting failed + SyftError: If failing to delete node route """ - if route_id: - try: - self.node_routes = [r for r in self.node_routes if r.id != route_id] - except Exception as e: - return SyftError( - message=f"Error deleting route with id {route_id}. Exception: {e}" - ) - if route: try: self.node_routes = [r for r in self.node_routes if r != route] diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index 07a5e360f31..ef8b090a8a6 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -618,7 +618,7 @@ def test_delete_route_on_peer( # gateway delete the routes for the domain res = gateway_client.api.services.network.delete_route_on_peer( - peer=domain_peer, route_id=new_route.id + peer=domain_peer, route=new_route ) assert isinstance(res, SyftSuccess) gateway_peer = domain_client.peers[0] From 719a9a5b7456d3de3c4bac034b27fe3f9cc1ba17 Mon Sep 17 00:00:00 2001 From: dk Date: Thu, 27 Jun 2024 17:32:24 +0700 Subject: [PATCH 166/309] [syft/action_obj] refactor `ActionObject._save_to_blob_storage` --- .../src/syft/service/action/action_object.py | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index f737c87fcb0..9f0c7baaf5d 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -793,17 +793,25 @@ def reload_cache(self) -> SyftError | None: return None - def _save_to_blob_storage_(self, data: Any) -> SyftError | None: + def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: # relative from ...types.blob_storage import BlobFile from ...types.blob_storage import CreateBlobStorageEntry + api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + if api is None: + raise ValueError( + f"api is None. You must login to {self.syft_node_location}" + ) + if not can_upload_to_blob_storage(data, api.metadata): + return SyftWarning( + message=f"The action object {self.id} was not saved to " + f"the blob store but to memory cache since it is small." + ) + if not isinstance(data, ActionDataEmpty): if isinstance(data, BlobFile): if not data.uploaded: - api = APIRegistry.api_for( - self.syft_node_location, self.syft_client_verify_key - ) data._upload_to_blobstorage_from_api(api) else: serialized = serialize(data, to_bytes=True) @@ -843,21 +851,10 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: "skipping writing action object to store, passed data was empty." ) - self.syft_action_data_cache = data + # self.syft_action_data_cache = data return None - def _set_reprs(self, data: any) -> None: - if inspect.isclass(data): - self.syft_action_data_repr_ = truncate_str(repr_cls(data)) - else: - self.syft_action_data_repr_ = truncate_str( - data._repr_markdown_() - if hasattr(data, "_repr_markdown_") - else data.__repr__() - ) - self.syft_action_data_str_ = truncate_str(str(data)) - def _save_to_blob_storage( self, allow_empty: bool = False ) -> SyftError | SyftSuccess | SyftWarning: @@ -869,23 +866,14 @@ def _save_to_blob_storage( message=f"cannot store empty object {self.id} to the blob storage" ) try: - api = APIRegistry.api_for( - node_uid=self.syft_node_location, - user_verify_key=self.syft_client_verify_key, + result = self._save_to_blob_storage_(data) + if isinstance(result, SyftError | SyftWarning): + return result + if not TraceResultRegistry.current_thread_is_tracing(): + self._clear_cache() + return SyftSuccess( + message=f"Saved action object {self.id} to the blob store" ) - if api is None: - raise ValueError( - f"api is None. You must login to {self.syft_node_location}" - ) - if can_upload_to_blob_storage(data, api.metadata): - result = self._save_to_blob_storage_(data) - if isinstance(result, SyftError): - return result - if not TraceResultRegistry.current_thread_is_tracing(): - self._clear_cache() - return SyftSuccess( - message=f"Saved action object {self.id} to the blob store" - ) except Exception as e: print( f"Failed to save action object {self.id} to the blob store. Error: {e}" @@ -900,6 +888,17 @@ def _save_to_blob_storage( def _clear_cache(self) -> None: self.syft_action_data_cache = self.as_empty_data() + def _set_reprs(self, data: any) -> None: + if inspect.isclass(data): + self.syft_action_data_repr_ = truncate_str(repr_cls(data)) + else: + self.syft_action_data_repr_ = truncate_str( + data._repr_markdown_() + if hasattr(data, "_repr_markdown_") + else data.__repr__() + ) + self.syft_action_data_str_ = truncate_str(str(data)) + @property def is_pointer(self) -> bool: return self.syft_node_uid is not None From 7609728a7080a400c7d57b2373e10b139477d7fa Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 27 Jun 2024 16:19:09 +0530 Subject: [PATCH 167/309] refactor reverse tunnel config logic to a single method --- .../service/network/association_request.py | 5 ++- .../syft/service/network/network_service.py | 40 +++++++++++++------ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index bdbe8dc6cbc..a910302005e 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -68,7 +68,10 @@ def _run( # If the remote peer is added via reverse tunnel, we skip ping to peer if add_rathole_route: - network_service.rtunnel_service.set_server_config(self.remote_peer) + network_service.set_reverse_tunnel_config( + context=context, + remote_node_peer=self.remote_peer, + ) else: # Pinging the remote peer to verify the connection try: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4a39a1edf98..4aa48481df2 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -261,9 +261,10 @@ def exchange_credentials_with( # Step 5: Save rathole config to enable reverse tunneling if reverse_tunnel and reverse_tunnel_enabled(): - self.rtunnel_service.set_client_config( + self.set_reverse_tunnel_config( + context=context, self_node_peer=self_node_peer, - remote_node_route=remote_node_route, + remote_node_peer=remote_node_peer, ) return ( @@ -484,26 +485,41 @@ def update_peer( peer = result.ok() - node_side_type = cast(NodeType, context.node.node_type) - if node_side_type.value == NodeType.GATEWAY.value: - rathole_route = peer.get_rathole_route() - self.rtunnel_service.set_server_config(peer) if rathole_route else None + self.set_reverse_tunnel_config(context=context, remote_node_peer=peer) + return SyftSuccess( + message=f"Peer '{result.ok().name}' information successfully updated." + ) + + def set_reverse_tunnel_config( + self, + context: AuthedServiceContext, + remote_node_peer: NodePeer, + self_node_peer: NodePeer | None = None, + ) -> None: + node_type = cast(NodeType, context.node.node_type) + if node_type.value == NodeType.GATEWAY.value: + rathole_route = remote_node_peer.get_rathole_route() + ( + self.rtunnel_service.set_server_config(remote_node_peer) + if rathole_route + else None + ) else: - self_node_peer: NodePeer = context.node.settings.to(NodePeer) + self_node_peer = ( + context.node.settings.to(NodePeer) + if self_node_peer is None + else self_node_peer + ) rathole_route = self_node_peer.get_rathole_route() ( self.rtunnel_service.set_client_config( self_node_peer=self_node_peer, - remote_node_route=peer.pick_highest_priority_route(), + remote_node_route=remote_node_peer.pick_highest_priority_route(), ) if rathole_route else None ) - return SyftSuccess( - message=f"Peer '{result.ok().name}' information successfully updated." - ) - @service_method( path="network.delete_peer_by_id", name="delete_peer_by_id", From 72e9edca4a88ab12416a55c052334b034b0d6f35 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 27 Jun 2024 14:40:33 +0200 Subject: [PATCH 168/309] rewrite assets and show in repr --- .../syft/src/syft/service/code/user_code.py | 59 +++++++++++-------- .../syft/src/syft/service/dataset/dataset.py | 3 + .../syft/tests/syft/users/user_code_test.py | 2 +- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ef65d8e8987..26557e12500 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -4,7 +4,6 @@ # stdlib import ast from collections.abc import Callable -from collections.abc import Generator from copy import deepcopy import datetime from enum import Enum @@ -779,31 +778,45 @@ def byte_code(self) -> PyCodeObject | None: return compile_byte_code(self.parsed_code) @property - def assets(self) -> list[Asset]: - # relative - from ...client.api import APIRegistry + def assets(self) -> dict[str, Asset] | SyftError: + if not self.input_policy: + return {} - api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) - if api is None: - return SyftError(message=f"You must login to {self.node_uid}") + api = self._get_api() + if isinstance(api, SyftError): + return api - inputs: Generator = (x for x in range(0)) # create an empty generator - if self.input_policy_init_kwargs is not None: - inputs = ( - uids - for node_identity, uids in self.input_policy_init_kwargs.items() - if node_identity.node_name == api.node_name - ) + # get all assets on the node + datasets = api.services.dataset.get_all() + if isinstance(datasets, SyftError): + return datasets + all_assets = { + asset.action_id: asset + for asset in itertools.chain.from_iterable(x.asset_list for x in datasets) + } + + # get a flat dict of all inputs + all_inputs = {} + inputs = self.input_policy.inputs or {} + for vals in inputs.values(): + all_inputs.update(vals) - all_assets = [] - for uid in itertools.chain.from_iterable(x.values() for x in inputs): - if isinstance(uid, UID): - assets = api.services.dataset.get_assets_by_action_id(uid) - if not isinstance(assets, list): - return assets + # map the action_id to the asset + used_assets = {} + for kwarg, action_id in all_inputs.items(): + used_assets[kwarg] = all_assets.get(action_id, None) + return used_assets - all_assets += assets - return all_assets + @property + def _asset_str(self) -> str | SyftError: + assets = self.assets + if isinstance(assets, SyftError): + return assets + asset_str_list = [ + f"{kwarg}={repr(asset)}" for kwarg, asset in self.assets.items() + ] + asset_str = "\n".join(asset_str_list) + return asset_str def get_sync_dependencies( self, context: AuthedServiceContext @@ -894,7 +907,7 @@ def _inner_repr(self, level: int = 0) -> str: {constants_str} {shared_with_line} code: - +{self._asset_str} {self.raw_code} """ if self.nested_codes != {}: diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 9dde84429c4..ac9d0114c08 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -178,6 +178,9 @@ def _repr_html_(self) -> Any: {mock_table_line} """ + def __repr__(self) -> str: + return f"Asset(name='{self.name}', node_uid='{self.node_uid}', action_id='{self.action_id}')" + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str = f"Asset: {self.name}\n" _repr_str += f"Pointer Id: {self.action_id}\n" diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 3e9cb975580..7d6432100e9 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -166,7 +166,7 @@ def func(asset): c for c in request.changes if (isinstance(c, UserCodeStatusChange)) ) - assert status_change.code.assets[0].model_dump( + assert status_change.code.assets["asset"].model_dump( mode="json" ) == asset_input.model_dump(mode="json") From f6b3348010bb1c2ac4ca608ae6b9f27e8e6e71f1 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 27 Jun 2024 15:41:51 +0200 Subject: [PATCH 169/309] data migration for store metadata --- .../0-prepare-migration-data.ipynb | 24 ++- .../1a-connect-and-migrate.ipynb | 2 +- .../1b-connect-and-migrate-via-api.ipynb | 169 ++++++++++++++-- .../2-post-migration-tests.ipynb | 2 +- .../syft/src/syft/client/domain_client.py | 25 ++- packages/syft/src/syft/node/node.py | 15 +- .../src/syft/protocol/protocol_version.json | 7 + .../syft/service/action/action_permissions.py | 14 ++ .../src/syft/service/action/action_store.py | 6 + .../service/migration/migration_service.py | 185 ++++++++++++++++-- .../migration/object_migration_state.py | 12 ++ .../src/syft/service/settings/settings.py | 8 +- .../syft/src/syft/store/document_store.py | 14 +- .../syft/src/syft/store/kv_document_store.py | 18 +- .../src/syft/store/mongo_document_store.py | 38 +++- .../tests/syft/stores/permissions_test.py | 26 +++ .../tests/syft/stores/store_mocks_test.py | 3 +- 17 files changed, 500 insertions(+), 68 deletions(-) create mode 100644 packages/syft/tests/syft/stores/permissions_test.py diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 9c24c78e53f..94e9a9f5bee 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -42,7 +42,12 @@ "source": [ "# this notebook should only be used to run the latest deployed version of syft\n", "# the notebooks after this (1a/1b and 2), will test migrating from that latest version\n", - "assert latest_deployed_version == sy.__version__" + "print(\n", + " f\"latest deployed version: {latest_deployed_version}, installed version: {sy.__version__}\"\n", + ")\n", + "assert (\n", + " latest_deployed_version == sy.__version__\n", + "), f\"{latest_deployed_version} does not match installed version {sy.__version__}\"" ] }, { @@ -216,14 +221,25 @@ "id": "18", "metadata": {}, "outputs": [], + "source": [ + "if node.node_type.value == \"python\":\n", + " node.land()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:syft086] *", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-env-syft086-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -235,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb b/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb index aca3e671f5c..fed82bb72bd 100644 --- a/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb +++ b/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb @@ -80,7 +80,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb index 50f34b603f0..b0d08cbc21b 100644 --- a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb +++ b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb @@ -15,9 +15,19 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "1", "metadata": {}, + "outputs": [], + "source": [ + "print(f\"syft version: {sy.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, "source": [ "TODOS\n", "- [x] action objects\n", @@ -31,7 +41,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -48,7 +58,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -57,7 +67,7 @@ }, { "cell_type": "markdown", - "id": "4", + "id": "5", "metadata": {}, "source": [ "# Client side migrations" @@ -65,7 +75,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ "## document store objects" @@ -74,17 +84,27 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ - "migration_dict = client.services.migration.get_migration_objects()" + "migration_dict = client.services.migration.get_migration_objects(get_all=True)" ] }, { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "migration_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -96,7 +116,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -117,7 +137,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -127,7 +147,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -137,7 +157,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -146,7 +166,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "14", "metadata": {}, "source": [ "## Actions and ActionObjects" @@ -155,7 +175,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +185,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -183,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +213,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -203,17 +223,126 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "19", "metadata": {}, "outputs": [], "source": [ "assert isinstance(res, sy.SyftSuccess)" ] }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## Store metadata\n", + "\n", + "- Permissions\n", + "- StoragePermissions" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "store_metadata = client.services.migration.get_all_store_metadata()\n", + "store_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in store_metadata.items():\n", + " if len(v.permissions):\n", + " print(\n", + " k, len(v.permissions), len(v.permissions) == len(migration_dict.get(k, []))\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "# Test update method with a temp node\n", + "# After update, all metadata should match between the nodes\n", + "\n", + "temp_node = sy.orchestra.launch(\n", + " name=\"temp_node\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + " reset=True,\n", + ")\n", + "\n", + "temp_client = temp_node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "temp_client.services.migration.update_store_metadata(store_metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "for cname, real_partition in node.python_node.document_store.partitions.items():\n", + " temp_partition = temp_node.python_node.document_store.partitions[cname]\n", + "\n", + " temp_perms = dict(temp_partition.permissions.items())\n", + " real_perms = dict(real_partition.permissions.items())\n", + "\n", + " # Only look at migrated items\n", + " temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", + " assert temp_perms == real_perms\n", + "\n", + " temp_storage = dict(temp_partition.storage_permissions.items())\n", + " real_storage = dict(real_partition.storage_permissions.items())\n", + " temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", + "\n", + " assert temp_storage == real_storage\n", + "\n", + "# Action store\n", + "real_partition = node.python_node.action_store\n", + "temp_partition = temp_node.python_node.action_store\n", + "temp_perms = dict(temp_partition.permissions.items())\n", + "real_perms = dict(real_partition.permissions.items())\n", + "\n", + "# Only look at migrated items\n", + "temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", + "assert temp_perms == real_perms\n", + "\n", + "temp_storage = dict(temp_partition.storage_permissions.items())\n", + "real_storage = dict(real_partition.storage_permissions.items())\n", + "temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", + "\n", + "assert temp_storage == real_storage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", "metadata": {}, "outputs": [], "source": [] @@ -235,7 +364,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb index ef89ccbeb16..f0d2c3546bb 100644 --- a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb +++ b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb @@ -185,7 +185,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 489455f0663..4ae9c6010e5 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -5,18 +5,19 @@ from pathlib import Path import re from string import Template -from typing import TYPE_CHECKING, Any, Dict, List +from typing import Any +from typing import TYPE_CHECKING from typing import cast # third party from loguru import logger import markdown -from syft.serde import deserialize, serialize -from syft.types.syft_object import Context from tqdm import tqdm # relative from ..abstract_node import NodeSideType +from ..serde import deserialize +from ..serde import serialize from ..serde.serializable import serializable from ..service.action.action_object import ActionObject from ..service.code_history.code_history import CodeHistoriesDict @@ -31,6 +32,7 @@ from ..service.user.roles import Roles from ..service.user.user import UserView from ..types.blob_storage import BlobFile +from ..types.syft_object import Context from ..types.uid import UID from ..util.misc_objs import HTMLObject from ..util.util import get_mb_size @@ -386,19 +388,22 @@ def code_status(self) -> APIModule | None: @property def output(self) -> APIModule | None: return self._get_service_by_name_if_exists("output") - - def save_migration_objects_to_file(self, filename: str, get_all: bool = False) -> Dict[Any, Any] | SyftError: - migration_dict = self.api.services.migration.get_migration_objects(get_all=get_all) + + def save_migration_objects_to_file( + self, filename: str, get_all: bool = False + ) -> dict[Any, Any] | SyftError: + migration_dict = self.api.services.migration.get_migration_objects( + get_all=get_all + ) if isinstance(migration_dict, SyftError): return migration_dict ser_bytes = serialize(migration_dict, to_bytes=True) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(ser_bytes) return migration_dict - - + def migrate_objects_from_file(self, filename: str) -> SyftSuccess | SyftError: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: ser_bytes = f.read() migration_dict = deserialize(ser_bytes, from_bytes=True) context = Context() diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index f06a5abb325..0b09614929c 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -120,7 +120,8 @@ from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig -from ..types.syft_object import SYFT_OBJECT_VERSION_2, Context +from ..types.syft_object import Context +from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject from ..types.uid import UID from ..util.experimental_flags import flags @@ -1593,8 +1594,12 @@ def create_initial_settings(self, admin_email: str) -> NodeSettings | None: node_settings = settings_exists[0] if node_settings.__version__ != NodeSettings.__version__: context = Context() - node_settings = node_settings.migrate_to(NodeSettings.__version__, context) - res = settings_stash.delete_by_uid(self.signing_key.verify_key, node_settings.id) + node_settings = node_settings.migrate_to( + NodeSettings.__version__, context + ) + res = settings_stash.delete_by_uid( + self.signing_key.verify_key, node_settings.id + ) if res.is_err(): raise Exception(res.value) res = settings_stash.set(self.signing_key.verify_key, node_settings) @@ -1785,7 +1790,9 @@ def create_default_worker_pool(node: Node) -> SyftError | None: default_worker_pool.worker_list ) if worker_to_add_ > 0: - add_worker_method = node.get_service_method(SyftWorkerPoolService.add_workers) + add_worker_method = node.get_service_method( + SyftWorkerPoolService.add_workers + ) result = add_worker_method( context=context, number=worker_to_add_, diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 51c9e362290..69847276a9c 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -200,6 +200,13 @@ "hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b", "action": "remove" } + }, + "StoreMetadata": { + "1": { + "version": 1, + "hash": "bb9edb077f0214c5867d5349aa99eb584d133bd5f2cc5c824986c9174c0dbbc9", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 2fda8bee2ef..6131dcf5d08 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -41,6 +41,20 @@ def __init__( self.credentials = credentials self.permission = permission + @classmethod + def from_permission_string( + cls, uid: UID, permission_string: str + ) -> "ActionObjectPermission": + if permission_string.startswith("ALL_"): + permission = ActionPermission[permission_string] + verify_key = None + else: + verify_key_str, perm_str = permission_string.split("_", 1) + permission = ActionPermission[perm_str] + verify_key = SyftVerifyKey.from_string(verify_key_str) + + return cls(uid=uid, permission=permission, credentials=verify_key) + @property def permission_string(self) -> str: if self.permission in COMPOUND_ACTION_PERMISSION: diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index fc5ae3c1958..bd9336c5a71 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -274,6 +274,9 @@ def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: return Ok(self.permissions[uid]) return Err(f"No permissions found for uid: {uid}") + def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: + return Ok(dict(self.permissions.items())) + def add_storage_permission(self, permission: StoragePermission) -> None: permissions = self.storage_permissions[permission.uid] permissions.add(permission.node_uid) @@ -301,6 +304,9 @@ def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: return Ok(self.storage_permissions[uid]) return Err(f"No storage permissions found for uid: {uid}") + def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: + return Ok(dict(self.storage_permissions.items())) + def _all( self, credentials: SyftVerifyKey, diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 07e94564196..f3678d34d3b 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -4,8 +4,11 @@ # stdlib -# third party +# stdlib from collections import defaultdict +from typing import cast + +# third party from result import Err from result import Ok from result import Result @@ -13,15 +16,20 @@ # relative from ...serde.serializable import serializable from ...store.document_store import DocumentStore +from ...store.document_store import StorePartition from ...types.syft_object import SyftObject from ..action.action_object import Action from ..action.action_object import ActionObject +from ..action.action_permissions import ActionObjectPermission +from ..action.action_permissions import StoragePermission +from ..action.action_store import KeyValueActionStore from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL +from .object_migration_state import StoreMetadata from .object_migration_state import SyftMigrationStateStash from .object_migration_state import SyftObjectMigrationState @@ -115,6 +123,145 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated + @service_method( + path="migration.get_all_store_metadata", + name="get_all_store_metadata", + roles=ADMIN_ROLE_LEVEL, + ) + def get_all_store_metadata( + self, + context: AuthedServiceContext, + document_store_object_types: list[type[SyftObject]] | None = None, + include_action_store: bool = True, + ) -> dict[str, StoreMetadata] | SyftError: + res = self._get_all_store_metadata( + context, + document_store_object_types=document_store_object_types, + include_action_store=include_action_store, + ) + if res.is_err(): + return SyftError(message=res.value) + else: + return res.ok() + + def _get_partition_from_type( + self, + context: AuthedServiceContext, + object_type: type[SyftObject], + ) -> Result[KeyValueActionStore | StorePartition, str]: + object_partition: KeyValueActionStore | StorePartition | None = None + if issubclass(object_type, ActionObject): + object_partition = cast(KeyValueActionStore, context.node.action_store) + else: + canonical_name = object_type.__canonical_name__ + object_partition = self.store.partitions.get(canonical_name) + + if object_partition is None: + return Err(f"Object partition not found for {object_type}") # type: ignore + + return Ok(object_partition) + + def _get_store_metadata( + self, + context: AuthedServiceContext, + object_type: type[SyftObject], + ) -> Result[StoreMetadata, str]: + object_partition = self._get_partition_from_type(context, object_type) + if object_partition.is_err(): + return object_partition + object_partition = object_partition.ok() + + permissions = object_partition.get_all_permissions() + + if permissions.is_err(): + return permissions + permissions = permissions.ok() + + storage_permissions = object_partition.get_all_storage_permissions() + if storage_permissions.is_err(): + return storage_permissions + storage_permissions = storage_permissions.ok() + + return Ok( + StoreMetadata( + object_type=object_type, + permissions=permissions, + storage_permissions=storage_permissions, + ) + ) + + def _get_all_store_metadata( + self, + context: AuthedServiceContext, + document_store_object_types: list[type[SyftObject]] | None = None, + include_action_store: bool = True, + ) -> Result[dict[str, list[str]], str]: + if document_store_object_types is None: + document_store_object_types = self.store.get_partition_object_types() + + store_metadata = {} + for klass in document_store_object_types: + result = self._get_store_metadata(context, klass) + if result.is_err(): + return result + store_metadata[klass] = result.ok() + + if include_action_store: + result = self._get_store_metadata(context, ActionObject) + if result.is_err(): + return result + store_metadata[ActionObject] = result.ok() + + return Ok(store_metadata) + + @service_method( + path="migration.update_store_metadata", + name="update_store_metadata", + roles=ADMIN_ROLE_LEVEL, + ) + def update_store_metadata( + self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] + ) -> SyftSuccess | SyftError: + res = self._update_store_metadata(context, store_metadata) + if res.is_err(): + return SyftError(message=res.value) + else: + return SyftSuccess(message=res.ok()) + + def _update_store_metadata_for_klass( + self, context: AuthedServiceContext, metadata: StoreMetadata + ) -> Result[str, str]: + object_partition = self._get_partition_from_type(context, metadata.object_type) + if object_partition.is_err(): + return object_partition + object_partition = object_partition.ok() + + permissions = [ + ActionObjectPermission.from_permission_string(uid, perm_str) + for uid, perm_strs in metadata.permissions.items() + for perm_str in perm_strs + ] + + storage_permissions = [ + StoragePermission(uid, node_uid) + for uid, node_uids in metadata.storage_permissions.items() + for node_uid in node_uids + ] + + object_partition.add_permissions(permissions) + object_partition.add_storage_permissions(storage_permissions) + + return Ok("success") + + def _update_store_metadata( + self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] + ) -> Result[str, str]: + for metadata in store_metadata.values(): + result = self._update_store_metadata_for_klass(context, metadata) + if result.is_err(): + return result + return Ok("success") + @service_method( path="migration.get_migration_objects", name="get_migration_objects", @@ -137,12 +284,9 @@ def _get_migration_objects( context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, get_all: bool = False, - ) -> Result[dict, str]: + ) -> Result[dict[type[SyftObject], list[SyftObject]], str]: if document_store_object_types is None: - document_store_object_types = [ - partition.settings.object_type - for partition in self.store.partitions.values() - ] + document_store_object_types = self.store.get_partition_object_types() if get_all: klasses_to_migrate = document_store_object_types @@ -171,9 +315,14 @@ def _get_migration_objects( objects = objects_result.ok() for object in objects: actual_klass = type(object) - use_klass = klass if actual_klass.__canonical_name__ == klass.__canonical_name__ else actual_klass + use_klass = ( + klass + if actual_klass.__canonical_name__ == klass.__canonical_name__ + else actual_klass + ) result[use_klass].append(object) - return Ok(result) + + return Ok(dict(result)) @service_method( path="migration.update_migrated_objects", @@ -196,21 +345,25 @@ def _update_migrated_objects( klass = type(migrated_object) mro = klass.__mro__ class_index = 0 + object_partition = None while len(mro) > class_index: canonical_name = mro[class_index].__canonical_name__ object_partition = self.store.partitions.get(canonical_name) if object_partition is not None: break class_index += 1 - + if object_partition is None: + return Err(f"Object partition not found for {klass}") + # canonical_name = mro[class_index].__canonical_name__ # object_partition = self.store.partitions.get(canonical_name) - + # print(klass, canonical_name, object_partition) qk = object_partition.settings.store_key.with_obj(migrated_object.id) # print(migrated_object) + # stdlib import sys - + result = object_partition._update( context.credentials, qk=qk, @@ -219,7 +372,7 @@ def _update_migrated_objects( overwrite=True, allow_missing_keys=True, ) - + if result.is_err(): print("ERR:", result.value, file=sys.stderr) print("ERR:", klass, file=sys.stderr) @@ -326,7 +479,9 @@ def migrate_data( name="get_migration_actionobjects", roles=ADMIN_ROLE_LEVEL, ) - def get_migration_actionobjects(self, context: AuthedServiceContext): + def get_migration_actionobjects( + self, context: AuthedServiceContext + ) -> dict | SyftError: res = self._get_migration_actionobjects(context) if res.is_ok(): return res.ok() @@ -343,7 +498,9 @@ def _get_migration_actionobjects( action_object_pending_migration = self._find_klasses_pending_for_migration( context=context, object_types=action_object_types ) - result_dict = {x: [] for x in action_object_pending_migration} + result_dict: dict[type[SyftObject], SyftObject] = { + x: [] for x in action_object_pending_migration + } action_store = context.node.action_store action_store_objects_result = action_store._all( context.credentials, has_permission=True diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 77ea6f60f51..1f213fa2b5d 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -10,9 +10,11 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings +from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry +from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission @@ -80,3 +82,13 @@ def get_by_name( ) -> Result[SyftObjectMigrationState, str]: qks = KlassNamePartitionKey.with_obj(canonical_name) return self.query_one(credentials=credentials, qks=qks) + + +@serializable() +class StoreMetadata(SyftObject): + __canonical_name__ = "StoreMetadata" + __version__ = SYFT_OBJECT_VERSION_1 + + object_type: type + permissions: dict[UID, set[str]] + storage_permissions: dict[UID, set[UID]] diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index 6fa54fc8f67..b2fdee62d74 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -2,8 +2,6 @@ from collections.abc import Callable from typing import Any -from syft.util.util import get_env - # relative from ...abstract_node import NodeSideType from ...abstract_node import NodeType @@ -24,6 +22,7 @@ from ...util.misc_objs import HTMLObject from ...util.misc_objs import MarkdownDescription from ...util.schema import DEFAULT_WELCOME_MSG +from ...util.util import get_env @serializable() @@ -193,7 +192,10 @@ class NodeSettingsV2(SyftObject): def upgrade_node_settings() -> list[Callable]: return [ make_set_default("association_request_auto_approval", False), - make_set_default("default_worker_pool", get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME)) + make_set_default( + "default_worker_pool", + get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME), + ), ] diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 43e3a7d7e29..c0780b003e0 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -208,7 +208,7 @@ def all(self) -> tuple[QueryKey, ...] | list[QueryKey]: def from_obj(partition_keys: PartitionKeys, obj: SyftObject) -> QueryKeys: qks = [] for partition_key in partition_keys.all: - pk_key = partition_key.key # name of the attribute + pk_key = partition_key.key # name of the attribute pk_type = partition_key.type_ pk_value = getattr(obj, pk_key) # object has a method for getting these types @@ -489,6 +489,7 @@ def _update( obj: SyftObject, has_permission: bool = False, overwrite: bool = False, + allow_missing_keys: bool = False, ) -> Result[SyftObject, str]: raise NotImplementedError @@ -525,6 +526,9 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: def has_permission(self, permission: ActionObjectPermission) -> bool: raise NotImplementedError + def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: + raise NotImplementedError + def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: raise NotImplementedError @@ -543,6 +547,9 @@ def has_storage_permission(self, permission: StoragePermission | UID) -> bool: def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: raise NotImplementedError + def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: + raise NotImplementedError + def _migrate_data( self, to_klass: SyftObject, @@ -588,6 +595,11 @@ def partition(self, settings: PartitionSettings) -> StorePartition: ) return self.partitions[settings.name] + def get_partition_object_types(self) -> list[type]: + return [ + partition.settings.object_type for partition in self.partitions.values() + ] + @instrument class BaseStash: diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 7c921f26eb6..dec9d5d5db3 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -308,6 +308,9 @@ def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], Err]: return Ok(self.permissions[uid]) return Err(f"No permissions found for uid: {uid}") + def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: + return Ok(dict(self.permissions.items())) + def add_storage_permission(self, permission: StoragePermission) -> None: permissions = self.storage_permissions[permission.uid] permissions.add(permission.node_uid) @@ -330,6 +333,14 @@ def has_storage_permission(self, permission: StoragePermission | UID) -> bool: return permission.node_uid in self.storage_permissions[permission.uid] return False + def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], Err]: + if uid in self.storage_permissions: + return Ok(self.storage_permissions[uid]) + return Err(f"No storage permissions found for uid: {uid}") + + def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: + return Ok(dict(self.storage_permissions.items())) + def _all( self, credentials: SyftVerifyKey, @@ -343,11 +354,6 @@ def _all( result = sorted(result, key=lambda x: getattr(x, order_by.key, "")) return Ok(result) - def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], Err]: - if uid in self.storage_permissions: - return Ok(self.storage_permissions[uid]) - return Err(f"No storage permissions found for uid: {uid}") - def _remove_keys( self, store_key: QueryKey, @@ -417,7 +423,7 @@ def _update( obj: SyftObject, has_permission: bool = False, overwrite: bool = False, - allow_missing_keys=False, + allow_missing_keys: bool = False, ) -> Result[SyftObject, str]: try: if qk.value not in self.data: diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index 65640052133..c17f79b408b 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -336,7 +336,7 @@ def _update( obj: SyftObject, has_permission: bool = False, overwrite: bool = False, - allow_missing_keys=False, + allow_missing_keys: bool = False, ) -> Result[SyftObject, str]: collection_status = self.collection if collection_status.is_err(): @@ -502,7 +502,7 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: return False - def _get_permissions_for_uid(self, uid: UID) -> Result[Set[str], Err]: # noqa: UP006 + def _get_permissions_for_uid(self, uid: UID) -> Result[Set[str], str]: # noqa: UP006 collection_permissions_status = self.permissions if collection_permissions_status.is_err(): return collection_permissions_status @@ -515,6 +515,20 @@ def _get_permissions_for_uid(self, uid: UID) -> Result[Set[str], Err]: # noqa: return Ok(set(permissions["permissions"])) + def get_all_permissions(self) -> Result[dict[UID, Set[str]], str]: # noqa: UP006 + # Returns a dictionary of all permissions {object_uid: {*permissions}} + collection_permissions_status = self.permissions + if collection_permissions_status.is_err(): + return collection_permissions_status + collection_permissions: MongoCollection = collection_permissions_status.ok() + + permissions = collection_permissions.find({}) + permissions_dict = {} + for permission in permissions: + permissions_dict[permission["_id"]] = permission["permissions"] + + return Ok(permissions_dict) + def add_permission(self, permission: ActionObjectPermission) -> Result[None, Err]: collection_permissions_status = self.permissions if collection_permissions_status.is_err(): @@ -648,7 +662,7 @@ def remove_storage_permission( f"the node_uid {storage_permission.node_uid} does not exist in the storage permission!" ) - def _get_storage_permissions_for_uid(self, uid: UID) -> Result[Set[UID], Err]: # noqa: UP006 + def _get_storage_permissions_for_uid(self, uid: UID) -> Result[Set[UID], str]: # noqa: UP006 storage_permissions_or_err = self.storage_permissions if storage_permissions_or_err.is_err(): return storage_permissions_or_err @@ -665,6 +679,24 @@ def _get_storage_permissions_for_uid(self, uid: UID) -> Result[Set[UID], Err]: return Ok(set(storage_permissions["node_uids"])) + def get_all_storage_permissions(self) -> Result[dict[UID, Set[UID]], str]: # noqa: UP006 + # Returns a dictionary of all storage permissions {object_uid: {*node_uids}} + storage_permissions_or_err = self.storage_permissions + if storage_permissions_or_err.is_err(): + return storage_permissions_or_err + storage_permissions_collection: MongoCollection = ( + storage_permissions_or_err.ok() + ) + + storage_permissions = storage_permissions_collection.find({}) + storage_permissions_dict = {} + for storage_permission in storage_permissions: + storage_permissions_dict[storage_permission["_id"]] = storage_permission[ + "node_uids" + ] + + return Ok(storage_permissions_dict) + def take_ownership( self, uid: UID, credentials: SyftVerifyKey ) -> Result[SyftSuccess, str]: diff --git a/packages/syft/tests/syft/stores/permissions_test.py b/packages/syft/tests/syft/stores/permissions_test.py new file mode 100644 index 00000000000..cd5ccd0e9e6 --- /dev/null +++ b/packages/syft/tests/syft/stores/permissions_test.py @@ -0,0 +1,26 @@ +# stdlib +import secrets + +# syft absolute +from syft.service.action.action_permissions import ActionObjectPermission +from syft.service.action.action_permissions import ActionPermission +from syft.service.action.action_permissions import COMPOUND_ACTION_PERMISSION +from syft.service.action.action_permissions import SyftVerifyKey +from syft.service.action.action_permissions import UID + + +def test_permission_string_round_trip(): + for permission in ActionPermission: + uid = UID() + if permission in COMPOUND_ACTION_PERMISSION: + verify_key = None + else: + verify_key = SyftVerifyKey.from_string(secrets.token_hex(32)) + + original_obj = ActionObjectPermission(uid, permission, verify_key) + perm_string = original_obj.permission_string + recreated_obj = ActionObjectPermission.from_permission_string(uid, perm_string) + + assert original_obj.permission == recreated_obj.permission + assert original_obj.uid == recreated_obj.uid + assert original_obj.credentials == recreated_obj.credentials diff --git a/packages/syft/tests/syft/stores/store_mocks_test.py b/packages/syft/tests/syft/stores/store_mocks_test.py index 74c3a6d2f6d..9ab6a71cb57 100644 --- a/packages/syft/tests/syft/stores/store_mocks_test.py +++ b/packages/syft/tests/syft/stores/store_mocks_test.py @@ -7,7 +7,8 @@ from syft.store.document_store import PartitionSettings from syft.store.document_store import StoreConfig from syft.store.kv_document_store import KeyValueBackingStore -from syft.types.syft_object import SYFT_OBJECT_VERSION_2, SyftObject +from syft.types.syft_object import SYFT_OBJECT_VERSION_2 +from syft.types.syft_object import SyftObject from syft.types.uid import UID From 0848f05c4db5b904eb107732c71f196b77a49d14 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Thu, 27 Jun 2024 17:18:40 -0400 Subject: [PATCH 170/309] add missing migrations --- packages/syft/src/syft/service/code/user_code.py | 12 ++++++++++++ .../syft/service/code_history/code_history.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b8df87fc619..d2a27abc861 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1933,3 +1933,15 @@ def migrate_usercode_v5_to_v4() -> list[Callable]: drop("origin_node_side_type"), drop("l0_deny_reason"), ] + + +@migrate(UserCodeV5, UserCode) +def migrate_usercode_v5_to_v6() -> list[Callable]: + return [ + drop("enclave_metadata"), + ] + + +@migrate(UserCode, UserCodeV5) +def migrate_usercode_v6_to_v5() -> list[Callable]: + return [make_set_default("enclave_metadata", None)] diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index e013ef22c34..d1dde068581 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable import json from typing import Any @@ -7,10 +8,13 @@ from ...client.enclave_client import EnclaveMetadata from ...serde.serializable import serializable from ...service.user.user_roles import ServiceRole +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.syft_object import SyftVerifyKey +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID from ...util.notebook_ui.components.tabulator_template import ( build_tabulator_table_with_data, @@ -171,3 +175,15 @@ def _repr_html_(self) -> str | None: "icon": None, } return build_tabulator_table_with_data(rows, metadata) + + +@migrate(CodeHistoryV2, CodeHistory) +def code_history_v2_to_v3() -> list[Callable]: + return [drop("enclave_metadata")] + + +@migrate(CodeHistory, CodeHistoryV2) +def code_history_v3_to_v2() -> list[Callable]: + return [ + make_set_default("enclave_metadata", None), + ] From c72606008fb7c43b86635df7d2dc3a62b8a8d97f Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:15:56 +0530 Subject: [PATCH 171/309] Add hot reloading capability using uvicorn app factory pattern --- packages/syft/src/syft/node/server.py | 184 +++++++++++++------------- 1 file changed, 95 insertions(+), 89 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 43b8359a1f9..0b46727f0f6 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -1,9 +1,11 @@ # stdlib -import asyncio +import base64 from collections.abc import Callable from enum import Enum +import json import multiprocessing import os +from pathlib import Path import platform import signal import subprocess # nosec @@ -14,7 +16,6 @@ from fastapi import FastAPI import requests from starlette.middleware.cors import CORSMiddleware -import uvicorn # relative from ..abstract_node import NodeSideType @@ -56,11 +57,61 @@ def make_app(name: str, router: APIRouter) -> FastAPI: return app -worker_classes = { - NodeType.DOMAIN: Domain, - NodeType.GATEWAY: Gateway, - NodeType.ENCLAVE: Enclave, -} +def app_factory() -> FastAPI: + try: + kwargs_encoded = os.environ["APP_FACTORY_KWARGS"] + kwargs_json = base64.b64decode(kwargs_encoded) + kwargs = json.loads(kwargs_json) + name = kwargs["name"] + node_type = kwargs["node_type"] + node_side_type = kwargs["node_side_type"] + processes = kwargs["processes"] + reset = kwargs["reset"] + dev_mode = kwargs["dev_mode"] + enable_warnings = kwargs["enable_warnings"] + in_memory_workers = kwargs["in_memory_workers"] + queue_port = kwargs["queue_port"] + create_producer = kwargs["create_producer"] + n_consumers = kwargs["n_consumers"] + association_request_auto_approval = kwargs["association_request_auto_approval"] + background_tasks = kwargs["background_tasks"] + except KeyError as e: + raise KeyError(f"Missing required environment variable: {e}") + + worker_classes = { + NodeType.DOMAIN: Domain, + NodeType.GATEWAY: Gateway, + NodeType.ENCLAVE: Enclave, + } + if node_type not in worker_classes: + raise NotImplementedError(f"node_type: {node_type} is not supported") + worker_class = worker_classes[node_type] + kwargs = { + "name": name, + "processes": processes, + "local_db": True, + "node_type": node_type, + "node_side_type": node_side_type, + "enable_warnings": enable_warnings, + "migrate": True, + "in_memory_workers": in_memory_workers, + "queue_port": queue_port, + "create_producer": create_producer, + "n_consumers": n_consumers, + "association_request_auto_approval": association_request_auto_approval, + "background_tasks": background_tasks, + } + if dev_mode: + print( + f"\nWARNING: private key is based on node name: {name} in dev_mode. " + "Don't run this in production." + ) + kwargs["reset"] = reset + + worker = worker_class.named(**kwargs) if dev_mode else worker_class(**kwargs) + router = make_routes(worker=worker) + app = make_app(worker.name, router=router) + return app def run_uvicorn( @@ -80,89 +131,44 @@ def run_uvicorn( n_consumers: int, background_tasks: bool, ) -> None: - async def _run_uvicorn( - name: str, - node_type: NodeType, - host: str, - port: int, - reset: bool, - dev_mode: bool, - node_side_type: Enum, - ) -> None: - if node_type not in worker_classes: - raise NotImplementedError(f"node_type: {node_type} is not supported") - worker_class = worker_classes[node_type] - if dev_mode: - print( - f"\nWARNING: private key is based on node name: {name} in dev_mode. " - "Don't run this in production." - ) - - worker = worker_class.named( - name=name, - processes=processes, - reset=reset, - local_db=True, - node_type=node_type, - node_side_type=node_side_type, - enable_warnings=enable_warnings, - migrate=True, - in_memory_workers=in_memory_workers, - queue_port=queue_port, - create_producer=create_producer, - n_consumers=n_consumers, - association_request_auto_approval=association_request_auto_approval, - background_tasks=background_tasks, - ) - else: - worker = worker_class( - name=name, - processes=processes, - local_db=True, - node_type=node_type, - node_side_type=node_side_type, - enable_warnings=enable_warnings, - migrate=True, - in_memory_workers=in_memory_workers, - queue_port=queue_port, - create_producer=create_producer, - n_consumers=n_consumers, - association_request_auto_approval=association_request_auto_approval, - background_tasks=background_tasks, - ) - router = make_routes(worker=worker) - app = make_app(worker.name, router=router) - - if reset: - try: - python_pids = find_python_processes_on_port(port) - for pid in python_pids: - print(f"Stopping process on port: {port}") - kill_process(pid) - time.sleep(1) - except Exception: # nosec - print(f"Failed to kill python process on port: {port}") - - config = uvicorn.Config(app, host=host, port=port, reload=dev_mode) - server = uvicorn.Server(config) - - await server.serve() - asyncio.get_running_loop().stop() - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete( - _run_uvicorn( - name, - node_type, - host, - port, - reset, - dev_mode, - node_side_type, - ) + if reset: + try: + python_pids = find_python_processes_on_port(port) + for pid in python_pids: + print(f"Stopping process on port: {port}") + kill_process(pid) + time.sleep(1) + except Exception: # nosec + print(f"Failed to kill python process on port: {port}") + + kwargs = { + "name": name, + "node_type": node_type, + "node_side_type": node_side_type, + "processes": processes, + "reset": reset, + "dev_mode": dev_mode, + "enable_warnings": enable_warnings, + "in_memory_workers": in_memory_workers, + "queue_port": queue_port, + "create_producer": create_producer, + "n_consumers": n_consumers, + "association_request_auto_approval": association_request_auto_approval, + "background_tasks": background_tasks, + } + kwargs_json = json.dumps(kwargs) + kwargs_encoded = base64.b64encode(kwargs_json.encode()).decode() + uvicorn_cmd = ( + f"APP_FACTORY_KWARGS={kwargs_encoded}" + " uvicorn syft.node.server:app_factory" + " --factory" + f" --host {host}" + f" --port {port}" ) - loop.close() + if dev_mode: + uvicorn_cmd += f" --reload --reload-dir {Path(__file__).parent.parent}" + print(f"{uvicorn_cmd=}") + os.system(uvicorn_cmd) def serve_node( From 2272447b80c8cb7dac92c57308b78ba3442834be Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 28 Jun 2024 08:41:56 +0700 Subject: [PATCH 172/309] [syft/action_obj] stop saving data to cache if saving to the blob store --- packages/syft/src/syft/service/action/action_object.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 9f0c7baaf5d..51d208126b9 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -851,8 +851,6 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: "skipping writing action object to store, passed data was empty." ) - # self.syft_action_data_cache = data - return None def _save_to_blob_storage( From 9e8aaad09da0d2eab0567455a464c6594c17362a Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 28 Jun 2024 11:13:04 +0700 Subject: [PATCH 173/309] [syft/user_code] remove unused variables linter - add a catch for syntax error when parsing the submitted code --- .../syft/src/syft/service/code/user_code.py | 76 ++----------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 31560cb9fb1..ad51750ed19 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -96,6 +96,7 @@ from ..policy.policy import partition_by_node from ..policy.policy_service import PolicyService from ..response import SyftError +from ..response import SyftException from ..response import SyftInfo from ..response import SyftNotReady from ..response import SyftSuccess @@ -1267,15 +1268,6 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: global_check = _check_global(code) if isinstance(global_check, SyftWarning): display(global_check) - # err = SyftError(message=global_check.message) - # display(err) - # return err - - lint_issues = _lint_code(code) - lint_warning_msg = "" - for issue in lint_issues: - lint_warning_msg += f"{issue}\n\t" - display(SyftWarning(message=lint_warning_msg)) if name is not None: fname = name @@ -1335,62 +1327,6 @@ def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: return v -# Define a linter function -def _lint_code(code: str) -> list: - # Parse the code into an AST - tree = ast.parse(code) - - # Initialize a list to collect linting issues - issues = [] - - # Define a visitor class to walk the AST - class CodeVisitor(ast.NodeVisitor): - def __init__(self) -> None: - self.globals: set = set() - self.defined_names: set = set() - self.current_scope_defined_names: set = set() - - def visit_Global(self, node: Any) -> None: - # Collect global variable names - for name in node.names: - self.globals.add(name) - self.generic_visit(node) - - def visit_FunctionDef(self, node: Any) -> None: - # Collect defined function names and handle function scope - self.defined_names.add(node.name) - self.current_scope_defined_names = set() # New scope - self.generic_visit(node) - self.current_scope_defined_names.clear() # Clear scope after visiting - - def visit_Assign(self, node: Any) -> None: - # Collect assigned variable names - for target in node.targets: - if isinstance(target, ast.Name): - self.current_scope_defined_names.add(target.id) - self.generic_visit(node) - - def visit_Name(self, node: Any) -> None: - # Check if variables are used before being defined - if isinstance(node.ctx, ast.Load): - if ( - node.id not in self.current_scope_defined_names - and node.id not in self.defined_names - and node.id not in self.globals - ): - issues.append( - f"Variable '{node.id}' used at line {node.lineno} before being defined." - ) - self.generic_visit(node) - - # Create a visitor instance and visit the AST - visitor = CodeVisitor() - visitor.visit(tree) - - # Return the collected issues - return issues - - def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] @@ -1413,11 +1349,15 @@ def process_code( policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - tree = ast.parse(raw_code) + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + raise SyftException(f"Syntax error in code: {e}") # check there are no globals - v = GlobalsVisitor() - v.visit(tree) + v = _check_global(raw_code=tree) + if isinstance(v, SyftWarning): + raise SyftException(message=f"{v.message}") f = tree.body[0] f.decorator_list = [] From e6f31f0e3a9d35391e8a7dff6c583e4826101b17 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 11:25:23 +0530 Subject: [PATCH 174/309] rename rathole to reverse tunnel in syft application - rename rathole path and token to rtunnel --- packages/grid/backend/grid/core/config.py | 4 +-- packages/grid/devspace.yaml | 12 ++++---- .../backend/backend-statefulset.yaml | 6 ++-- .../syft/templates/proxy/proxy-configmap.yaml | 8 ++--- .../templates/rathole/rathole-configmap.yaml | 4 +-- .../rathole/rathole-statefulset.yaml | 30 ++++++++----------- packages/grid/helm/syft/values.yaml | 6 ++-- packages/grid/helm/values.dev.yaml | 5 ++-- packages/syft/src/syft/client/client.py | 18 +++++------ .../service/network/association_request.py | 12 ++++---- .../syft/service/network/network_service.py | 30 +++++++++---------- .../src/syft/service/network/node_peer.py | 4 +-- .../syft/src/syft/service/network/rathole.py | 2 +- .../service/network/rathole_config_builder.py | 8 ++--- .../service/network/reverse_tunnel_service.py | 6 ++-- .../syft/src/syft/service/network/routes.py | 4 +-- packages/syft/src/syft/util/util.py | 2 +- 17 files changed, 78 insertions(+), 83 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 33d65719fe8..b6f5ddf9067 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,8 +155,8 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) - REVERSE_TUNNEL_RATHOLE_ENABLED: bool = str_to_bool( - os.getenv("REVERSE_TUNNEL_RATHOLE_ENABLED", "false") + REVERSE_TUNNEL_ENABLED: bool = str_to_bool( + os.getenv("REVERSE_TUNNEL_ENABLED", "false") ) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index d1c6facde50..7e37e705214 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -86,7 +86,7 @@ deployments: version: dev-${DEVSPACE_TIMESTAMP} node: type: domain # required for the gateway profile - rathole: + reverse_tunnel: mode: client dev: @@ -125,12 +125,12 @@ dev: - path: ../syft:/root/app/syft ssh: localPort: 3480 - rathole: + reverse_tunnel: labelSelector: app.kubernetes.io/name: syft - app.kubernetes.io/component: rathole + app.kubernetes.io/component: reverse_tunnel ports: - - port: "2333" # rathole + - port: "2333" # reverse_tunnel profiles: - name: dev-low @@ -158,7 +158,7 @@ profiles: # Patch mode to server - op: replace - path: deployments.syft.helm.values.rathole.mode + path: deployments.syft.helm.values.reverse_tunnel.mode value: server # Port Re-Mapping @@ -179,7 +179,7 @@ profiles: # Mongo - op: replace - path: dev.rathole.ports[0].port + path: dev.reverse_tunnel.ports[0].port value: 2334:2333 - name: gcp diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index e08db908fea..e5dd285985f 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -91,10 +91,10 @@ spec: - name: ASSOCIATION_REQUEST_AUTO_APPROVAL value: {{ .Values.node.associationRequestAutoApproval | quote }} {{- end }} - {{- if .Values.rathole.enabled }} + {{- if .Values.reverse_tunnel.enabled }} - name: RATHOLE_PORT - value: {{ .Values.rathole.port | quote }} - - name: REVERSE_TUNNEL_RATHOLE_ENABLED + value: {{ .Values.reverse_tunnel.port | quote }} + - name: REVERSE_TUNNEL_ENABLED value: "true" {{- end }} # MongoDB diff --git a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 705c153442a..1bcdff49876 100644 --- a/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -28,22 +28,22 @@ data: - url: "http://rathole:2333" routers: rathole: - rule: "PathPrefix(`/`) && Headers(`Upgrade`, `websocket`) && !PathPrefix(`/rathole`)" + rule: "PathPrefix(`/`) && Headers(`Upgrade`, `websocket`) && !PathPrefix(`/rtunnel`)" entryPoints: - "web" service: "rathole" frontend: - rule: "PathPrefix(`/`) && !PathPrefix(`/rathole`)" + rule: "PathPrefix(`/`) && !PathPrefix(`/rtunnel`)" entryPoints: - "web" service: "frontend" backend: - rule: "(PathPrefix(`/api`) || PathPrefix(`/docs`) || PathPrefix(`/redoc`)) && !PathPrefix(`/rathole`)" + rule: "(PathPrefix(`/api`) || PathPrefix(`/docs`) || PathPrefix(`/redoc`)) && !PathPrefix(`/rtunnel`)" entryPoints: - "web" service: "backend" blob-storage: - rule: "PathPrefix(`/blob`) && !PathPrefix(`/rathole`)" + rule: "PathPrefix(`/blob`) && !PathPrefix(`/rtunnel`)" entryPoints: - "web" service: "seaweedfs" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index cd5453f1ea2..3bebf40cffc 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -7,7 +7,7 @@ metadata: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: rathole data: - {{- if eq .Values.rathole.mode "server" }} + {{- if eq .Values.reverse_tunnel.mode "server" }} server.toml: | [server] bind_addr = "0.0.0.0:2333" @@ -22,7 +22,7 @@ data: bind_addr = "0.0.0.0:8001" {{- end }} - {{- if eq .Values.rathole.mode "client" }} + {{- if eq .Values.reverse_tunnel.mode "client" }} client.toml: | [client] remote_addr = "0.0.0.0:2333" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index c441991ef47..dfa068218d6 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -20,36 +20,30 @@ spec: labels: {{- include "common.labels" . | nindent 8 }} app.kubernetes.io/component: rathole - {{- if .Values.rathole.podLabels }} - {{- toYaml .Values.rathole.podLabels | nindent 8 }} + {{- if .Values.reverse_tunnel.podLabels }} + {{- toYaml .Values.reverse_tunnel.podLabels | nindent 8 }} {{- end }} - {{- if .Values.rathole.podAnnotations }} - annotations: {{- toYaml .Values.rathole.podAnnotations | nindent 8 }} + {{- if .Values.reverse_tunnel.podAnnotations }} + annotations: {{- toYaml .Values.reverse_tunnel.podAnnotations | nindent 8 }} {{- end }} spec: - {{- if .Values.rathole.nodeSelector }} - nodeSelector: {{- .Values.rathole.nodeSelector | toYaml | nindent 8 }} + {{- if .Values.reverse_tunnel.nodeSelector }} + nodeSelector: {{- .Values.reverse_tunnel.nodeSelector | toYaml | nindent 8 }} {{- end }} containers: - name: rathole image: {{ .Values.global.registry }}/openmined/grid-rathole:{{ .Values.global.version }} imagePullPolicy: Always - resources: {{ include "common.resources.set" (dict "resources" .Values.rathole.resources "preset" .Values.rathole.resourcesPreset) | nindent 12 }} + resources: {{ include "common.resources.set" (dict "resources" .Values.reverse_tunnel.resources "preset" .Values.reverse_tunnel.resourcesPreset) | nindent 12 }} env: - - name: SERVICE_NAME - value: "rathole" - name: LOG_LEVEL - value: {{ .Values.rathole.logLevel | quote }} + value: {{ .Values.reverse_tunnel.logLevel | quote }} - name: MODE - value: {{ .Values.rathole.mode | quote }} - - name: DEV_MODE - value: {{ .Values.rathole.devMode | quote }} - - name: APP_PORT - value: {{ .Values.rathole.appPort | quote }} + value: {{ .Values.reverse_tunnel.mode | quote }} - name: RATHOLE_PORT - value: {{ .Values.rathole.ratholePort | quote }} - {{- if .Values.rathole.env }} - {{- toYaml .Values.rathole.env | nindent 12 }} + value: {{ .Values.reverse_tunnel.port | quote }} + {{- if .Values.reverse_tunnel.env }} + {{- toYaml .Values.reverse_tunnel.env | nindent 12 }} {{- end }} ports: - name: rathole-port diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index e6c5cb51be2..8e225b94599 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -238,11 +238,11 @@ ingress: # ================================================================================= -rathole: +reverse_tunnel: # Extra environment vars env: null - enabled: true - logLevel: info + enabled: false + logLevel: "info" port: 2333 mode: client diff --git a/packages/grid/helm/values.dev.yaml b/packages/grid/helm/values.dev.yaml index 713948f4a9b..e9c32543c0b 100644 --- a/packages/grid/helm/values.dev.yaml +++ b/packages/grid/helm/values.dev.yaml @@ -47,9 +47,10 @@ proxy: resourcesPreset: null resources: null -rathole: - enabled: "true" +reverse_tunnel: + enabled: true logLevel: "trace" + # attestation: # enabled: true # resourcesPreset: null diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 9540a50fb7a..f153fe88f41 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -117,7 +117,7 @@ def forward_message_to_proxy( API_PATH = "/api/v2" DEFAULT_PYGRID_PORT = 80 DEFAULT_PYGRID_ADDRESS = f"http://localhost:{DEFAULT_PYGRID_PORT}" -INTERNAL_PROXY_TO_RATHOLE = "http://proxy:80/rathole/" +INTERNAL_PROXY_TO_RATHOLE = "http://proxy:80/rtunnel/" class Routes(Enum): @@ -130,7 +130,7 @@ class Routes(Enum): STREAM = f"{API_PATH}/stream" -@serializable(attrs=["proxy_target_uid", "url", "rathole_token"]) +@serializable(attrs=["proxy_target_uid", "url", "rtunnel_token"]) class HTTPConnection(NodeConnection): __canonical_name__ = "HTTPConnection" __version__ = SYFT_OBJECT_VERSION_3 @@ -140,7 +140,7 @@ class HTTPConnection(NodeConnection): routes: type[Routes] = Routes session_cache: Session | None = None headers: dict[str, str] | None = None - rathole_token: str | None = None + rtunnel_token: str | None = None @field_validator("url", mode="before") @classmethod @@ -158,7 +158,7 @@ def with_proxy(self, proxy_target_uid: UID) -> Self: return HTTPConnection( url=self.url, proxy_target_uid=proxy_target_uid, - rathole_token=self.rathole_token, + rtunnel_token=self.rtunnel_token, ) def stream_via(self, proxy_uid: UID, url_path: str) -> GridURL: @@ -200,7 +200,7 @@ def _make_get( url = self.url - if self.rathole_token: + if self.rtunnel_token: self.headers = {} if self.headers is None else self.headers url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) self.headers["Host"] = self.url.host_or_ip @@ -229,7 +229,7 @@ def _make_get( def _make_get_no_params(self, path: str, stream: bool = False) -> bytes | Iterable: url = self.url - if self.rathole_token: + if self.rtunnel_token: self.headers = {} if self.headers is None else self.headers url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) self.headers["Host"] = self.url.host_or_ip @@ -261,7 +261,7 @@ def _make_put( ) -> Response: url = self.url - if self.rathole_token: + if self.rtunnel_token: url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) self.headers = {} if self.headers is None else self.headers self.headers["Host"] = self.url.host_or_ip @@ -293,7 +293,7 @@ def _make_post( ) -> bytes: url = self.url - if self.rathole_token: + if self.rtunnel_token: url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) self.headers = {} if self.headers is None else self.headers self.headers["Host"] = self.url.host_or_ip @@ -408,7 +408,7 @@ def register(self, new_user: UserCreate) -> SyftSigningKey: def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError: msg_bytes: bytes = _serialize(obj=signed_call, to_bytes=True) - if self.rathole_token: + if self.rtunnel_token: api_url = GridURL.from_url(INTERNAL_PROXY_TO_RATHOLE) api_url = api_url.with_path(self.routes.ROUTE_API_CALL.value) self.headers = {} if self.headers is None else self.headers diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index a910302005e..a13dd5085d3 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -59,15 +59,15 @@ def _run( ) network_stash = network_service.stash - # Check if remote peer to be added is via rathole - rathole_route = self.remote_peer.get_rathole_route() - add_rathole_route = ( - rathole_route is not None - and self.remote_peer.latest_added_route == rathole_route + # Check if remote peer to be added is via reverse tunnel + rtunnel_route = self.remote_peer.get_rtunnel_route() + add_rtunnel_route = ( + rtunnel_route is not None + and self.remote_peer.latest_added_route == rtunnel_route ) # If the remote peer is added via reverse tunnel, we skip ping to peer - if add_rathole_route: + if add_rtunnel_route: network_service.set_reverse_tunnel_config( context=context, remote_node_peer=self.remote_peer, diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 4aa48481df2..35de9bfaf91 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -67,11 +67,11 @@ NodeTypePartitionKey = PartitionKey(key="node_type", type_=NodeType) OrderByNamePartitionKey = PartitionKey(key="name", type_=str) -REVERSE_TUNNEL_RATHOLE_ENABLED = "REVERSE_TUNNEL_RATHOLE_ENABLED" +REVERSE_TUNNEL_ENABLED = "REVERSE_TUNNEL_ENABLED" def reverse_tunnel_enabled() -> bool: - return str_to_bool(get_env(REVERSE_TUNNEL_RATHOLE_ENABLED, "false")) + return str_to_bool(get_env(REVERSE_TUNNEL_ENABLED, "false")) @serializable() @@ -192,10 +192,10 @@ def exchange_credentials_with( if reverse_tunnel and not reverse_tunnel_enabled(): return SyftError(message="Reverse tunneling is not enabled on this node.") elif reverse_tunnel: - _rathole_route = self_node_peer.node_routes[-1] - _rathole_route.rathole_token = generate_token() - _rathole_route.host_or_ip = f"{self_node_peer.name}.syft.local" - self_node_peer.node_routes[-1] = _rathole_route + _rtunnel_route = self_node_peer.node_routes[-1] + _rtunnel_route.rtunnel_token = generate_token() + _rtunnel_route.host_or_ip = f"{self_node_peer.name}.syft.local" + self_node_peer.node_routes[-1] = _rtunnel_route if isinstance(self_node_peer, SyftError): return self_node_peer @@ -259,7 +259,7 @@ def exchange_credentials_with( ) return SyftError(message="Failed to update route information.") - # Step 5: Save rathole config to enable reverse tunneling + # Step 5: Save config to enable reverse tunneling if reverse_tunnel and reverse_tunnel_enabled(): self.set_reverse_tunnel_config( context=context, @@ -498,10 +498,10 @@ def set_reverse_tunnel_config( ) -> None: node_type = cast(NodeType, context.node.node_type) if node_type.value == NodeType.GATEWAY.value: - rathole_route = remote_node_peer.get_rathole_route() + rtunnel_route = remote_node_peer.get_rtunnel_route() ( self.rtunnel_service.set_server_config(remote_node_peer) - if rathole_route + if rtunnel_route else None ) else: @@ -510,13 +510,13 @@ def set_reverse_tunnel_config( if self_node_peer is None else self_node_peer ) - rathole_route = self_node_peer.get_rathole_route() + rtunnel_route = self_node_peer.get_rtunnel_route() ( self.rtunnel_service.set_client_config( self_node_peer=self_node_peer, remote_node_route=remote_node_peer.pick_highest_priority_route(), ) - if rathole_route + if rtunnel_route else None ) @@ -538,10 +538,10 @@ def delete_peer_by_id( node_side_type = cast(NodeType, context.node.node_type) if node_side_type.value == NodeType.GATEWAY.value: - rathole_route = peer_to_delete.get_rathole_route() + rtunnel_route = peer_to_delete.get_rtunnel_route() ( self.rtunnel_service.clear_server_config(peer_to_delete) - if rathole_route + if rtunnel_route else None ) @@ -955,7 +955,7 @@ def from_grid_url(context: TransformContext) -> TransformContext: context.output["private"] = False context.output["proxy_target_uid"] = context.obj.proxy_target_uid context.output["priority"] = 1 - context.output["rathole_token"] = context.obj.rathole_token + context.output["rtunnel_token"] = context.obj.rtunnel_token return context @@ -995,7 +995,7 @@ def node_route_to_http_connection( return HTTPConnection( url=url, proxy_target_uid=obj.proxy_target_uid, - rathole_token=obj.rathole_token, + rtunnel_token=obj.rtunnel_token, ) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index e3f14481f36..23c1ffc7057 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -285,9 +285,9 @@ def guest_client(self) -> SyftClient: def proxy_from(self, client: SyftClient) -> SyftClient: return client.proxy_to(self) - def get_rathole_route(self) -> HTTPNodeRoute | None: + def get_rtunnel_route(self) -> HTTPNodeRoute | None: for route in self.node_routes: - if hasattr(route, "rathole_token") and route.rathole_token: + if hasattr(route, "rtunnel_token") and route.rtunnel_token: return route return None diff --git a/packages/syft/src/syft/service/network/rathole.py b/packages/syft/src/syft/service/network/rathole.py index d311bfa9c2b..4fd0be445b1 100644 --- a/packages/syft/src/syft/service/network/rathole.py +++ b/packages/syft/src/syft/service/network/rathole.py @@ -36,7 +36,7 @@ def from_peer(cls, peer: NodePeer) -> Self: return cls( uuid=peer.id, - secret_token=peer.rathole_token, + secret_token=peer.rtunnel_token, local_addr_host=high_priority_route.host_or_ip, local_addr_port=high_priority_route.port, server_name=peer.name, diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index 90f499ec237..f134468a120 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -36,7 +36,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: None """ - rathole_route = peer.get_rathole_route() + rathole_route = peer.get_rtunnel_route() if not rathole_route: raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") @@ -46,7 +46,7 @@ def add_host_to_server(self, peer: NodePeer) -> None: config = RatholeConfig( uuid=peer_id.to_string(), - secret_token=rathole_route.rathole_token, + secret_token=rathole_route.rtunnel_token, local_addr_host=DEFAULT_LOCAL_ADDR_HOST, local_addr_port=random_port, server_name=peer.name, @@ -120,13 +120,13 @@ def _get_random_port(self) -> int: return secrets.randbits(15) def add_host_to_client( - self, peer_name: str, peer_id: str, rathole_token: str, remote_addr: str + self, peer_name: str, peer_id: str, rtunnel_token: str, remote_addr: str ) -> None: """Add a host to the rathole client toml file.""" config = RatholeConfig( uuid=peer_id, - secret_token=rathole_token, + secret_token=rtunnel_token, local_addr_host="proxy", local_addr_port=80, server_name=peer_name, diff --git a/packages/syft/src/syft/service/network/reverse_tunnel_service.py b/packages/syft/src/syft/service/network/reverse_tunnel_service.py index bb80c56f401..36d5b14e151 100644 --- a/packages/syft/src/syft/service/network/reverse_tunnel_service.py +++ b/packages/syft/src/syft/service/network/reverse_tunnel_service.py @@ -14,7 +14,7 @@ def set_client_config( self_node_peer: NodePeer, remote_node_route: NodeRoute, ) -> None: - rathole_route = self_node_peer.get_rathole_route() + rathole_route = self_node_peer.get_rtunnel_route() if not rathole_route: raise Exception( "Failed to exchange routes via . " @@ -31,12 +31,12 @@ def set_client_config( self.builder.add_host_to_client( peer_name=self_node_peer.name, peer_id=str(self_node_peer.id), - rathole_token=rathole_route.rathole_token, + rtunnel_token=rathole_route.rtunnel_token, remote_addr=remote_addr, ) def set_server_config(self, remote_peer: NodePeer) -> None: - rathole_route = remote_peer.get_rathole_route() + rathole_route = remote_peer.get_rtunnel_route() self.builder.add_host_to_server(remote_peer) if rathole_route else None def clear_client_config(self, self_node_peer: NodePeer) -> None: diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index b1ff68f6b72..ca5ea04c999 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -96,7 +96,7 @@ class HTTPNodeRoute(SyftObject, NodeRoute): port: int = 80 proxy_target_uid: UID | None = None priority: int = 1 - rathole_token: str | None = None + rtunnel_token: str | None = None def __eq__(self, other: Any) -> bool: if not isinstance(other, HTTPNodeRoute): @@ -109,7 +109,7 @@ def __hash__(self) -> int: + hash(self.port) + hash(self.protocol) + hash(self.proxy_target_uid) - + hash(self.rathole_token) + + hash(self.rtunnel_token) ) def __str__(self) -> str: diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index eedc61173a0..d8098f55e1d 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -920,7 +920,7 @@ def get_dev_mode() -> bool: def generate_token() -> str: - return hashlib.sha256(secrets.token_bytes(16)).hexdigest() + return secrets.token_hex(64) def sanitize_html(html: str) -> str: From b393b7a80c92c1de6bcf6a676a61da102871affc Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 11:46:18 +0530 Subject: [PATCH 175/309] rename reverse_tunnel to rtunnel in values and values.dev yaml --- packages/grid/devspace.yaml | 10 ++++---- .../backend/backend-service-account.yaml | 4 ++-- .../backend/backend-statefulset.yaml | 4 ++-- .../templates/rathole/rathole-configmap.yaml | 8 ++----- .../rathole/rathole-statefulset.yaml | 24 +++++++++---------- packages/grid/helm/syft/values.yaml | 2 +- packages/grid/helm/values.dev.yaml | 2 +- packages/grid/rathole/start.sh | 13 ++++++++-- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 7e37e705214..41e0ac00aa4 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -86,7 +86,7 @@ deployments: version: dev-${DEVSPACE_TIMESTAMP} node: type: domain # required for the gateway profile - reverse_tunnel: + rathole: mode: client dev: @@ -125,10 +125,10 @@ dev: - path: ../syft:/root/app/syft ssh: localPort: 3480 - reverse_tunnel: + rathole: labelSelector: app.kubernetes.io/name: syft - app.kubernetes.io/component: reverse_tunnel + app.kubernetes.io/component: rathole ports: - port: "2333" # reverse_tunnel @@ -158,7 +158,7 @@ profiles: # Patch mode to server - op: replace - path: deployments.syft.helm.values.reverse_tunnel.mode + path: deployments.syft.helm.values.rtunnel.mode value: server # Port Re-Mapping @@ -179,7 +179,7 @@ profiles: # Mongo - op: replace - path: dev.reverse_tunnel.ports[0].port + path: dev.rtunnel.ports[0].port value: 2334:2333 - name: gcp diff --git a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml index 76d70afee70..ee4634fc45f 100644 --- a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml @@ -26,10 +26,10 @@ metadata: app.kubernetes.io/component: backend rules: - apiGroups: [""] - resources: ["pods", "configmaps", "secrets", "services"] + resources: ["pods", "configmaps", "secrets"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] - apiGroups: [""] - resources: ["pods/log"] + resources: ["pods/log", "services"] verbs: ["get", "list", "watch"] - apiGroups: ["batch"] resources: ["jobs"] diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index e5dd285985f..b86916d6bde 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -91,9 +91,9 @@ spec: - name: ASSOCIATION_REQUEST_AUTO_APPROVAL value: {{ .Values.node.associationRequestAutoApproval | quote }} {{- end }} - {{- if .Values.reverse_tunnel.enabled }} + {{- if .Values.rtunnel.enabled }} - name: RATHOLE_PORT - value: {{ .Values.reverse_tunnel.port | quote }} + value: {{ .Values.rtunnel.port | quote }} - name: REVERSE_TUNNEL_ENABLED value: "true" {{- end }} diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index 3bebf40cffc..ae506486ce4 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -7,7 +7,7 @@ metadata: {{- include "common.labels" . | nindent 4 }} app.kubernetes.io/component: rathole data: - {{- if eq .Values.reverse_tunnel.mode "server" }} + {{- if eq .Values.rtunnel.mode "server" }} server.toml: | [server] bind_addr = "0.0.0.0:2333" @@ -16,13 +16,9 @@ data: type = "websocket" [server.transport.websocket] tls = false - - [server.services.domain] - token = "domain-specific-rathole-secret" - bind_addr = "0.0.0.0:8001" {{- end }} - {{- if eq .Values.reverse_tunnel.mode "client" }} + {{- if eq .Values.rtunnel.mode "client" }} client.toml: | [client] remote_addr = "0.0.0.0:2333" diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index dfa068218d6..b1286f6a1f4 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -20,30 +20,30 @@ spec: labels: {{- include "common.labels" . | nindent 8 }} app.kubernetes.io/component: rathole - {{- if .Values.reverse_tunnel.podLabels }} - {{- toYaml .Values.reverse_tunnel.podLabels | nindent 8 }} + {{- if .Values.rtunnel.podLabels }} + {{- toYaml .Values.rtunnel.podLabels | nindent 8 }} {{- end }} - {{- if .Values.reverse_tunnel.podAnnotations }} - annotations: {{- toYaml .Values.reverse_tunnel.podAnnotations | nindent 8 }} + {{- if .Values.rtunnel.podAnnotations }} + annotations: {{- toYaml .Values.rtunnel.podAnnotations | nindent 8 }} {{- end }} spec: - {{- if .Values.reverse_tunnel.nodeSelector }} - nodeSelector: {{- .Values.reverse_tunnel.nodeSelector | toYaml | nindent 8 }} + {{- if .Values.rtunnel.nodeSelector }} + nodeSelector: {{- .Values.rtunnel.nodeSelector | toYaml | nindent 8 }} {{- end }} containers: - name: rathole image: {{ .Values.global.registry }}/openmined/grid-rathole:{{ .Values.global.version }} imagePullPolicy: Always - resources: {{ include "common.resources.set" (dict "resources" .Values.reverse_tunnel.resources "preset" .Values.reverse_tunnel.resourcesPreset) | nindent 12 }} + resources: {{ include "common.resources.set" (dict "resources" .Values.rtunnel.resources "preset" .Values.rtunnel.resourcesPreset) | nindent 12 }} env: - name: LOG_LEVEL - value: {{ .Values.reverse_tunnel.logLevel | quote }} + value: {{ .Values.rtunnel.logLevel | quote }} - name: MODE - value: {{ .Values.reverse_tunnel.mode | quote }} + value: {{ .Values.rtunnel.mode | quote }} - name: RATHOLE_PORT - value: {{ .Values.reverse_tunnel.port | quote }} - {{- if .Values.reverse_tunnel.env }} - {{- toYaml .Values.reverse_tunnel.env | nindent 12 }} + value: {{ .Values.rtunnel.port | quote }} + {{- if .Values.rtunnel.env }} + {{- toYaml .Values.rtunnel.env | nindent 12 }} {{- end }} ports: - name: rathole-port diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 8e225b94599..b5372d8e857 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -238,7 +238,7 @@ ingress: # ================================================================================= -reverse_tunnel: +rtunnel: # Extra environment vars env: null enabled: false diff --git a/packages/grid/helm/values.dev.yaml b/packages/grid/helm/values.dev.yaml index e9c32543c0b..493850cbb67 100644 --- a/packages/grid/helm/values.dev.yaml +++ b/packages/grid/helm/values.dev.yaml @@ -47,7 +47,7 @@ proxy: resourcesPreset: null resources: null -reverse_tunnel: +rtunnel: enabled: true logLevel: "trace" diff --git a/packages/grid/rathole/start.sh b/packages/grid/rathole/start.sh index b1af50597fc..87111ac8c9f 100755 --- a/packages/grid/rathole/start.sh +++ b/packages/grid/rathole/start.sh @@ -8,9 +8,18 @@ copy_config() { cp -L -r -f /conf/* conf/ } -# Start the server +# Start the server and reload until healthy start_server() { - RUST_LOG=$RUST_LOG /app/rathole conf/server.toml & + while true; do + RUST_LOG=$RUST_LOG /app/rathole conf/server.toml + status=$? + if [ $status -eq 0 ]; then + break + else + echo "Server failed to start, retrying in 5 seconds..." + sleep 5 + fi + done & } # Start the client From 30f1132082ad14b8cf3c4e5618bcce9a2639a160 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 12:44:11 +0530 Subject: [PATCH 176/309] revert backend account to have permission to patch services - rename reference of rathole_token to rtunnel_token in gateway test --- packages/grid/devspace.yaml | 2 +- .../helm/syft/templates/backend/backend-service-account.yaml | 4 ++-- tests/integration/network/gateway_test.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 41e0ac00aa4..dc0d2dbd142 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -86,7 +86,7 @@ deployments: version: dev-${DEVSPACE_TIMESTAMP} node: type: domain # required for the gateway profile - rathole: + rtunnel: mode: client dev: diff --git a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml index ee4634fc45f..76d70afee70 100644 --- a/packages/grid/helm/syft/templates/backend/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml @@ -26,10 +26,10 @@ metadata: app.kubernetes.io/component: backend rules: - apiGroups: [""] - resources: ["pods", "configmaps", "secrets"] + resources: ["pods", "configmaps", "secrets", "services"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] - apiGroups: [""] - resources: ["pods/log", "services"] + resources: ["pods/log"] verbs: ["get", "list", "watch"] - apiGroups: ["batch"] resources: ["jobs"] diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index ef8b090a8a6..d64ab32fdca 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -946,7 +946,7 @@ def test_reverse_tunnel_connection(domain_1_port: int, gateway_port: int): # Domain's peer is a gateway and vice-versa domain_peer = domain_client.peers[0] assert domain_peer.node_type == NodeType.GATEWAY - assert domain_peer.node_routes[0].rathole_token is None + assert domain_peer.node_routes[0].rtunnel_token is None assert len(gateway_client.peers) == 0 gateway_client_root = gateway_client.login( @@ -960,7 +960,7 @@ def test_reverse_tunnel_connection(domain_1_port: int, gateway_port: int): gateway_peers = gateway_client.api.services.network.get_all_peers() assert len(gateway_peers) == 1 assert len(gateway_peers[0].node_routes) == 1 - assert gateway_peers[0].node_routes[0].rathole_token is not None + assert gateway_peers[0].node_routes[0].rtunnel_token is not None proxy_domain_client = gateway_client.peers[0] From ad31f5bbc924bc3a8b880767cb68dbc2907641a5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 13:35:20 +0530 Subject: [PATCH 177/309] force enable proxy in case of gateways --- packages/grid/devspace.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index dc0d2dbd142..b961d17da26 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -88,6 +88,8 @@ deployments: type: domain # required for the gateway profile rtunnel: mode: client + proxy: + enabled: true # required for the gateway profile dev: mongo: @@ -161,6 +163,11 @@ profiles: path: deployments.syft.helm.values.rtunnel.mode value: server + # Enable proxy for gateway + - op: replace + path: deployments.syft.helm.values.proxy.enabled + value: true + # Port Re-Mapping # Mongo - op: replace @@ -177,7 +184,7 @@ profiles: path: dev.backend.containers.backend-container.ssh.localPort value: 3481 - # Mongo + # Reverse tunnel port - op: replace path: dev.rtunnel.ports[0].port value: 2334:2333 From a95a815098e7f792a4885df8d35bca1c565f14e8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 13:57:54 +0530 Subject: [PATCH 178/309] start rathole pod, config, and service if rtunnel flag is enabled --- .../grid/helm/syft/templates/rathole/rathole-configmap.yaml | 2 ++ .../grid/helm/syft/templates/rathole/rathole-service.yaml | 2 ++ .../grid/helm/syft/templates/rathole/rathole-statefulset.yaml | 4 +++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml index ae506486ce4..77f2bec4c4b 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-configmap.yaml @@ -1,3 +1,4 @@ +{{- if .Values.rtunnel.enabled }} apiVersion: v1 kind: ConfigMap metadata: @@ -28,3 +29,4 @@ data: [client.transport.websocket] tls = false {{- end }} +{{- end }} diff --git a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml index 01fa305ac77..087da2256e6 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-service.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-service.yaml @@ -1,3 +1,4 @@ +{{- if .Values.rtunnel.enabled }} apiVersion: v1 kind: Service metadata: @@ -15,3 +16,4 @@ spec: port: 2333 targetPort: 2333 protocol: TCP +{{- end }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml index b1286f6a1f4..86d39b51551 100644 --- a/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml +++ b/packages/grid/helm/syft/templates/rathole/rathole-statefulset.yaml @@ -1,3 +1,4 @@ +{{- if .Values.rtunnel.enabled }} apiVersion: apps/v1 kind: StatefulSet metadata: @@ -75,4 +76,5 @@ spec: - ReadWriteOnce resources: requests: - storage: 10Mi \ No newline at end of file + storage: 10Mi +{{- end }} \ No newline at end of file From feaa0d4a8db0c8ad01812e5128546063b42b47df Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 28 Jun 2024 14:07:03 +0530 Subject: [PATCH 179/309] add retry to test_delete_route_on_peer for flakyness --- tests/integration/network/gateway_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index d64ab32fdca..48cddf8ce4c 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -573,6 +573,7 @@ def test_add_route_on_peer(set_env_var, gateway_port: int, domain_1_port: int) - @pytest.mark.network +@pytest.mark.flaky(reruns=2, reruns_delay=2) def test_delete_route_on_peer( set_env_var, gateway_port: int, domain_1_port: int ) -> None: From 7705464ba246b725dfc463a012d1c5d6fd240786 Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Sat, 29 Jun 2024 22:17:09 +0530 Subject: [PATCH 180/309] Use pydantic_settings to manage FastAPI app settings --- packages/syft/src/syft/node/server.py | 149 +++++++++----------------- 1 file changed, 48 insertions(+), 101 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 0b46727f0f6..ca81e6e53dd 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -1,8 +1,5 @@ # stdlib -import base64 from collections.abc import Callable -from enum import Enum -import json import multiprocessing import os from pathlib import Path @@ -10,10 +7,13 @@ import signal import subprocess # nosec import time +from typing import Any # third party from fastapi import APIRouter from fastapi import FastAPI +from pydantic_settings import BaseSettings +from pydantic_settings import SettingsConfigDict import requests from starlette.middleware.cors import CORSMiddleware @@ -36,102 +36,64 @@ WAIT_TIME_SECONDS = 20 -def make_app(name: str, router: APIRouter) -> FastAPI: - app = FastAPI( - title=name, - ) - - api_router = APIRouter() - - api_router.include_router(router) - app.include_router(api_router, prefix="/api/v2") - - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) +class AppSettings(BaseSettings): + name: str + node_type: NodeType = NodeType.DOMAIN + node_side_type: NodeSideType = NodeSideType.HIGH_SIDE + processes: int = 1 + reset: bool = False + dev_mode: bool = False + enable_warnings: bool = False + in_memory_workers: bool = True + queue_port: int | None = None + create_producer: bool = False + n_consumers: int = 0 + association_request_auto_approval: bool = False + background_tasks: bool = False - return app + model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") def app_factory() -> FastAPI: - try: - kwargs_encoded = os.environ["APP_FACTORY_KWARGS"] - kwargs_json = base64.b64decode(kwargs_encoded) - kwargs = json.loads(kwargs_json) - name = kwargs["name"] - node_type = kwargs["node_type"] - node_side_type = kwargs["node_side_type"] - processes = kwargs["processes"] - reset = kwargs["reset"] - dev_mode = kwargs["dev_mode"] - enable_warnings = kwargs["enable_warnings"] - in_memory_workers = kwargs["in_memory_workers"] - queue_port = kwargs["queue_port"] - create_producer = kwargs["create_producer"] - n_consumers = kwargs["n_consumers"] - association_request_auto_approval = kwargs["association_request_auto_approval"] - background_tasks = kwargs["background_tasks"] - except KeyError as e: - raise KeyError(f"Missing required environment variable: {e}") + settings = AppSettings() worker_classes = { NodeType.DOMAIN: Domain, NodeType.GATEWAY: Gateway, NodeType.ENCLAVE: Enclave, } - if node_type not in worker_classes: - raise NotImplementedError(f"node_type: {node_type} is not supported") - worker_class = worker_classes[node_type] - kwargs = { - "name": name, - "processes": processes, - "local_db": True, - "node_type": node_type, - "node_side_type": node_side_type, - "enable_warnings": enable_warnings, - "migrate": True, - "in_memory_workers": in_memory_workers, - "queue_port": queue_port, - "create_producer": create_producer, - "n_consumers": n_consumers, - "association_request_auto_approval": association_request_auto_approval, - "background_tasks": background_tasks, - } - if dev_mode: + if settings.node_type not in worker_classes: + raise NotImplementedError(f"node_type: {settings.node_type} is not supported") + worker_class = worker_classes[settings.node_type] + + kwargs = settings.model_dump() + if settings.dev_mode: print( - f"\nWARNING: private key is based on node name: {name} in dev_mode. " + f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. " "Don't run this in production." ) - kwargs["reset"] = reset + worker = worker_class.named(**kwargs) + else: + del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode + worker = worker_class(**kwargs) - worker = worker_class.named(**kwargs) if dev_mode else worker_class(**kwargs) + app = FastAPI(title=settings.name) router = make_routes(worker=worker) - app = make_app(worker.name, router=router) + api_router = APIRouter() + api_router.include_router(router) + app.include_router(api_router, prefix="/api/v2") + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) return app -def run_uvicorn( - name: str, - node_type: Enum, - host: str, - port: int, - processes: int, - reset: bool, - dev_mode: bool, - node_side_type: str, - enable_warnings: bool, - in_memory_workers: bool, - queue_port: int | None, - create_producer: bool, - association_request_auto_approval: bool, - n_consumers: int, - background_tasks: bool, -) -> None: - if reset: +def run_uvicorn(host: str, port: int, **kwargs: Any) -> None: + if kwargs.get("reset"): try: python_pids = find_python_processes_on_port(port) for pid in python_pids: @@ -141,33 +103,18 @@ def run_uvicorn( except Exception: # nosec print(f"Failed to kill python process on port: {port}") - kwargs = { - "name": name, - "node_type": node_type, - "node_side_type": node_side_type, - "processes": processes, - "reset": reset, - "dev_mode": dev_mode, - "enable_warnings": enable_warnings, - "in_memory_workers": in_memory_workers, - "queue_port": queue_port, - "create_producer": create_producer, - "n_consumers": n_consumers, - "association_request_auto_approval": association_request_auto_approval, - "background_tasks": background_tasks, - } - kwargs_json = json.dumps(kwargs) - kwargs_encoded = base64.b64encode(kwargs_json.encode()).decode() + env_prefix = AppSettings.model_config.get("env_prefix", "") + env_variables = " ".join(f"{env_prefix}{k.upper()}={v}" for k, v in kwargs.items()) + uvicorn_cmd = ( - f"APP_FACTORY_KWARGS={kwargs_encoded}" + f"{env_variables}" " uvicorn syft.node.server:app_factory" " --factory" f" --host {host}" f" --port {port}" ) - if dev_mode: + if kwargs.get("dev_mode"): uvicorn_cmd += f" --reload --reload-dir {Path(__file__).parent.parent}" - print(f"{uvicorn_cmd=}") os.system(uvicorn_cmd) From 59f22204cee147b2197d32efcb19ebbdb613d54f Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:59:09 +0530 Subject: [PATCH 181/309] Use uvicorn.run instead of starting another shell process with os.system --- packages/syft/src/syft/node/server.py | 34 ++++++++++++++++++--------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index ca81e6e53dd..cf0bf4370d5 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -6,6 +6,7 @@ import platform import signal import subprocess # nosec +import sys import time from typing import Any @@ -16,6 +17,7 @@ from pydantic_settings import SettingsConfigDict import requests from starlette.middleware.cors import CORSMiddleware +import uvicorn # relative from ..abstract_node import NodeSideType @@ -103,19 +105,29 @@ def run_uvicorn(host: str, port: int, **kwargs: Any) -> None: except Exception: # nosec print(f"Failed to kill python process on port: {port}") + # Set up all kwargs as environment variables so that they can be accessed in the app_factory function. env_prefix = AppSettings.model_config.get("env_prefix", "") - env_variables = " ".join(f"{env_prefix}{k.upper()}={v}" for k, v in kwargs.items()) - - uvicorn_cmd = ( - f"{env_variables}" - " uvicorn syft.node.server:app_factory" - " --factory" - f" --host {host}" - f" --port {port}" + for key, value in kwargs.items(): + key_with_prefix = f"{env_prefix}{key.upper()}" + os.environ[key_with_prefix] = str(value) + + # The `serve_node` function calls `run_uvicorn` in a separate process using `multiprocessing.Process`. + # When the child process is created, it inherits the file descriptors from the parent process. + # If the parent process has a file descriptor open for sys.stdin, the child process will also have a file descriptor + # open for sys.stdin. This can cause an OSError in uvicorn when it tries to access sys.stdin in the child process. + # To prevent this, we set sys.stdin to None in the child process. This is safe because we don't actually need + # sys.stdin while running uvicorn programmatically. + sys.stdin = None # type: ignore + + # Finally, run the uvicorn server. + uvicorn.run( + "syft.node.server:app_factory", + host=host, + port=port, + factory=True, + reload=kwargs.get("dev_mode"), + reload_dirs=[Path(__file__).parent.parent] if kwargs.get("dev_mode") else None, ) - if kwargs.get("dev_mode"): - uvicorn_cmd += f" --reload --reload-dir {Path(__file__).parent.parent}" - os.system(uvicorn_cmd) def serve_node( From e95af15c8b8789fac5c4880655fd9b080f69c630 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Sun, 30 Jun 2024 14:56:43 +0530 Subject: [PATCH 182/309] fix incorrect path --- .../syft/src/syft/service/network/rathole_config_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index f134468a120..15d1e8fb7eb 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -211,7 +211,7 @@ def _add_dynamic_addr_to_rathole( proxy_rule = ( f"Host(`{config.server_name}.syft.local`) || " - f"HostHeader(`{config.server_name}.syft.local`) && PathPrefix(`/rathole`)" + f"HostHeader(`{config.server_name}.syft.local`) && PathPrefix(`/rtunnel`)" ) rathole_proxy["http"]["routers"][config.server_name] = { From c2ab1029579cb98fb763ac4fd9996242957bfac0 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Sun, 30 Jun 2024 17:24:45 +0530 Subject: [PATCH 183/309] fix charts --- packages/grid/devspace.yaml | 126 +++++++++--------- .../dev/base.yaml} | 29 ++-- .../grid/helm/examples/dev/domain.tunnel.yaml | 11 ++ packages/grid/helm/examples/dev/enclave.yaml | 8 ++ packages/grid/helm/examples/dev/gateway.yaml | 14 ++ packages/grid/helm/syft/values.yaml | 8 +- packages/grid/helm/values.dev.high.yaml | 48 ------- packages/grid/helm/values.dev.low.yaml | 48 ------- tox.ini | 21 ++- 9 files changed, 133 insertions(+), 180 deletions(-) rename packages/grid/helm/{values.dev.yaml => examples/dev/base.yaml} (74%) create mode 100644 packages/grid/helm/examples/dev/domain.tunnel.yaml create mode 100644 packages/grid/helm/examples/dev/enclave.yaml create mode 100644 packages/grid/helm/examples/dev/gateway.yaml delete mode 100644 packages/grid/helm/values.dev.high.yaml delete mode 100644 packages/grid/helm/values.dev.low.yaml diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index b961d17da26..8bbf3487daf 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -60,14 +60,6 @@ images: context: ./seaweedfs tags: - dev-${DEVSPACE_TIMESTAMP} - rathole: - image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_RATHOLE}" - buildKit: - args: ["--platform", "linux/${PLATFORM}"] - dockerfile: ./rathole/rathole.dockerfile - context: ./rathole - tags: - - dev-${DEVSPACE_TIMESTAMP} # This is a list of `deployments` that DevSpace can create for this project deployments: @@ -76,20 +68,16 @@ deployments: releaseName: syft-dev chart: name: ./helm/syft - # anything that does not need devspace $env vars should go in values.dev.yaml - valuesFiles: - - ./helm/syft/values.yaml - - ./helm/values.dev.yaml + # values that need to be templated go here values: global: registry: ${CONTAINER_REGISTRY} version: dev-${DEVSPACE_TIMESTAMP} - node: - type: domain # required for the gateway profile - rtunnel: - mode: client - proxy: - enabled: true # required for the gateway profile + node: {} + # anything that does not need templating should go in helm/examples/dev/base.yaml + # or profile specific values files + valuesFiles: + - ./helm/examples/dev/base.yaml dev: mongo: @@ -127,69 +115,86 @@ dev: - path: ../syft:/root/app/syft ssh: localPort: 3480 - rathole: - labelSelector: - app.kubernetes.io/name: syft - app.kubernetes.io/component: rathole - ports: - - port: "2333" # reverse_tunnel profiles: - - name: dev-low + - name: domain-low + description: "Deploy a low-side domain" patches: - op: add path: deployments.syft.helm.values.node value: side: low - - name: dev-high + + - name: domain-tunnel + description: "Deploy a domain with tunneling enabled" patches: + # enable rathole image - op: add - path: deployments.syft.helm.values.node + path: images value: - side: high + rathole: + image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_RATHOLE}" + buildKit: + args: ["--platform", "linux/${PLATFORM}"] + dockerfile: ./rathole/rathole.dockerfile + context: ./rathole + tags: + - dev-${DEVSPACE_TIMESTAMP} + # use rathole client-specific chart values + - op: add + path: deployments.syft.helm.valuesFiles + value: ./helm/examples/dev/domain.tunnel.yaml - name: gateway + description: "Deploy a Gateway Node with tunnel enabled" patches: - - op: replace - path: deployments.syft.helm.values.node.type - value: "gateway" + # enable rathole image + - op: add + path: images + value: + rathole: + image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_RATHOLE}" + buildKit: + args: ["--platform", "linux/${PLATFORM}"] + dockerfile: ./rathole/rathole.dockerfile + context: ./rathole + tags: + - dev-${DEVSPACE_TIMESTAMP} + # enable rathole `devspace dev` config + - op: add + path: dev + value: + rathole: + labelSelector: + app.kubernetes.io/name: syft + app.kubernetes.io/component: rathole + ports: + - port: "2333" + # use gateway-specific chart values + - op: add + path: deployments.syft.helm.valuesFiles + value: ./helm/examples/dev/gateway.yaml + # remove unused images - op: remove path: images.seaweedfs - op: remove path: dev.seaweedfs - - # Patch mode to server - - op: replace - path: deployments.syft.helm.values.rtunnel.mode - value: server - - # Enable proxy for gateway - - op: replace - path: deployments.syft.helm.values.proxy.enabled - value: true - # Port Re-Mapping - # Mongo - op: replace path: dev.mongo.ports[0].port value: 27018:27017 - - # Backend - op: replace path: dev.backend.ports[0].port value: 5679:5678 - - # Backend Container SSH - op: replace path: dev.backend.containers.backend-container.ssh.localPort value: 3481 - - # Reverse tunnel port - op: replace path: dev.rtunnel.ports[0].port value: 2334:2333 - name: gcp + description: "Deploy a high-side domain on GCP" patches: - op: replace path: deployments.syft.helm.valuesFiles @@ -197,6 +202,7 @@ profiles: - ./helm/examples/gcp/gcp.high.yaml - name: gcp-low + description: "Deploy a low-side domain on GCP" patches: - op: replace path: deployments.syft.helm.valuesFiles @@ -204,6 +210,7 @@ profiles: - ./helm/examples/gcp/gcp.low.yaml - name: azure + description: "Deploy a high-side domain on AKS" patches: - op: replace path: deployments.syft.helm.valuesFiles @@ -211,11 +218,9 @@ profiles: - ./helm/examples/azure/azure.high.yaml - name: enclave + description: "Deploy an enclave node" patches: - - op: replace - path: deployments.syft.helm.values.node.type - value: "enclave" - + # enable image build for enclave-attestation - op: add path: images value: @@ -233,29 +238,20 @@ profiles: enclave-attestation: sync: - path: ./enclave/attestation/server:/app/server - + # use gateway-specific chart values - op: add - path: deployments.syft.helm.values - value: - attestation: - enabled: true - + path: deployments.syft.helm.valuesFiles + value: ./helm/examples/dev/enclave.yaml # Port Re-Mapping - # Mongo - op: replace path: dev.mongo.ports[0].port value: 27019:27017 - - # Backend - op: replace path: dev.backend.ports[0].port value: 5680:5678 - - # Backend Container SSH - op: replace path: dev.backend.containers.backend-container.ssh.localPort value: 3482 - - op: replace path: dev.seaweedfs.ports value: diff --git a/packages/grid/helm/values.dev.yaml b/packages/grid/helm/examples/dev/base.yaml similarity index 74% rename from packages/grid/helm/values.dev.yaml rename to packages/grid/helm/examples/dev/base.yaml index 493850cbb67..b81e4847cd8 100644 --- a/packages/grid/helm/values.dev.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -1,15 +1,9 @@ -# Helm chart values used for development and testing -# Can be used through `helm install -f values.dev.yaml` or devspace `valuesFiles` +# Base Helm chart values used for development and testing +# Can be used through `helm install -f packages/grid/helm/examples/dev/base.yaml` or devspace `valuesFiles` global: randomizedSecrets: false -registry: - resourcesPreset: null - resources: null - - storageSize: "5Gi" - node: rootEmail: info@openmined.org associationRequestAutoApproval: true @@ -44,14 +38,21 @@ frontend: resources: null proxy: + enabled: true + resourcesPreset: null resources: null -rtunnel: +registry: enabled: true - logLevel: "trace" -# attestation: -# enabled: true -# resourcesPreset: null -# resources: null + resourcesPreset: null + resources: null + + storageSize: "5Gi" + +rtunnel: + enabled: false + +attestation: + enabled: false diff --git a/packages/grid/helm/examples/dev/domain.tunnel.yaml b/packages/grid/helm/examples/dev/domain.tunnel.yaml new file mode 100644 index 00000000000..cec2e97cc6e --- /dev/null +++ b/packages/grid/helm/examples/dev/domain.tunnel.yaml @@ -0,0 +1,11 @@ +# Values for deploying a domain with a reverse tunnel server in client-mode +# Patched on top of patch `base.yaml` + +# Proxy is required for the tunnel to work +proxy: + enabled: true + +rtunnel: + enabled: true + mode: client + logLevel: debug diff --git a/packages/grid/helm/examples/dev/enclave.yaml b/packages/grid/helm/examples/dev/enclave.yaml new file mode 100644 index 00000000000..2951da06b05 --- /dev/null +++ b/packages/grid/helm/examples/dev/enclave.yaml @@ -0,0 +1,8 @@ +# Values for deploying an enclave +# Patched on top of patch `base.yaml` + +node: + type: enclave + +attestation: + enabled: true diff --git a/packages/grid/helm/examples/dev/gateway.yaml b/packages/grid/helm/examples/dev/gateway.yaml new file mode 100644 index 00000000000..e0916c98c21 --- /dev/null +++ b/packages/grid/helm/examples/dev/gateway.yaml @@ -0,0 +1,14 @@ +# Values for deploying a gateway with a reverse tunnel server +# Patched on top of patch `base.yaml` + +node: + type: gateway + +# Proxy is required for the tunnel to work +proxy: + enabled: true + +rtunnel: + enabled: true + mode: server + logLevel: debug diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index b5372d8e857..377bd763c54 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -134,6 +134,7 @@ proxy: registry: enabled: true + # Extra environment vars env: null @@ -239,14 +240,15 @@ ingress: rtunnel: - # Extra environment vars - env: null enabled: false - logLevel: "info" + logLevel: info port: 2333 mode: client + # Extra environment vars + env: null + # Pod labels & annotations podLabels: null podAnnotations: null diff --git a/packages/grid/helm/values.dev.high.yaml b/packages/grid/helm/values.dev.high.yaml deleted file mode 100644 index 9a0e266704a..00000000000 --- a/packages/grid/helm/values.dev.high.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# Helm chart values used for development and testing -# Can be used through `helm install -f values.dev.yaml` or devspace `valuesFiles` - -global: - randomizedSecrets: false - -registry: - resourcesPreset: null - resources: null - - storageSize: "5Gi" - -node: - rootEmail: info@openmined.org - side: high - - resourcesPreset: 2xlarge - resources: null - - defaultWorkerPool: - count: 1 - podLabels: null - podAnnotations: null - - secret: - defaultRootPassword: changethis - -mongo: - resourcesPreset: null - resources: null - - secret: - rootPassword: example - -seaweedfs: - resourcesPreset: null - resources: null - - secret: - s3RootPassword: admin - -frontend: - resourcesPreset: null - resources: null - -proxy: - resourcesPreset: null - resources: null diff --git a/packages/grid/helm/values.dev.low.yaml b/packages/grid/helm/values.dev.low.yaml deleted file mode 100644 index 7e5de1a68f2..00000000000 --- a/packages/grid/helm/values.dev.low.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# Helm chart values used for development and testing -# Can be used through `helm install -f values.dev.yaml` or devspace `valuesFiles` - -global: - randomizedSecrets: false - -registry: - resourcesPreset: null - resources: null - - storageSize: "5Gi" - -node: - rootEmail: info@openmined.org - side: low - - resourcesPreset: 2xlarge - resources: null - - defaultWorkerPool: - count: 1 - podLabels: null - podAnnotations: null - - secret: - defaultRootPassword: changethis - -mongo: - resourcesPreset: null - resources: null - - secret: - rootPassword: example - -seaweedfs: - resourcesPreset: null - resources: null - - secret: - s3RootPassword: admin - -frontend: - resourcesPreset: null - resources: null - -proxy: - resourcesPreset: null - resources: null diff --git a/tox.ini b/tox.ini index 6fdda86e25c..7955837a2ae 100644 --- a/tox.ini +++ b/tox.ini @@ -471,7 +471,7 @@ commands = # Creating test-domain-1 cluster on port 9082 bash -c '\ - export CLUSTER_NAME=${DOMAIN_CLUSTER_NAME} CLUSTER_HTTP_PORT=9082 && \ + export CLUSTER_NAME=${DOMAIN_CLUSTER_NAME} CLUSTER_HTTP_PORT=9082 DEVSPACE_PROFILE=domain-tunnel && \ tox -e dev.k8s.start && \ tox -e dev.k8s.deploy' @@ -874,6 +874,23 @@ commands = bash -c 'devspace cleanup images --kube-context k3d-${CLUSTER_NAME} --no-warn --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 || true' bash -c 'kubectl --context k3d-${CLUSTER_NAME} delete namespace syft --now=true || true' +[testenv:dev.k8s.render] +description = Dump devspace rendered chargs for debugging. Save in `packages/grid/out.render` +changedir = {toxinidir}/packages/grid +passenv = HOME, USER, DEVSPACE_PROFILE +setenv= + OUTPUT_DIR = {env:OUTPUT_DIR:./.devspace/rendered} +allowlist_externals = + bash +commands = + bash -c '\ + if [[ -n "${DEVSPACE_PROFILE}" ]]; then export DEVSPACE_PROFILE="-p ${DEVSPACE_PROFILE}"; fi && \ + rm -rf ${OUTPUT_DIR} && \ + mkdir -p ${OUTPUT_DIR} && \ + echo "profile: $DEVSPACE_PROFILE" && \ + devspace print ${DEVSPACE_PROFILE} > ${OUTPUT_DIR}/config.txt && \ + devspace deploy --render --skip-build --no-warn ${DEVSPACE_PROFILE} --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 > ${OUTPUT_DIR}/chart.yaml' + [testenv:dev.k8s.launch.gateway] description = Launch a single gateway on K8s passenv = HOME, USER @@ -888,7 +905,7 @@ commands = tox -e dev.k8s.{posargs:deploy} [testenv:dev.k8s.launch.domain] -description = Launch a single domain on K8s +description = Launch a single domain on K8s passenv = HOME, USER setenv= CLUSTER_NAME = {env:CLUSTER_NAME:test-domain-1} From a92b06d5508f100b30cfa1d0a1dba13f69796c8f Mon Sep 17 00:00:00 2001 From: alfred-openmined-bot <145415986+alfred-openmined-bot@users.noreply.github.com> Date: Sun, 30 Jun 2024 12:14:37 +0000 Subject: [PATCH 184/309] bump protocol and remove notebooks --- notebooks/Experimental/Network.ipynb | 209 ------------------ .../src/syft/protocol/protocol_version.json | 4 +- 2 files changed, 2 insertions(+), 211 deletions(-) delete mode 100644 notebooks/Experimental/Network.ipynb diff --git a/notebooks/Experimental/Network.ipynb b/notebooks/Experimental/Network.ipynb deleted file mode 100644 index 7a1f3f257dc..00000000000 --- a/notebooks/Experimental/Network.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "bd9a2226-3e53-4f27-9213-75a8c3ff9176", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "import syft as sy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fddf8d07-d154-4284-a27b-d74e35d3f851", - "metadata": {}, - "outputs": [], - "source": [ - "gateway_client = sy.login(\n", - " url=\"http://localhost\", port=9081, email=\"info@openmined.org\", password=\"changethis\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f7b106d-b784-45d8-b54d-4ce2de2da453", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client = sy.login(\n", - " url=\"http://localhost\", port=9082, email=\"info@openmined.org\", password=\"changethis\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff504949-620d-4e26-beee-0d39e0e502eb", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client.connect_to_gateway(gateway_client, reverse_tunnel=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba7bc71a-4e6a-4429-9588-7b3d0ed19e27", - "metadata": {}, - "outputs": [], - "source": [ - "gateway_client.api.services.request" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5b4984e1-331e-4fd8-b012-768fc613f48a", - "metadata": {}, - "outputs": [], - "source": [ - "# gateway_client.api.services.request[0].approve()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "90dc44bd", - "metadata": {}, - "outputs": [], - "source": [ - "node_peers = gateway_client.api.network.get_all_peers()\n", - "node_peers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c06aaa6-4157-42d1-959f-9d47722a3420", - "metadata": {}, - "outputs": [], - "source": [ - "node_peer = gateway_client.api.network.get_all_peers()[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb63a77b", - "metadata": {}, - "outputs": [], - "source": [ - "node_peer.node_routes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61882e86", - "metadata": {}, - "outputs": [], - "source": [ - "node_peer.node_routes[0].__dict__" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fb19dbc6-869b-46dc-92e3-5e75ee6d0b06", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client.api.network.get_all_peers()[0].node_routes[0].__dict__" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32d09a51", - "metadata": {}, - "outputs": [], - "source": [ - "# node_peer.client_with_key(sy.SyftSigningKey.generate())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7d9e41d", - "metadata": {}, - "outputs": [], - "source": [ - "# gateway_client.api.network.delete_route(node_peer.verify_key, node_peer.node_routes[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8fa24ec7", - "metadata": {}, - "outputs": [], - "source": [ - "gateway_client" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3a081250-abc3-43a3-9e06-ff0c3a362ebf", - "metadata": {}, - "outputs": [], - "source": [ - "gateway_client.peers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6fedfe4-9362-47c9-9342-5cf6eacde8ab", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client_proxy = gateway_client.peers[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f1940e00-0337-4b56-88c2-d70f397a7016", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client_proxy.connection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "613125c5-6321-4238-852c-ff0cfcd9526a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.1.-1" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index c08ea39b7b2..5c0c448a2cf 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -281,7 +281,7 @@ }, "3": { "version": 3, - "hash": "cac31ba98bdcc42c0717555a0918d0c8aef0d2235f892a2d86dceff09930fb88", + "hash": "b61b30d10e2d25726c708ef34b69c7b730d41b16b315e7062f3d487e943143d5", "action": "add" } }, @@ -399,7 +399,7 @@ }, "3": { "version": 3, - "hash": "9e7e3700a2f7b1a67f054efbcb31edc71bbf358c469c85ed7760b81233803bac", + "hash": "d26cb313e92b1fbe36995c8ed4103a9168ea6e589b2803ed9a91c23f14bf0c96", "action": "add" } }, From 6cf0d50789d89fd137d43f0e3ae87670971637ca Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Sun, 30 Jun 2024 20:34:07 +0530 Subject: [PATCH 185/309] Add VSCode debugger support to Syft uvicorn server using debugpy --- .gitignore | 3 ++- .pre-commit-config.yaml | 4 ++-- .vscode/launch.json | 25 ++++++++++++++++++++ packages/syft/setup.cfg | 3 ++- packages/syft/src/syft/node/server.py | 34 ++++++++++++++++++++++++++- packages/syft/src/syft/orchestra.py | 4 ++++ 6 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.gitignore b/.gitignore index fc3d10b8733..0ab30b93adf 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,8 @@ .idea/ .mypy_cache .python-version -.vscode/ +.vscode/* +!.vscode/launch.json .tox/* .creds build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab1745d6ce5..92b64b9d4b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: exclude: ^(packages/syft/tests/mongomock) - id: check-json always_run: true - exclude: ^(packages/grid/frontend/|packages/syft/tests/mongomock) + exclude: ^(packages/grid/frontend/|packages/syft/tests/mongomock|.vscode) - id: check-added-large-files always_run: true exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif) @@ -179,7 +179,7 @@ repos: rev: "v3.0.0-alpha.9-for-vscode" hooks: - id: prettier - exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock) + exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) # - repo: meta # hooks: diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000000..bb5d6e9c00a --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Syft Debugger", + "type": "debugpy", + "request": "attach", + "connect": { + "host": "localhost", + "port": "${input:port}" + } + } + ], + "inputs": [ + { + "id": "port", + "description": "Port on which the debugger is listening", + "type": "promptString", + "default": "5678" + } + ] +} \ No newline at end of file diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 6c8ed7d741b..7e30b26e8e4 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -92,11 +92,12 @@ dev = %(test_plugins)s %(telemetry)s bandit==1.7.8 - ruff==0.4.7 + debugpy==1.8.2 importlib-metadata==7.1.0 isort==5.13.2 mypy==1.10.0 pre-commit==3.7.1 + ruff==0.4.7 safety>=2.4.0b2 telemetry = diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index cf0bf4370d5..131e7595e2c 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -1,6 +1,7 @@ # stdlib from collections.abc import Callable import multiprocessing +import multiprocessing.synchronize import os from pathlib import Path import platform @@ -11,6 +12,7 @@ from typing import Any # third party +import debugpy from fastapi import APIRouter from fastapi import FastAPI from pydantic_settings import BaseSettings @@ -94,7 +96,12 @@ def app_factory() -> FastAPI: return app -def run_uvicorn(host: str, port: int, **kwargs: Any) -> None: +def run_uvicorn( + host: str, + port: int, + debugger_attached_event: multiprocessing.synchronize.Event | None = None, + **kwargs: Any, +) -> None: if kwargs.get("reset"): try: python_pids = find_python_processes_on_port(port) @@ -105,6 +112,23 @@ def run_uvicorn(host: str, port: int, **kwargs: Any) -> None: except Exception: # nosec print(f"Failed to kill python process on port: {port}") + if kwargs.get("debug"): + if debugger_attached_event is None: + raise ValueError( + "The `debugger_attached_event` parameter must be provided when `debug=True`." + ) + os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + _, debug_port = debugpy.listen(0) + print( + "\nStarting the server with the Python Debugger enabled (`debug=True`).\n" + 'To attach the debugger, open the command palette in VSCode and select "Debug: Start Debugging (F5)".\n' + f"Then, enter `{debug_port}` in the port field and press Enter.\n" + ) + print(f"Waiting for debugger to attach on port `{debug_port}`...") + debugpy.wait_for_client() # blocks execution until a remote debugger is attached + print("Debugger attached") + debugger_attached_event.set() # Signal the parent process that the debugger is attached + # Set up all kwargs as environment variables so that they can be accessed in the app_factory function. env_prefix = AppSettings.model_config.get("env_prefix", "") for key, value in kwargs.items(): @@ -147,7 +171,9 @@ def serve_node( n_consumers: int = 0, association_request_auto_approval: bool = False, background_tasks: bool = False, + debug: bool = False, ) -> tuple[Callable, Callable]: + debugger_attached_event = multiprocessing.Event() server_process = multiprocessing.Process( target=run_uvicorn, kwargs={ @@ -166,6 +192,8 @@ def serve_node( "n_consumers": n_consumers, "association_request_auto_approval": association_request_auto_approval, "background_tasks": background_tasks, + "debug": debug, + "debugger_attached_event": debugger_attached_event, }, ) @@ -182,6 +210,10 @@ def start() -> None: print(f"Starting {name} server on {host}:{port}") server_process.start() + if debug: + # Wait for the debugger to get attached + debugger_attached_event.wait() + if tail: try: while True: diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 1a08f594aa2..53601581724 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -166,6 +166,7 @@ def deploy_to_python( queue_port: int | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, + debug: bool = False, ) -> NodeHandle: worker_classes = { NodeType.DOMAIN: Domain, @@ -193,6 +194,7 @@ def deploy_to_python( "create_producer": create_producer, "association_request_auto_approval": association_request_auto_approval, "background_tasks": background_tasks, + "debug": debug, } if port: @@ -282,6 +284,7 @@ def launch( queue_port: int | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, + debug: bool = False, ) -> NodeHandle: if dev_mode is True: thread_workers = True @@ -318,6 +321,7 @@ def launch( queue_port=queue_port, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, + debug=debug, ) elif deployment_type_enum == DeploymentType.REMOTE: return deploy_to_remote( From c7980e1c7b6cb89320086656f72fa03b3c2b805c Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Sun, 30 Jun 2024 22:28:20 +0530 Subject: [PATCH 186/309] Refactoring and CI fixes --- packages/syft/src/syft/node/server.py | 47 ++++++++++++++------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 131e7595e2c..e208b4dcaf0 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -12,7 +12,6 @@ from typing import Any # third party -import debugpy from fastapi import APIRouter from fastapi import FastAPI from pydantic_settings import BaseSettings @@ -96,10 +95,26 @@ def app_factory() -> FastAPI: return app +def attach_debugger() -> None: + # third party + import debugpy + + os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + _, debug_port = debugpy.listen(0) + print( + "\nStarting the server with the Python Debugger enabled (`debug=True`).\n" + 'To attach the debugger, open the command palette in VSCode and select "Debug: Start Debugging (F5)".\n' + f"Then, enter `{debug_port}` in the port field and press Enter.\n" + ) + print(f"Waiting for debugger to attach on port `{debug_port}`...") + debugpy.wait_for_client() # blocks execution until a remote debugger is attached + print("Debugger attached") + + def run_uvicorn( host: str, port: int, - debugger_attached_event: multiprocessing.synchronize.Event | None = None, + starting_uvicorn_event: multiprocessing.synchronize.Event, **kwargs: Any, ) -> None: if kwargs.get("reset"): @@ -113,21 +128,7 @@ def run_uvicorn( print(f"Failed to kill python process on port: {port}") if kwargs.get("debug"): - if debugger_attached_event is None: - raise ValueError( - "The `debugger_attached_event` parameter must be provided when `debug=True`." - ) - os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" - _, debug_port = debugpy.listen(0) - print( - "\nStarting the server with the Python Debugger enabled (`debug=True`).\n" - 'To attach the debugger, open the command palette in VSCode and select "Debug: Start Debugging (F5)".\n' - f"Then, enter `{debug_port}` in the port field and press Enter.\n" - ) - print(f"Waiting for debugger to attach on port `{debug_port}`...") - debugpy.wait_for_client() # blocks execution until a remote debugger is attached - print("Debugger attached") - debugger_attached_event.set() # Signal the parent process that the debugger is attached + attach_debugger() # Set up all kwargs as environment variables so that they can be accessed in the app_factory function. env_prefix = AppSettings.model_config.get("env_prefix", "") @@ -143,6 +144,9 @@ def run_uvicorn( # sys.stdin while running uvicorn programmatically. sys.stdin = None # type: ignore + # Signal the parent process that we are starting the uvicorn server. + starting_uvicorn_event.set() + # Finally, run the uvicorn server. uvicorn.run( "syft.node.server:app_factory", @@ -173,7 +177,7 @@ def serve_node( background_tasks: bool = False, debug: bool = False, ) -> tuple[Callable, Callable]: - debugger_attached_event = multiprocessing.Event() + starting_uvicorn_event = multiprocessing.Event() server_process = multiprocessing.Process( target=run_uvicorn, kwargs={ @@ -193,7 +197,7 @@ def serve_node( "association_request_auto_approval": association_request_auto_approval, "background_tasks": background_tasks, "debug": debug, - "debugger_attached_event": debugger_attached_event, + "starting_uvicorn_event": starting_uvicorn_event, }, ) @@ -210,9 +214,8 @@ def start() -> None: print(f"Starting {name} server on {host}:{port}") server_process.start() - if debug: - # Wait for the debugger to get attached - debugger_attached_event.wait() + # Wait for the child process to start uvicorn server before starting the readiness checks. + starting_uvicorn_event.wait() if tail: try: From e65c81c20e0678f5ff3cf84132be008f4cfced06 Mon Sep 17 00:00:00 2001 From: dk Date: Mon, 1 Jul 2024 12:11:39 +0700 Subject: [PATCH 187/309] [syft/user_code] separate code parsing out of the 'global' keyworkd check function - add some type annotations - simplify unit test for the case --- .../syft/src/syft/service/code/user_code.py | 49 +++++++++++++------ .../syft/tests/syft/users/user_code_test.py | 13 +++-- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ad51750ed19..704f130b82d 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1042,9 +1042,13 @@ def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: source = dedent(inspect.getsource(self.local_function)) - - v: GlobalsVisitor | SyftWarning = _check_global(raw_code=source) - if isinstance(v, SyftWarning): # the code contains "global" keyword + tree: ast.Module | SyftWarning = _parse_code(source) + if isinstance(tree, SyftWarning): + return SyftError( + message=f"Error when running function locally: {tree.message}" + ) + v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) + if isinstance(v, SyftWarning): return SyftError( message=f"Error when running function locally: {v.message}" ) @@ -1264,8 +1268,12 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) + tree: ast.Module | SyftWarning = _parse_code(raw_code=code) + if isinstance(tree, SyftWarning): + display(tree) + # check that there are no globals - global_check = _check_global(code) + global_check: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) if isinstance(global_check, SyftWarning): display(global_check) @@ -1314,12 +1322,13 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator -def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: - tree = ast.parse(raw_code) - # check there are no globals +def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: + """ + Check that the code does not contain any global variables + """ v = GlobalsVisitor() try: - v.visit(tree) + v.visit(code_tree) except Exception: return SyftWarning( message="Your code contains (a) global variable(s), which is not allowed" @@ -1327,6 +1336,17 @@ def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: return v +def _parse_code(raw_code: str) -> ast.Module | SyftWarning: + """ + Parse the code into an AST tree and return a warning if there are syntax errors + """ + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + return SyftWarning(message=f"Your code contains syntax error: {e}") + return tree + + def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] @@ -1349,17 +1369,16 @@ def process_code( policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - try: - tree = ast.parse(raw_code) - except SyntaxError as e: - raise SyftException(f"Syntax error in code: {e}") + tree: ast.Module | SyftWarning = _parse_code(raw_code=raw_code) + if isinstance(tree, SyftWarning): + raise SyftException(f"{tree.message}") # check there are no globals - v = _check_global(raw_code=tree) + v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) if isinstance(v, SyftWarning): - raise SyftException(message=f"{v.message}") + raise SyftException(f"{v.message}") - f = tree.body[0] + f: ast.stmt = tree.body[0] f.decorator_list = [] call_args = function_input_kwargs diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 13caf8f9b07..5d7504c61fc 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -401,15 +401,18 @@ def test_submit_code_with_global_var(guest_client: DomainClient) -> None: ) def mock_syft_func_with_global(): global x + return x - def example_function(): - return 1 + x + res = guest_client.code.submit(mock_syft_func_with_global) + assert isinstance(res, SyftError) - return example_function() + @sy.syft_function_single_use() + def mock_syft_func_single_use_with_global(): + global x + return x - res = guest_client.code.submit(mock_syft_func_with_global) + res = guest_client.code.submit(mock_syft_func_single_use_with_global) assert isinstance(res, SyftError) - assert "No Globals allowed!" in res.message def test_request_existing_usercodesubmit(worker) -> None: From a589ad3c0b4ad66ff7d889afd5e477a31b109307 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 1 Jul 2024 14:09:14 +0700 Subject: [PATCH 188/309] [syft/user_code] try to parse and unparse user code both on the client and server side - remove parsing user code in local call since it was already parsed before Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/code/user_code.py | 110 +++++++++--------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 704f130b82d..cad1255553f 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1041,18 +1041,6 @@ def __call__( def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: - source = dedent(inspect.getsource(self.local_function)) - tree: ast.Module | SyftWarning = _parse_code(source) - if isinstance(tree, SyftWarning): - return SyftError( - message=f"Error when running function locally: {tree.message}" - ) - v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(v, SyftWarning): - return SyftError( - message=f"Error when running function locally: {v.message}" - ) - # filtered_args = [] filtered_kwargs = {} # for arg in args: @@ -1268,21 +1256,21 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) - tree: ast.Module | SyftWarning = _parse_code(raw_code=code) - if isinstance(tree, SyftWarning): - display(tree) - - # check that there are no globals - global_check: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(global_check, SyftWarning): - display(global_check) - if name is not None: fname = name code = replace_func_name(code, fname) else: fname = f.__name__ + input_kwargs = f.__code__.co_varnames[: f.__code__.co_argcount] + + parse_user_code( + raw_code=code, + func_name=fname, + original_func_name=f.__name__, + function_input_kwargs=input_kwargs, + ) + res = SubmitUserCode( code=code, func_name=fname, @@ -1292,7 +1280,7 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: output_policy_type=output_policy_type, output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), local_function=f, - input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount], + input_kwargs=input_kwargs, worker_pool_name=worker_pool_name, ) @@ -1305,6 +1293,11 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: display(err) return err + except SyftException as se: + err = SyftError(message=f"Error when parsing the code: {se}") + display(err) + return err + if share_results_with_owners and res.output_policy_init_kwargs is not None: res.output_policy_init_kwargs["output_readers"] = ( res.input_owner_verify_keys @@ -1322,6 +1315,20 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator +def generate_unique_func_name(context: TransformContext) -> TransformContext: + if context.output is not None: + code_hash = context.output["code_hash"] + service_func_name = context.output["func_name"] + context.output["service_func_name"] = service_func_name + func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" + user_unique_func_name = ( + f"user_func_{service_func_name}_{context.credentials}_{time.time()}" + ) + context.output["unique_func_name"] = func_name + context.output["user_unique_func_name"] = user_unique_func_name + return context + + def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: """ Check that the code does not contain any global variables @@ -1330,8 +1337,8 @@ def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: try: v.visit(code_tree) except Exception: - return SyftWarning( - message="Your code contains (a) global variable(s), which is not allowed" + raise SyftException( + "Your code contains (a) global variable(s), which is not allowed" ) return v @@ -1343,47 +1350,27 @@ def _parse_code(raw_code: str) -> ast.Module | SyftWarning: try: tree = ast.parse(raw_code) except SyntaxError as e: - return SyftWarning(message=f"Your code contains syntax error: {e}") + raise SyftException(f"Your code contains syntax error: {e}") return tree -def generate_unique_func_name(context: TransformContext) -> TransformContext: - if context.output is not None: - code_hash = context.output["code_hash"] - service_func_name = context.output["func_name"] - context.output["service_func_name"] = service_func_name - func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" - user_unique_func_name = ( - f"user_func_{service_func_name}_{context.credentials}_{time.time()}" - ) - context.output["unique_func_name"] = func_name - context.output["user_unique_func_name"] = user_unique_func_name - return context - - -def process_code( - context: TransformContext, +def parse_user_code( raw_code: str, func_name: str, original_func_name: str, - policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - tree: ast.Module | SyftWarning = _parse_code(raw_code=raw_code) - if isinstance(tree, SyftWarning): - raise SyftException(f"{tree.message}") - - # check there are no globals - v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(v, SyftWarning): - raise SyftException(f"{v.message}") + # parse the code, check for syntax errors and if there are global variables + try: + tree: ast.Module = _parse_code(raw_code=raw_code) + _check_global(code_tree=tree) + except SyftException as e: + raise SyftException(f"{e}") f: ast.stmt = tree.body[0] f.decorator_list = [] call_args = function_input_kwargs - if "domain" in function_input_kwargs and context.output is not None: - context.output["uses_domain"] = True call_stmt_keywords = [ast.keyword(arg=i, value=[ast.Name(id=i)]) for i in call_args] call_stmt = ast.Assign( targets=[ast.Name(id="result")], @@ -1408,6 +1395,25 @@ def process_code( return unparse(wrapper_function) +def process_code( + context: TransformContext, + raw_code: str, + func_name: str, + original_func_name: str, + policy_input_kwargs: list[str], + function_input_kwargs: list[str], +) -> str: + if "domain" in function_input_kwargs and context.output is not None: + context.output["uses_domain"] = True + + return parse_user_code( + raw_code=raw_code, + func_name=func_name, + original_func_name=original_func_name, + function_input_kwargs=function_input_kwargs, + ) + + def new_check_code(context: TransformContext) -> TransformContext: # TODO: remove this tech debt hack if context.output is None: From a260de881ad279cd094e35d4d5f0ded2968ee63c Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 1 Jul 2024 15:26:12 +0700 Subject: [PATCH 189/309] [syft/chore] change print to `logger.debug` for blob store path and min size to upload --- packages/syft/src/syft/node/node.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index e056fd32147..977dfb07f70 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -510,12 +510,11 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: if self.dev_mode: if isinstance(self.blob_store_config, OnDiskBlobStorageConfig): - print( + logger.debug( f"Using on-disk blob storage with path: " f"{self.blob_store_config.client_config.base_directory}", - end=". ", ) - print( + logger.debug( f"Minimum object size to be saved to the blob storage: " f"{self.blob_store_config.min_blob_size} (MB)." ) From ab083585bf0e9ea6cb0b396b3f57a4d7e33d3cf1 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 1 Jul 2024 11:59:33 +0200 Subject: [PATCH 190/309] update code.asset and code repr --- .../syft/src/syft/service/code/user_code.py | 52 ++++++++++++------- .../syft/src/syft/service/dataset/dataset.py | 3 +- packages/syft/src/syft/util/table.py | 10 +++- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index cab58c09eb6..6043453f552 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -10,7 +10,7 @@ import hashlib import inspect from io import StringIO -import itertools +import json import keyword import random import re @@ -76,6 +76,7 @@ from ..action.action_object import ActionObject from ..context import AuthedServiceContext from ..dataset.dataset import Asset +from ..dataset.dataset import Dataset from ..job.job_stash import Job from ..output.output_service import ExecutionOutput from ..output.output_service import OutputService @@ -778,7 +779,7 @@ def byte_code(self) -> PyCodeObject | None: return compile_byte_code(self.parsed_code) @property - def assets(self) -> dict[str, Asset] | SyftError: + def assets(self) -> list[Asset] | SyftError: if not self.input_policy: return {} @@ -787,13 +788,15 @@ def assets(self) -> dict[str, Asset] | SyftError: return api # get all assets on the node - datasets = api.services.dataset.get_all() + datasets: list[Dataset] = api.services.dataset.get_all() if isinstance(datasets, SyftError): return datasets - all_assets = { - asset.action_id: asset - for asset in itertools.chain.from_iterable(x.asset_list for x in datasets) - } + + all_assets = {} + for dataset in datasets: + for asset in dataset.asset_list: + asset._dataset_name = dataset.name + all_assets[asset.action_id] = asset # get a flat dict of all inputs all_inputs = {} @@ -802,20 +805,26 @@ def assets(self) -> dict[str, Asset] | SyftError: all_inputs.update(vals) # map the action_id to the asset - used_assets = {} - for kwarg, action_id in all_inputs.items(): - used_assets[kwarg] = all_assets.get(action_id, None) + used_assets = [] + for kwarg_name, action_id in all_inputs.items(): + asset = all_assets.get(action_id, None) + asset._kwarg_name = kwarg_name + used_assets.append(asset) return used_assets @property - def _asset_str(self) -> str | SyftError: - assets = self.assets - if isinstance(assets, SyftError): - return assets - asset_str_list = [ - f"{kwarg}={repr(asset)}" for kwarg, asset in self.assets.items() - ] - asset_str = "\n".join(asset_str_list) + def _asset_json(self) -> str | SyftError: + if isinstance(self.assets, SyftError): + return self.assets + used_assets = {} + for asset in self.assets: + used_assets[asset._kwarg_name] = { + "source_dataset": asset._dataset_name, + "source_asset": asset.name, + "action_id": asset.action_id.no_dash, + "source_node": asset.node_uid.no_dash, + } + asset_str = json.dumps(used_assets, indent=2) return asset_str def get_sync_dependencies( @@ -899,6 +908,11 @@ def _inner_repr(self, level: int = 0) -> str: constants = [x for x in args if isinstance(x, Constant)] constants_str = "\n\t".join([f"{x.kw}: {x.val}" for x in constants]) + # indent all lines except the first one + asset_str = "\n".join( + [f" {line}" for line in self._asset_json.split("\n")] + ).lstrip() + md = f"""class UserCode id: UID = {self.id} service_func_name: str = {self.service_func_name} @@ -906,8 +920,8 @@ def _inner_repr(self, level: int = 0) -> str: status: list = {self.code_status} {constants_str} {shared_with_line} + assets: dict = {asset_str} code: -{self._asset_str} {self.raw_code} """ if self.nested_codes != {}: diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index ac9d0114c08..a483cfa2751 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -105,7 +105,8 @@ class Asset(SyftObject): created_at: DateTime = DateTime.now() uploader: Contributor | None = None - __repr_attrs__ = ["name", "shape"] + __repr_attrs__ = ["_kwarg_name", "name", "action_id", "_dataset_name", "node_uid"] + __clipboard_attrs__ = ["action_id", "node_uid", "_dataset_name"] def __init__( self, diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 611ca5e33a2..a592837769f 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -135,7 +135,15 @@ def _create_table_rows( except Exception as e: print(e) value = None - cols[field].append(sanitize_html(str(value))) + + if field in getattr(item, "__clipboard_attrs__", []): + value = { + "value": sanitize_html(str(value)), + "type": "clipboard", + } + else: + value = sanitize_html(str(value)) + cols[field].append(value) col_lengths = {len(cols[col]) for col in cols.keys()} if len(col_lengths) != 1: From 4ca2550579f0704f8f85e469dd336a51bca0b2ed Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 1 Jul 2024 12:00:28 +0200 Subject: [PATCH 191/309] fix return type --- packages/syft/src/syft/service/code/user_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 6043453f552..fb4875460c1 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -781,7 +781,7 @@ def byte_code(self) -> PyCodeObject | None: @property def assets(self) -> list[Asset] | SyftError: if not self.input_policy: - return {} + return [] api = self._get_api() if isinstance(api, SyftError): From 628fefa6748fbc0748ef73c7fffb96c5e5febbe8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 1 Jul 2024 13:37:01 +0200 Subject: [PATCH 192/309] allocate for user --- .../src/syft/service/blob_storage/service.py | 61 ++++++++++++++++--- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 12a9cc9f8c2..7a6e1cf732e 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -5,6 +5,7 @@ import requests # relative +from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...service.action.action_object import ActionObject from ...store.blob_storage import BlobRetrieval @@ -25,6 +26,7 @@ from ..service import AbstractService from ..service import TYPE_TO_SERVICE from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL from .remote_profile import AzureRemoteProfile from .remote_profile import RemoteProfileStash @@ -209,14 +211,29 @@ def read( return res return SyftError(message=result.err()) - @service_method( - path="blob_storage.allocate", - name="allocate", - roles=GUEST_ROLE_LEVEL, - ) - def allocate( - self, context: AuthedServiceContext, obj: CreateBlobStorageEntry + def _allocate( + self, + context: AuthedServiceContext, + obj: CreateBlobStorageEntry, + uploaded_by: SyftVerifyKey | None = None, ) -> BlobDepositType | SyftError: + """ + Allocate a secure location for the blob storage entry. + + If uploaded_by is None, the credentials of the context will be used. + + Args: + context (AuthedServiceContext): context + obj (CreateBlobStorageEntry): create blob parameters + uploaded_by (SyftVerifyKey | None, optional): Uploader credentials. + Can be used to upload on behalf of another user, needed for data migrations. + Defaults to None. + + Returns: + BlobDepositType | SyftError: Blob deposit + """ + upload_credentials = uploaded_by or context.credentials + with context.node.blob_storage_client.connect() as conn: secure_location = conn.allocate(obj) @@ -229,15 +246,41 @@ def allocate( type_=obj.type_, mimetype=obj.mimetype, file_size=obj.file_size, - uploaded_by=context.credentials, + uploaded_by=upload_credentials, ) blob_deposit = conn.write(blob_storage_entry) - result = self.stash.set(context.credentials, blob_storage_entry) + result = self.stash.set( + upload_credentials, + blob_storage_entry, + ) if result.is_err(): return SyftError(message=f"{result.err()}") return blob_deposit + @service_method( + path="blob_storage.allocate", + name="allocate", + roles=GUEST_ROLE_LEVEL, + ) + def allocate( + self, context: AuthedServiceContext, obj: CreateBlobStorageEntry + ) -> BlobDepositType | SyftError: + return self._allocate(context, obj) + + @service_method( + path="blob_storage.allocate_for_user", + name="allocate_for_user", + roles=ADMIN_ROLE_LEVEL, + ) + def allocate_for_user( + self, + context: AuthedServiceContext, + obj: CreateBlobStorageEntry, + uploaded_by: SyftVerifyKey, + ) -> BlobDepositType | SyftError: + return self._allocate(context, obj, uploaded_by) + @service_method( path="blob_storage.write_to_disk", name="write_to_disk", From 5e1676931dca119b638758903a6b0a13f60e1b70 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 1 Jul 2024 16:43:50 +0200 Subject: [PATCH 193/309] update code.asset and code repr --- .../syft/src/syft/service/code/user_code.py | 77 +++++++++++++------ .../syft/src/syft/service/dataset/dataset.py | 6 +- packages/syft/src/syft/util/table.py | 10 ++- .../syft/tests/syft/users/user_code_test.py | 2 +- 4 files changed, 67 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b8df87fc619..fb4875460c1 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -4,14 +4,13 @@ # stdlib import ast from collections.abc import Callable -from collections.abc import Generator from copy import deepcopy import datetime from enum import Enum import hashlib import inspect from io import StringIO -import itertools +import json import keyword import random import re @@ -77,6 +76,7 @@ from ..action.action_object import ActionObject from ..context import AuthedServiceContext from ..dataset.dataset import Asset +from ..dataset.dataset import Dataset from ..job.job_stash import Job from ..output.output_service import ExecutionOutput from ..output.output_service import OutputService @@ -779,31 +779,53 @@ def byte_code(self) -> PyCodeObject | None: return compile_byte_code(self.parsed_code) @property - def assets(self) -> list[Asset]: - # relative - from ...client.api import APIRegistry - - api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) - if api is None: - return SyftError(message=f"You must login to {self.node_uid}") + def assets(self) -> list[Asset] | SyftError: + if not self.input_policy: + return [] - inputs: Generator = (x for x in range(0)) # create an empty generator - if self.input_policy_init_kwargs is not None: - inputs = ( - uids - for node_identity, uids in self.input_policy_init_kwargs.items() - if node_identity.node_name == api.node_name - ) + api = self._get_api() + if isinstance(api, SyftError): + return api - all_assets = [] - for uid in itertools.chain.from_iterable(x.values() for x in inputs): - if isinstance(uid, UID): - assets = api.services.dataset.get_assets_by_action_id(uid) - if not isinstance(assets, list): - return assets + # get all assets on the node + datasets: list[Dataset] = api.services.dataset.get_all() + if isinstance(datasets, SyftError): + return datasets + + all_assets = {} + for dataset in datasets: + for asset in dataset.asset_list: + asset._dataset_name = dataset.name + all_assets[asset.action_id] = asset + + # get a flat dict of all inputs + all_inputs = {} + inputs = self.input_policy.inputs or {} + for vals in inputs.values(): + all_inputs.update(vals) + + # map the action_id to the asset + used_assets = [] + for kwarg_name, action_id in all_inputs.items(): + asset = all_assets.get(action_id, None) + asset._kwarg_name = kwarg_name + used_assets.append(asset) + return used_assets - all_assets += assets - return all_assets + @property + def _asset_json(self) -> str | SyftError: + if isinstance(self.assets, SyftError): + return self.assets + used_assets = {} + for asset in self.assets: + used_assets[asset._kwarg_name] = { + "source_dataset": asset._dataset_name, + "source_asset": asset.name, + "action_id": asset.action_id.no_dash, + "source_node": asset.node_uid.no_dash, + } + asset_str = json.dumps(used_assets, indent=2) + return asset_str def get_sync_dependencies( self, context: AuthedServiceContext @@ -886,6 +908,11 @@ def _inner_repr(self, level: int = 0) -> str: constants = [x for x in args if isinstance(x, Constant)] constants_str = "\n\t".join([f"{x.kw}: {x.val}" for x in constants]) + # indent all lines except the first one + asset_str = "\n".join( + [f" {line}" for line in self._asset_json.split("\n")] + ).lstrip() + md = f"""class UserCode id: UID = {self.id} service_func_name: str = {self.service_func_name} @@ -893,8 +920,8 @@ def _inner_repr(self, level: int = 0) -> str: status: list = {self.code_status} {constants_str} {shared_with_line} + assets: dict = {asset_str} code: - {self.raw_code} """ if self.nested_codes != {}: diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 9dde84429c4..a483cfa2751 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -105,7 +105,8 @@ class Asset(SyftObject): created_at: DateTime = DateTime.now() uploader: Contributor | None = None - __repr_attrs__ = ["name", "shape"] + __repr_attrs__ = ["_kwarg_name", "name", "action_id", "_dataset_name", "node_uid"] + __clipboard_attrs__ = ["action_id", "node_uid", "_dataset_name"] def __init__( self, @@ -178,6 +179,9 @@ def _repr_html_(self) -> Any: {mock_table_line} """ + def __repr__(self) -> str: + return f"Asset(name='{self.name}', node_uid='{self.node_uid}', action_id='{self.action_id}')" + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str = f"Asset: {self.name}\n" _repr_str += f"Pointer Id: {self.action_id}\n" diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 611ca5e33a2..a592837769f 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -135,7 +135,15 @@ def _create_table_rows( except Exception as e: print(e) value = None - cols[field].append(sanitize_html(str(value))) + + if field in getattr(item, "__clipboard_attrs__", []): + value = { + "value": sanitize_html(str(value)), + "type": "clipboard", + } + else: + value = sanitize_html(str(value)) + cols[field].append(value) col_lengths = {len(cols[col]) for col in cols.keys()} if len(col_lengths) != 1: diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 69f69ab76d4..a24897c62a8 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -166,7 +166,7 @@ def func(asset): c for c in request.changes if (isinstance(c, UserCodeStatusChange)) ) - assert status_change.code.assets[0].model_dump( + assert status_change.code.assets["asset"].model_dump( mode="json" ) == asset_input.model_dump(mode="json") From e6908e967773bcbcfa62322fac436e42137f30fd Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 1 Jul 2024 17:10:12 +0200 Subject: [PATCH 194/309] migrate blob storage --- .../1c-migrate-to-new-node.ipynb | 3518 +++++++++++++++++ .../service/migration/migration_service.py | 109 +- packages/syft/src/syft/types/blob_storage.py | 12 + 3 files changed, 3608 insertions(+), 31 deletions(-) create mode 100644 notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb diff --git a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb new file mode 100644 index 00000000000..504b6cd0695 --- /dev/null +++ b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb @@ -0,0 +1,3518 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "from syft.service.log.log import SyftLogV3\n", + "from syft.types.syft_object import Context\n", + "from syft.types.syft_object import SyftObject\n", + "from syft.service.user.user import User\n", + "\n", + "from syft.types.blob_storage import BlobStorageEntry, CreateBlobStorageEntry" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "syft version: 0.8.7-beta.9\n" + ] + } + ], + "source": [ + "print(f\"syft version: {sy.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "TODOS\n", + "- [x] action objects\n", + "- [x] maybe an example of how to migrate one object type in a custom way\n", + "- [x] check SyftObjectRegistry and compare with current implementation\n", + "- [x] run unit tests\n", + "- [ ] finalize notebooks for testing, run in CI\n", + "- [ ] other tasks defined in tickets" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Staging Protocol Changes...\n", + "Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/db/21519e1e3e664b38a635dc951c293158.sqlite\n", + "Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/db/21519e1e3e664b38a635dc951c293158.sqlite\n", + "Migrating data for: NodeSettings table.\n", + "Creating default worker image with tag='local-dev'\n", + "Setting up worker poolname=default-pool workers=2 image_uid=a1f62fc8f2ac4e32a70a90260019f831 in_memory=True\n" + ] + } + ], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged into as \n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "# Client side migrations" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ef428d2d-62d8-4b14-a8c9-89d7cc4f6a8a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Staging Protocol Changes...\n", + "Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/db/a3edccc59f384307aeaa1a50714c2300.sqlite\n", + "Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/db/a3edccc59f384307aeaa1a50714c2300.sqlite\n", + "Creating default worker image with tag='local-dev'\n", + "Setting up worker poolname=default-pool workers=2 image_uid=4ba80c1b73ae44819a22f9c4811230c2 in_memory=True\n", + "Created default worker pool.\n", + "Logged into as \n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "temp_node = sy.orchestra.launch(\n", + " name=\"temp_node\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + " reset=True,\n", + ")\n", + "\n", + "temp_client = temp_node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## document store objects" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "migration_dict = client.services.migration.get_migration_objects(get_all=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{syft.service.queue.queue_stash.ActionQueueItem: [: completed],\n", + " syft.service.user.user.User: [syft.service.user.user.User,\n", + " syft.service.user.user.User],\n", + " syft.service.worker.worker_pool.SyftWorker: [syft.service.worker.worker_pool.SyftWorker,\n", + " syft.service.worker.worker_pool.SyftWorker],\n", + " syft.service.settings.settings.NodeSettings: [syft.service.settings.settings.NodeSettings],\n", + " syft.service.dataset.dataset.Dataset: [syft.service.dataset.dataset.Dataset],\n", + " syft.service.code.user_code.UserCode: [syft.service.code.user_code.UserCode],\n", + " syft.service.log.log.SyftLog: [syft.service.log.log.SyftLog],\n", + " syft.service.request.request.Request: [syft.service.request.request.Request],\n", + " syft.service.job.job_stash.Job: [syft.service.job.job_stash.Job],\n", + " syft.service.notifier.notifier.NotifierSettings: [syft.service.notifier.notifier.NotifierSettings],\n", + " syft.service.notification.notifications.Notification: [syft.service.notification.notifications.Notification,\n", + " syft.service.notification.notifications.Notification,\n", + " syft.service.notification.notifications.Notification],\n", + " syft.service.code_history.code_history.CodeHistory: [syft.service.code_history.code_history.CodeHistory],\n", + " syft.types.blob_storage.BlobStorageEntry: [syft.types.blob_storage.BlobStorageEntry,\n", + " syft.types.blob_storage.BlobStorageEntry,\n", + " syft.types.blob_storage.BlobStorageEntry],\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState: [syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", + " syft.service.migration.object_migration_state.SyftObjectMigrationState],\n", + " syft.service.worker.worker_image.SyftWorkerImage: [syft.service.worker.worker_image.SyftWorkerImage,\n", + " syft.service.worker.worker_image.SyftWorkerImage],\n", + " syft.service.worker.worker_pool.WorkerPool: [syft.service.worker.worker_pool.WorkerPool],\n", + " syft.service.output.output_service.ExecutionOutput: [syft.service.output.output_service.ExecutionOutput],\n", + " syft.service.code.user_code.UserCodeStatusCollection: [{NodeIdentity : (, '')}]}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "migration_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "def custom_migration_function(context, obj: SyftObject, klass) -> SyftObject:\n", + " # Here, we are just doing the same, but this is where you would write your custom logic\n", + " return obj.migrate_to(klass.__version__, context)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# this wont work in the cases where the context is actually used,\n", + "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", + "\n", + "context = Context()\n", + "migrated_objects = []\n", + "for klass, objects in migration_dict.items():\n", + " for obj in objects:\n", + " if isinstance(obj, BlobStorageEntry):\n", + " continue\n", + " elif isinstance(obj, SyftLogV3):\n", + " migrated_obj = custom_migration_function(context, obj, klass)\n", + " else:\n", + " migrated_obj = obj.migrate_to(klass.__version__, context)\n", + "\n", + " migrated_objects.append(migrated_obj)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO what to do with workerpools\n", + "# TODO what to do with admin? @yash: can we make new node with existing verifykey?\n", + "# TODO check asset AO is not saved in blobstorage" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "12", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " #6e81a99a5c264407abee9e635f36ad9e already exists\n", + " #00bd74c6030e4af0b8844e00da8185c2 already exists\n", + " #22c6b28b12e94052aaf1adf4adc79741 already exists\n", + " #9ca8930bc5b04aa1a7b1835d69b8e4e2 already exists\n", + " #7ec16f3a86a442179f69ef4640f16c10 already exists\n", + " #2611c202fe984bec80cadfc482be3cb8 already exists\n", + " #ba943fd900a44527ba5d155128525072 already exists\n", + " #dc4f177789974647b6ea47657cf3fc76 already exists\n", + " #fe51de3d24ac482ba369fc14b9bda9f9 already exists\n", + " #aed2854dca67456196f91b68da949e5a already exists\n", + " #ab1fcea789344984814f8a4bcea7cd02 already exists\n", + " #77a3e8772dae4aa288b468d96e799364 already exists\n", + " #160150b1cee64df0830f00a1bee8d857 already exists\n", + " #a1f62fc8f2ac4e32a70a90260019f831 already exists\n", + " #3c32960a61c7471991de0d88b28a662e already exists\n" + ] + } + ], + "source": [ + "res = temp_client.services.migration.create_migrated_objects(migrated_objects)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3db3dc71-c06c-4b48-b884-dde1ff6d1838", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: success

" + ], + "text/plain": [ + "SyftSuccess: success" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(res, sy.SyftSuccess)" + ] + }, + { + "cell_type": "markdown", + "id": "b95366c4-78f8-484d-b53f-c6eb5d4f5c1c", + "metadata": {}, + "source": [ + "# Migrate blobstorage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3acabe3c-d164-4a5c-b1c5-c0f676c00ec5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "157231d6-c4e5-448a-9bde-190c8cb0df89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "```python\n", + "class BlobStorageEntry:\n", + " id: str = 5e9a51c12f664e39bcc6099981292e3f\n", + "\n", + "```" + ], + "text/plain": [ + "syft.types.blob_storage.BlobStorageEntry" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "klass = BlobStorageEntry\n", + "blob_entries = migration_dict[klass]\n", + "obj = blob_entries[0]\n", + "obj" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f8f06ced-8422-497d-a521-0be184017c53", + "metadata": {}, + "outputs": [], + "source": [ + "from io import BytesIO\n", + "import sys\n", + "\n", + "def migrate_blob_entry_data(old_client, new_client, obj, klass) -> sy.SyftSuccess | sy.SyftError:\n", + " migrated_obj = obj.migrate_to(klass.__version__, Context())\n", + " uploaded_by = migrated_obj.uploaded_by\n", + " blob_retrieval = old_client.services.blob_storage.read(obj.id)\n", + " if isinstance(blob_retrieval, sy.SyftError):\n", + " return blob_retrieval\n", + " \n", + " data = blob_retrieval.read()\n", + " # TODO do we have to determine new filesize here?\n", + " serialized = sy.serialize(data, to_bytes=True)\n", + " size = sys.getsizeof(serialized)\n", + " blob_create = CreateBlobStorageEntry.from_blob_storage_entry(obj)\n", + " blob_create.file_size = size\n", + "\n", + " blob_deposit_object = new_client.services.blob_storage.allocate_for_user(blob_create, uploaded_by)\n", + " if isinstance(blob_deposit_object, sy.SyftError):\n", + " return blob_deposit_object\n", + " return blob_deposit_object.write(BytesIO(serialized))\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e108f228-ff8b-41ef-966e-421be7fd39a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: File successfully saved.

" + ], + "text/plain": [ + "SyftSuccess: File successfully saved." + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
SyftSuccess: File successfully saved.

" + ], + "text/plain": [ + "SyftSuccess: File successfully saved." + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
SyftSuccess: File successfully saved.

" + ], + "text/plain": [ + "SyftSuccess: File successfully saved." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for blob_entry in blob_entries:\n", + " res = migrate_blob_entry_data(client, temp_client, blob_entry, BlobStorageEntry)\n", + " display(res)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "e7f005ec-7a6d-4ff6-a2db-f82cf1bfaa17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "

BlobStorageEntry List

\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

Total: 0

\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "[syft.types.blob_storage.BlobStorageEntry,\n", + " syft.types.blob_storage.BlobStorageEntry,\n", + " syft.types.blob_storage.BlobStorageEntry]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.services.blob_storage.get_all()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "23b87c41-e3e9-42b6-814c-075ab2719cbe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'syft_node_location': ,\n", + " 'syft_client_verify_key': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", + " 'id': ,\n", + " 'location': /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/blob/5e9a51c12f664e39bcc6099981292e3f,\n", + " 'type_': numpy.ndarray,\n", + " 'mimetype': 'bytes',\n", + " 'file_size': 609,\n", + " 'no_lines': 0,\n", + " 'uploaded_by': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", + " 'created_at': syft.types.datetime.DateTime,\n", + " 'bucket_name': None}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.services.blob_storage.get_all()[0].__dict__" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "8648ac13-7091-4dbe-9fab-efcc11b4d4ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'syft_node_location': ,\n", + " 'syft_client_verify_key': e296574092d9fe0bbd853b4f0294ca9bc6624ac16c3700da58eff07f69f477f2,\n", + " 'id': ,\n", + " 'location': /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/blob/5e9a51c12f664e39bcc6099981292e3f,\n", + " 'type_': numpy.ndarray,\n", + " 'mimetype': 'bytes',\n", + " 'file_size': 689,\n", + " 'no_lines': 0,\n", + " 'uploaded_by': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", + " 'created_at': syft.types.datetime.DateTime,\n", + " 'bucket_name': None}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp_client.services.blob_storage.get_all()[0].__dict__" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## Actions and ActionObjects" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "migration_action_dict = client.services.migration.get_migration_actionobjects(get_all=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "63c61536-4099-4eb5-ab62-db344a0c49e1", + "metadata": {}, + "outputs": [], + "source": [ + "ao = migration_action_dict[list(migration_action_dict.keys())[0]][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8b1599ea-d940-4415-9f9d-b36056d29e92", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([15, 16, 17, 18, 19])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ao.syft_action_data_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a84a06be-be0e-4378-92c4-a08c55a90dfc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{: Pointer:\n", + "array([15, 16, 17, 18, 19]), : Pointer:\n", + "array([15, 16, 17, 18, 19])}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node.python_node.action_store.data" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "9fec28ad-5417-4e94-a26f-3cf9cc5d3412", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.jobs[0].result.id.id" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7d006054-0248-4901-bb13-6d27e6a0870b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'syft_node_location': None,\n", + " 'syft_client_verify_key': None,\n", + " 'id': ,\n", + " 'syft_action_data_cache': array([15, 16, 17, 18, 19]),\n", + " 'syft_blob_storage_entry_id': None,\n", + " 'syft_parent_hashes': None,\n", + " 'syft_parent_op': None,\n", + " 'syft_parent_args': None,\n", + " 'syft_parent_kwargs': None,\n", + " 'syft_history_hash': 1494250481592695163,\n", + " 'syft_node_uid': None,\n", + " 'syft_pre_hooks__': {'ALWAYS': [ 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>,\n", + " 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>],\n", + " 'ON_POINTERS': [ 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>]},\n", + " 'syft_post_hooks__': {'ALWAYS': [ 'Result[Ok[Any], Err[str]]'>],\n", + " 'ON_POINTERS': []},\n", + " 'syft_twin_type': ,\n", + " 'syft_passthrough_attrs': ['is_mock',\n", + " 'is_real',\n", + " 'is_twin',\n", + " 'is_pointer',\n", + " 'request',\n", + " '__repr__',\n", + " '_repr_markdown_',\n", + " 'syft_twin_type',\n", + " '_repr_debug_',\n", + " 'as_empty',\n", + " 'get',\n", + " 'is_link',\n", + " 'wait',\n", + " '_save_to_blob_storage',\n", + " '_save_to_blob_storage_',\n", + " 'syft_action_data',\n", + " '__check_action_data',\n", + " 'as_empty_data',\n", + " '_set_obj_location_',\n", + " 'syft_action_data_cache',\n", + " 'reload_cache',\n", + " 'syft_resolved',\n", + " 'refresh_object',\n", + " 'syft_action_data_node_id',\n", + " 'node_uid',\n", + " '__sha256__',\n", + " '__hash_exclude_attrs__',\n", + " '__hash__',\n", + " 'create_shareable_sync_copy',\n", + " '_has_private_sync_attrs',\n", + " '__exclude_sync_diff_attrs__',\n", + " '__repr_attrs__'],\n", + " 'syft_action_data_type': numpy.ndarray,\n", + " 'syft_action_data_repr_': 'array([15, 16, 17, 18, 19])',\n", + " 'syft_action_data_str_': '[15 16 17 18 19]',\n", + " 'syft_has_bool_attr': True,\n", + " 'syft_resolve_data': None,\n", + " 'syft_created_at': syft.types.datetime.DateTime,\n", + " 'syft_resolved': True,\n", + " 'syft_action_data_node_id': None,\n", + " 'syft_dont_wrap_attrs': ['dtype', 'shape']}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node.python_node.action_store.data[sy.UID(\"d05d03da6ff44a57b1d48611e927a68a\")].__dict__" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# this wont work in the cases where the context is actually used, but since this you would need custom logic here anyway\n", + "# it doesnt matter\n", + "context = Context()\n", + "migrated_actionobjects = []\n", + "for klass, objects in migration_action_dict.items():\n", + " for obj in objects:\n", + " # custom migration logic here\n", + " migrated_actionobject = obj.migrate_to(klass.__version__, context)\n", + " migrated_actionobjects.append(migrated_actionobject)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "17", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Pointer:\n", + "array([15, 16, 17, 18, 19]), Pointer:\n", + "array([15, 16, 17, 18, 19])]\n" + ] + } + ], + "source": [ + "print(migrated_actionobjects)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "res = temp_client.services.migration.update_migrated_actionobjects(migrated_actionobjects)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "e2e5064d-2ab8-4d2f-a13e-6215545ea118", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: succesfully migrated actionobjects

" + ], + "text/plain": [ + "SyftSuccess: succesfully migrated actionobjects" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(res, sy.SyftSuccess)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "2ef24ba9-3781-4e55-9c67-df68465ab080", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[15 16 17 18 19]\n", + "[15 16 17 18 19]\n" + ] + } + ], + "source": [ + "for uid in temp_node.python_node.action_store.data:\n", + " ao = temp_client.services.action.get(uid)\n", + " ao.reload_cache()\n", + " print(ao.syft_action_data_cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "ff2a5b44-37a7-4b49-a332-ca93ee71b94f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[15 16 17 18 19]\n", + "[15 16 17 18 19]\n" + ] + } + ], + "source": [ + "for uid in node.python_node.action_store.data:\n", + " ao = client.services.action.get(uid)\n", + " ao.reload_cache()\n", + " print(ao.syft_action_data_cache)" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## Store metadata\n", + "\n", + "- Permissions\n", + "- StoragePermissions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "store_metadata = client.services.migration.get_all_store_metadata()\n", + "store_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in store_metadata.items():\n", + " if len(v.permissions):\n", + " print(\n", + " k, len(v.permissions), len(v.permissions) == len(migration_dict.get(k, []))\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "# Test update method with a temp node\n", + "# After update, all metadata should match between the nodes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "temp_client.services.migration.update_store_metadata(store_metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "for cname, real_partition in node.python_node.document_store.partitions.items():\n", + " temp_partition = temp_node.python_node.document_store.partitions[cname]\n", + "\n", + " temp_perms = dict(temp_partition.permissions.items())\n", + " real_perms = dict(real_partition.permissions.items())\n", + "\n", + " # Only look at migrated items\n", + " temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", + " assert temp_perms == real_perms\n", + "\n", + " temp_storage = dict(temp_partition.storage_permissions.items())\n", + " real_storage = dict(real_partition.storage_permissions.items())\n", + " temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", + "\n", + " assert temp_storage == real_storage\n", + "\n", + "# Action store\n", + "real_partition = node.python_node.action_store\n", + "temp_partition = temp_node.python_node.action_store\n", + "temp_perms = dict(temp_partition.permissions.items())\n", + "real_perms = dict(real_partition.permissions.items())\n", + "\n", + "# Only look at migrated items\n", + "temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", + "assert temp_perms == real_perms\n", + "\n", + "temp_storage = dict(temp_partition.storage_permissions.items())\n", + "real_storage = dict(real_partition.storage_permissions.items())\n", + "temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", + "\n", + "assert temp_storage == real_storage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index f3678d34d3b..d84db3b2361 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -295,11 +295,6 @@ def _get_migration_objects( context=context, object_types=document_store_object_types ) - if klasses_to_migrate: - print( - f"Classes in Document Store that need migration: {klasses_to_migrate}" - ) - result = defaultdict(list) for klass in klasses_to_migrate: @@ -324,6 +319,70 @@ def _get_migration_objects( return Ok(dict(result)) + def _search_partition_for_object( + self, context: AuthedServiceContext, obj: SyftObject + ) -> Result[StorePartition, str]: + klass = type(obj) + mro = klass.__mro__ + class_index = 0 + object_partition = None + while len(mro) > class_index: + canonical_name = mro[class_index].__canonical_name__ + object_partition = self.store.partitions.get(canonical_name) + if object_partition is not None: + break + class_index += 1 + if object_partition is None: + return Err(f"Object partition not found for {klass}") + return Ok(object_partition) + + @service_method( + path="migration.create_migrated_objects", + name="create_migrated_objects", + roles=ADMIN_ROLE_LEVEL, + ) + def create_migrated_objects( + self, + context: AuthedServiceContext, + migrated_objects: list[SyftObject], + ignore_existing: bool = True, + ) -> SyftSuccess | SyftError: + res = self._create_migrated_objects(context, migrated_objects) + if res.is_err(): + return SyftError(message=res.value) + else: + return SyftSuccess(message=res.ok()) + + def _create_migrated_objects( + self, + context: AuthedServiceContext, + migrated_objects: list[SyftObject], + ignore_existing: bool = True, + ) -> Result[str, str]: + for migrated_object in migrated_objects: + object_partition_or_err = self._search_partition_for_object( + context, migrated_object + ) + if object_partition_or_err.is_err(): + return object_partition_or_err + object_partition = object_partition_or_err.ok() + + # upsert the object + result = object_partition.set( + context.credentials, + obj=migrated_object, + ) + if result.is_err(): + if ignore_existing and "Duplication Key Error" in result.value: + print( + f"{type(migrated_object)} #{migrated_object.id} already exists" + ) + continue + else: + return result + + return Ok(value="success") + @service_method( path="migration.update_migrated_objects", name="update_migrated_objects", @@ -342,28 +401,18 @@ def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> Result[str, str]: for migrated_object in migrated_objects: - klass = type(migrated_object) - mro = klass.__mro__ - class_index = 0 - object_partition = None - while len(mro) > class_index: - canonical_name = mro[class_index].__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) - if object_partition is not None: - break - class_index += 1 - if object_partition is None: - return Err(f"Object partition not found for {klass}") + object_partition_or_err = self._search_partition_for_object( + context, migrated_object + ) + if object_partition_or_err.is_err(): + return object_partition_or_err + object_partition = object_partition_or_err.ok() # canonical_name = mro[class_index].__canonical_name__ # object_partition = self.store.partitions.get(canonical_name) # print(klass, canonical_name, object_partition) qk = object_partition.settings.store_key.with_obj(migrated_object.id) - # print(migrated_object) - # stdlib - import sys - result = object_partition._update( context.credentials, qk=qk, @@ -375,7 +424,7 @@ def _update_migrated_objects( if result.is_err(): print("ERR:", result.value, file=sys.stderr) - print("ERR:", klass, file=sys.stderr) + print("ERR:", type(migrated_object), file=sys.stderr) print("ERR:", migrated_object, file=sys.stderr) # return result return Ok(value="success") @@ -480,17 +529,17 @@ def migrate_data( roles=ADMIN_ROLE_LEVEL, ) def get_migration_actionobjects( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, get_all: bool = False ) -> dict | SyftError: - res = self._get_migration_actionobjects(context) + res = self._get_migration_actionobjects(context, get_all=get_all) if res.is_ok(): return res.ok() else: return SyftError(message=res.value) def _get_migration_actionobjects( - self, context: AuthedServiceContext - ) -> Result[dict[type[SyftObject], SyftObject], str]: + self, context: AuthedServiceContext, get_all: bool = False + ) -> Result[dict[type[SyftObject], list[SyftObject]], str]: # Track all object types from action store action_object_types = [Action, ActionObject] action_object_types.extend(ActionObject.__subclasses__()) @@ -498,9 +547,7 @@ def _get_migration_actionobjects( action_object_pending_migration = self._find_klasses_pending_for_migration( context=context, object_types=action_object_types ) - result_dict: dict[type[SyftObject], SyftObject] = { - x: [] for x in action_object_pending_migration - } + result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list) action_store = context.node.action_store action_store_objects_result = action_store._all( context.credentials, has_permission=True @@ -510,9 +557,9 @@ def _get_migration_actionobjects( action_store_objects = action_store_objects_result.ok() for obj in action_store_objects: - if type(obj) in result_dict: + if get_all or type(obj) in action_object_pending_migration: result_dict[type(obj)].append(obj) - return Ok(result_dict) + return Ok(dict(result_dict)) @service_method( path="migration.update_migrated_actionobjects", diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index a92134f26f7..a493039a1ab 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -326,6 +326,18 @@ class CreateBlobStorageEntry(SyftObject): file_size: int extensions: list[str] = [] + @classmethod + def from_blob_storage_entry(cls, entry: BlobStorageEntry) -> Self: + # TODO extensions are not stored in the BlobStorageEntry, + # so a blob entry from path might get a different filename + # after uploading. + return cls( + id=entry.id, + type_=entry.type_, + mimetype=entry.mimetype, + file_size=entry.file_size, + ) + @classmethod def from_obj(cls, obj: SyftObject, file_size: int | None = None) -> Self: if file_size is None: From dde19ac86944a7f1507ef0b806985672569254c5 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Mon, 1 Jul 2024 14:38:59 -0300 Subject: [PATCH 195/309] Rename 'unsafe_function' to 'run' --- .../api/0.9/02-review-code-and-approve.ipynb | 522 ++++++++++++++ notebooks/api/0.9/05-custom-policy.ipynb | 647 ++++++++++++++++++ .../tutorials/hello-syft/01-hello-syft.ipynb | 2 +- .../model-auditing/colab/01-user-log.ipynb | 2 +- .../02-data-owner-review-approve-code.ipynb | 2 +- .../01-reading-from-a-csv.ipynb | 2 +- ...lecting-data-finding-common-complain.ipynb | 2 +- ...orough-has-the-most-noise-complaints.ipynb | 2 +- ...-weekday-bike-most-groupby-aggregate.ipynb | 2 +- ...ing-dataframes-scraping-weather-data.ipynb | 2 +- ...rations-which-month-was-the-snowiest.ipynb | 2 +- .../07-cleaning-up-messy-data.ipynb | 2 +- .../08-how-to-deal-with-timestamps.ipynb | 2 +- .../syft/src/syft/service/code/user_code.py | 4 +- .../syft/tests/syft/users/user_code_test.py | 4 +- 15 files changed, 1184 insertions(+), 15 deletions(-) create mode 100644 notebooks/api/0.9/02-review-code-and-approve.ipynb create mode 100644 notebooks/api/0.9/05-custom-policy.ipynb diff --git a/notebooks/api/0.9/02-review-code-and-approve.ipynb b/notebooks/api/0.9/02-review-code-and-approve.ipynb new file mode 100644 index 00000000000..24956e4a48b --- /dev/null +++ b/notebooks/api/0.9/02-review-code-and-approve.ipynb @@ -0,0 +1,522 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reviewing and Approving Code in Syft as a Data Owner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "SYFT_VERSION = \">=0.8.2.b0,<0.9\"\n", + "package_string = f'\"syft{SYFT_VERSION}\"'\n", + "# %pip install {package_string} -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "\n", + "sy.requires(SYFT_VERSION)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Login to Syft Domain Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Launch and connect to test-domain-1 server we setup in the previous notebook\n", + "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Log into the node with default root credentials\n", + "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Selecting Project in the Syft Domain Server\n", + "\n", + "Let's see all the projects that are created by Data Scientists in this Domain Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "domain_client.projects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Select the project you want to work with\n", + "project = domain_client.projects[0]\n", + "project" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All code requests submitted by the Data Scientists as a part of this project can be accessed by invoking the following" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "project.requests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Tests\n", + "assert len(project.events) == 1\n", + "assert isinstance(project.events[0], sy.service.project.project.ProjectRequest)\n", + "assert len(project.requests) == 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reviewing Code Requests\n", + "\n", + "To review a specific request, we can select it and explore its attributes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "request = project.requests[0]\n", + "request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# See the code written by the Data Scientist and its metadata in the request\n", + "func = request.code\n", + "func" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# To see just the code\n", + "func.show_code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Reference to the assets that the function will run on\n", + "func.assets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Viewing the Asset and it's mock/private variants that the Data Scientist will be running on" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "asset = func.assets[0]\n", + "asset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "mock_data = asset.mock\n", + "mock_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Private data. Accessible as we are logged in with Data Owner credentials\n", + "pvt_data = asset.data\n", + "pvt_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Tests\n", + "assert len(asset.data_subjects) == 1\n", + "assert mock_data.shape == (10, 22)\n", + "assert pvt_data.shape == (10, 22)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Policies in Syft Function\n", + "\n", + "Each Syft Function requires an Input & Output policy attached to the python function against which executions are verified.\n", + "\n", + "Syft provides the following default policies:\n", + "* `sy.ExactMatch()` Input policy ensures that function executes against the exact inputs specified by Data Scientist.\n", + "* `sy.OutputPolicyExecuteOnce()` Output policy makes sure that the Data Scientist can run the function only once against the input.\n", + "\n", + "We can also implement custom policies based on our requirements. (Refer to notebook [05-custom-policy](./05-custom-policy.ipynb) for more information.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "op = func.output_policy_type\n", + "op" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# See the implementation of the policy\n", + "print(op.policy_code)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Policies provided by Syft are available before approving the code,\n", + "# Custom policies are only safe to use once the code is approved.\n", + "\n", + "assert func.output_policy is not None\n", + "assert func.input_policy is not None\n", + "\n", + "func.output_policy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execute the Data Scientist's code\n", + "\n", + "While Syft makes sure that the function is not tampered with, it does not perform any validation on the implementation itself.\n", + "\n", + "**It is the Data Owner's responsibility to review the code & verify if it's safe to execute.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Let's grab the actual executable function that was submitted by the user\n", + "users_function = func.run" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the code looks safe, we can go ahead and execute it on the private dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mock_result = users_function(trade_data=mock_data)\n", + "mock_result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "real_result = users_function(trade_data=pvt_data)\n", + "real_result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Approving a request\n", + "\n", + "By calling `request.approve()`, the data scientist can execute their function on the real data, and obtain the result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Uploaded wrong result - we shared mock_result instead of the real_result\n", + "result = request.approve()\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert isinstance(result, sy.SyftSuccess)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Denying a request\n", + "\n", + "At times you would want to deny a request in cases where the output is violating privacy, or if either of the policy is too lineant, or perhaps the code is confusing!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Deny the request with an appropriate reason\n", + "result = request.deny(\n", + " reason=(\n", + " \"The Submitted UserCode does not add differential privacy to the output.\"\n", + " \"Kindly add differential privacy and resubmit the code.\"\n", + " )\n", + ")\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(result, sy.SyftSuccess)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can verify the status by checking our request list\n", + "project.requests" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Re-approving requests\n", + "\n", + "Let's re-approve the request so that we can work with the results in the later notebooks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = request.approve()\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify the request status again\n", + "project.requests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Cleanup local domain server\n", + "\n", + "if node.node_type.value == \"python\":\n", + " node.land()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the code request has been approved, let's go through the [03-data-scientist-download-result](./03-data-scientist-download-result.ipynb) notebook to see how a Data Scientist can access the results." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/api/0.9/05-custom-policy.ipynb b/notebooks/api/0.9/05-custom-policy.ipynb new file mode 100644 index 00000000000..ee6d44425ed --- /dev/null +++ b/notebooks/api/0.9/05-custom-policy.ipynb @@ -0,0 +1,647 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "SYFT_VERSION = \">=0.8.2.b0,<0.9\"\n", + "package_string = f'\"syft{SYFT_VERSION}\"'\n", + "# %pip install {package_string} -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# third party\n", + "import numpy as np\n", + "\n", + "# syft absolute\n", + "import syft as sy\n", + "\n", + "sy.requires(SYFT_VERSION)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True, reset=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "domain_client.register(\n", + " email=\"newuser@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "client_low_ds = node.login(email=\"newuser@openmined.org\", password=\"pw\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# stdlib\n", + "from typing import Any" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class RepeatedCallPolicy(sy.CustomOutputPolicy):\n", + " n_calls: int = 0\n", + " downloadable_output_args: list[str] = []\n", + " state: dict[Any, Any] = {}\n", + "\n", + " def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):\n", + " self.downloadable_output_args = (\n", + " downloadable_output_args if downloadable_output_args is not None else []\n", + " )\n", + " self.n_calls = n_calls\n", + " self.state = {\"counts\": 0}\n", + "\n", + " def public_state(self):\n", + " return self.state[\"counts\"]\n", + "\n", + " def update_policy(self, context, outputs):\n", + " self.state[\"counts\"] += 1\n", + "\n", + " def apply_to_output(self, context, outputs, update_policy=True):\n", + " if hasattr(outputs, \"syft_action_data\"):\n", + " outputs = outputs.syft_action_data\n", + " output_dict = {}\n", + " if self.state[\"counts\"] < self.n_calls:\n", + " for output_arg in self.downloadable_output_args:\n", + " output_dict[output_arg] = outputs[output_arg]\n", + " if update_policy:\n", + " self.update_policy(context, outputs)\n", + " else:\n", + " return None\n", + " return output_dict\n", + "\n", + " def _is_valid(self, context):\n", + " return self.state[\"counts\"] < self.n_calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=[\"y\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "policy.n_calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "policy.downloadable_output_args" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "policy.init_kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(policy.init_kwargs)\n", + "a_obj = sy.ActionObject.from_obj({\"y\": [1, 2, 3]})\n", + "x = policy.apply_to_output(None, a_obj)\n", + "x[\"y\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "policy.n_calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "x = np.array([1, 2, 3])\n", + "x_pointer = sy.ActionObject.from_obj(x)\n", + "x_pointer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "x_pointer = x_pointer.send(domain_client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "from result import Err\n", + "from result import Ok\n", + "\n", + "# syft absolute\n", + "from syft.client.api import AuthedServiceContext\n", + "from syft.client.api import NodeIdentity\n", + "\n", + "\n", + "class CustomExactMatch(sy.CustomInputPolicy):\n", + " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", + " pass\n", + "\n", + " def filter_kwargs(self, kwargs, context, code_item_id):\n", + " # stdlib\n", + "\n", + " try:\n", + " allowed_inputs = self.allowed_ids_only(\n", + " allowed_inputs=self.inputs, kwargs=kwargs, context=context\n", + " )\n", + " results = self.retrieve_from_db(\n", + " code_item_id=code_item_id,\n", + " allowed_inputs=allowed_inputs,\n", + " context=context,\n", + " )\n", + " except Exception as e:\n", + " return Err(str(e))\n", + " return results\n", + "\n", + " def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft.service.action.action_object import TwinMode\n", + "\n", + " action_service = context.node.get_service(\"actionservice\")\n", + " code_inputs = {}\n", + "\n", + " # When we are retrieving the code from the database, we need to use the node's\n", + " # verify key as the credentials. This is because when we approve the code, we\n", + " # we allow the private data to be used only for this specific code.\n", + " # but we are not modifying the permissions of the private data\n", + "\n", + " root_context = AuthedServiceContext(\n", + " node=context.node, credentials=context.node.verify_key\n", + " )\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " for var_name, arg_id in allowed_inputs.items():\n", + " kwarg_value = action_service._get(\n", + " context=root_context,\n", + " uid=arg_id,\n", + " twin_mode=TwinMode.NONE,\n", + " has_permission=True,\n", + " )\n", + " if kwarg_value.is_err():\n", + " return Err(kwarg_value.err())\n", + " code_inputs[var_name] = kwarg_value.ok()\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " return Ok(code_inputs)\n", + "\n", + " def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + " ):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs\n", + "\n", + " def _is_valid(\n", + " self,\n", + " context,\n", + " usr_input_kwargs,\n", + " code_item_id,\n", + " ):\n", + " filtered_input_kwargs = self.filter_kwargs(\n", + " kwargs=usr_input_kwargs,\n", + " context=context,\n", + " code_item_id=code_item_id,\n", + " )\n", + "\n", + " if filtered_input_kwargs.is_err():\n", + " return filtered_input_kwargs\n", + "\n", + " filtered_input_kwargs = filtered_input_kwargs.ok()\n", + "\n", + " expected_input_kwargs = set()\n", + " for _inp_kwargs in self.inputs.values():\n", + " for k in _inp_kwargs.keys():\n", + " if k not in usr_input_kwargs:\n", + " return Err(f\"Function missing required keyword argument: '{k}'\")\n", + " expected_input_kwargs.update(_inp_kwargs.keys())\n", + "\n", + " permitted_input_kwargs = list(filtered_input_kwargs.keys())\n", + " not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n", + " if len(not_approved_kwargs) > 0:\n", + " return Err(\n", + " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n", + " )\n", + " return Ok(True)\n", + "\n", + "\n", + "def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + "):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + " from syft.client.api import NodeIdentity\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "@sy.syft_function(\n", + " input_policy=CustomExactMatch(x=x_pointer),\n", + " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=[\"y\"]),\n", + ")\n", + "def func(x):\n", + " return {\"y\": x + 1}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "request = client_low_ds.code.request_code_execution(func)\n", + "request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "request_id = request.id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "client_low_ds.code.get_all()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "for request in domain_client.requests:\n", + " if request.id == request_id:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "func = request.code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "# Custom policies need to be approved before they can be viewed and used\n", + "assert func.input_policy is None\n", + "assert func.output_policy is None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result = func.run(x=x_pointer)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "request.approve()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "assert func.input_policy is not None\n", + "assert func.output_policy is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "res_ptr = client_low_ds.code.func(x=x_pointer)\n", + "res_ptr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "res = res_ptr.get()\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert (res[\"y\"] == np.array([2, 3, 4])).all()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert set(res.keys()) == set(\"y\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for code in domain_client.code.get_all():\n", + " if code.service_func_name == \"func\":\n", + " break\n", + "print(code.output_policy.state)\n", + "assert code.output_policy.state == {\"counts\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if node.node_type.value == \"python\":\n", + " node.land()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorials/hello-syft/01-hello-syft.ipynb b/notebooks/tutorials/hello-syft/01-hello-syft.ipynb index 858b1042e69..e51018e2be0 100644 --- a/notebooks/tutorials/hello-syft/01-hello-syft.ipynb +++ b/notebooks/tutorials/hello-syft/01-hello-syft.ipynb @@ -412,7 +412,7 @@ "metadata": {}, "outputs": [], "source": [ - "get_mean_age_user_function = usercode.unsafe_function" + "get_mean_age_user_function = usercode.run" ] }, { diff --git a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb index eb0d3df04a9..6e1e438c04e 100644 --- a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb +++ b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb @@ -592,7 +592,7 @@ "metadata": {}, "outputs": [], "source": [ - "real_result = request.code.unsafe_function(data=asset.data)\n", + "real_result = request.code.run(data=asset.data)\n", "real_result" ] }, diff --git a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb index a919a9e2a8e..d4a9c70cc5e 100644 --- a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb +++ b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb @@ -156,7 +156,7 @@ "metadata": {}, "outputs": [], "source": [ - "users_function = user_code.unsafe_function\n", + "users_function = user_code.run\n", "users_function" ] }, diff --git a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb index 5026c62450f..368d7090d57 100644 --- a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb +++ b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb @@ -738,7 +738,7 @@ }, "outputs": [], "source": [ - "get_col_user_function = func.unsafe_function" + "get_col_user_function = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb index 97482b84975..c7a8d83c7f6 100644 --- a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb +++ b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb @@ -913,7 +913,7 @@ }, "outputs": [], "source": [ - "get_counts_user_func = func.unsafe_function" + "get_counts_user_func = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb index 152da5c6c45..407f6507c3a 100644 --- a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb +++ b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb @@ -1030,7 +1030,7 @@ }, "outputs": [], "source": [ - "get_counts_user_func = func.unsafe_function" + "get_counts_user_func = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb index 2ba4b0cfe7c..daed660972d 100644 --- a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb +++ b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb @@ -799,7 +799,7 @@ }, "outputs": [], "source": [ - "get_col_user_function = func.unsafe_function" + "get_col_user_function = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb index 56718085058..0bff86b5c06 100644 --- a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb @@ -986,7 +986,7 @@ }, "outputs": [], "source": [ - "get_col_user_function = func.unsafe_function" + "get_col_user_function = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb index cbdb061df8d..896f842e6cb 100644 --- a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb +++ b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb @@ -889,7 +889,7 @@ }, "outputs": [], "source": [ - "get_col_user_function = func.unsafe_function" + "get_col_user_function = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb index 801509179d4..dab79f0c217 100644 --- a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb @@ -944,7 +944,7 @@ }, "outputs": [], "source": [ - "zip_codes = func.unsafe_function" + "zip_codes = func.run" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb index 58b0132bd25..32beba9af48 100644 --- a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb +++ b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb @@ -895,7 +895,7 @@ }, "outputs": [], "source": [ - "find_recently_installed = func.unsafe_function" + "find_recently_installed = func.run" ] }, { diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b8df87fc619..a866ad64798 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -822,7 +822,7 @@ def get_sync_dependencies( return dependencies @property - def unsafe_function(self) -> Callable | None: + def run(self) -> Callable | None: warning = SyftWarning( message="This code was submitted by a User and could be UNSAFE." ) @@ -864,7 +864,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable | SyftError: # return the results return result except Exception as e: - return SyftError(f"Failed to run unsafe_function. Error: {e}") + return SyftError(f"Failed to execute 'run'. Error: {e}") return wrapper diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 69f69ab76d4..c7a56550d55 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -61,7 +61,7 @@ def test_user_code(worker) -> None: message = root_domain_client.notifications[-1] request = message.link user_code = request.changes[0].code - result = user_code.unsafe_function() + result = user_code.run() request.approve() result = guest_client.api.services.code.mock_syft_func() @@ -355,7 +355,7 @@ def compute_sum(): message = root_domain_client.notifications[-1] request = message.link user_code = request.changes[0].code - result = user_code.unsafe_function() + result = user_code.run() request.approve() result = ds_client.api.services.code.compute_sum() From 42d41b5e626ebd1fa10d3df2094f25702f4b732d Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 10:56:29 +0800 Subject: [PATCH 196/309] Add a unit test for worker deletion --- .../tests/syft/syft_worker_deletion_test.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 packages/syft/tests/syft/syft_worker_deletion_test.py diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py new file mode 100644 index 00000000000..daed38ea2a4 --- /dev/null +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -0,0 +1,82 @@ +# stdlib +from collections.abc import Generator +from secrets import token_hex +from typing import Any + +# third party +import numpy as np +import pytest + +# syft absolute +import syft as sy +from syft.orchestra import NodeHandle +from syft.service.job.job_stash import JobStatus +from syft.service.response import SyftError + + +@pytest.fixture() +def node_args() -> dict[str, Any]: + return {} + + +@pytest.fixture +def node(node_args: dict[str, Any]) -> Generator[NodeHandle, None, None]: + _node = sy.orchestra.launch( + **{ + "name": token_hex(8), + "dev_mode": True, + "reset": True, + "n_consumers": 3, + "create_producer": True, + "queue_port": None, + "local_db": False, + **node_args, + } + ) + # startup code here + yield _node + # Cleanup code + _node.python_node.cleanup() + _node.land() + + +@pytest.mark.parametrize("node_args", [{"n_consumers": 1}]) +@pytest.mark.parametrize("force", [True, False]) +def test_delete_worker(node: NodeHandle, force: bool) -> None: + client = node.login(email="info@openmined.org", password="changethis") + + data = np.array([1, 2, 3]) + data_action_obj = sy.ActionObject.from_obj(data) + data_pointer = data_action_obj.send(client) + + @sy.syft_function_single_use(data=data_pointer) + def compute_mean(data): + # stdlib + import time + + time.sleep(1.5) + return data.mean() + + client.code.request_code_execution(compute_mean) + client.requests[-1].approve() + + job = client.code.compute_mean(data=data_pointer, blocking=False) + + while True: + if (syft_worker_id := client.jobs.get_all()[0].job_worker_id) is not None: + break + + res = client.worker.delete(syft_worker_id, force=force) + assert not isinstance(res, SyftError) + + if not force and len(client.worker.get_all()) > 0: + assert client.worker.get(syft_worker_id).to_be_deleted + job.wait() + + job = client.jobs[0] + if force: + assert job.status in (JobStatus.COMPLETED, JobStatus.INTERRUPTED) + else: + assert job.status == JobStatus.COMPLETED + + # assert len(node.python_node.queue_manager.consumers["api_call"]) == 0 From 605afdddb8a658482acc712e3cf8dbfd1b8a9021 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 10:59:27 +0800 Subject: [PATCH 197/309] Add a timeout for the test --- packages/syft/tests/syft/syft_worker_deletion_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py index daed38ea2a4..d23627a9a98 100644 --- a/packages/syft/tests/syft/syft_worker_deletion_test.py +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -1,6 +1,7 @@ # stdlib from collections.abc import Generator from secrets import token_hex +import time from typing import Any # third party @@ -62,9 +63,12 @@ def compute_mean(data): job = client.code.compute_mean(data=data_pointer, blocking=False) + start = time.time() while True: if (syft_worker_id := client.jobs.get_all()[0].job_worker_id) is not None: break + if time.time() - start > 5: + raise TimeoutError("Job did not get picked up by any worker.") res = client.worker.delete(syft_worker_id, force=force) assert not isinstance(res, SyftError) From 3d22ba60d891d991c1300206d026498d05c964cb Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 11:02:45 +0800 Subject: [PATCH 198/309] Persist SyftWorker.to_be_deleted to the stash --- packages/syft/src/syft/service/queue/zmq_queue.py | 4 ++-- packages/syft/src/syft/service/worker/worker_pool.py | 5 +++-- .../syft/src/syft/service/worker/worker_service.py | 12 ++++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index e4368ee9882..b2481b29f1c 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -431,7 +431,7 @@ def purge_workers(self) -> None: logger.info("Failed to retrieve SyftWorker {worker.syft_worker_id}") continue - if worker.has_expired() or syft_worker._to_be_deleted: + if worker.has_expired() or syft_worker.to_be_deleted: logger.info( "Deleting expired Worker id={} uid={} expiry={} now={}", worker.identity, @@ -439,7 +439,7 @@ def purge_workers(self) -> None: worker.get_expiry(), Timeout.now(), ) - self.delete_worker(worker, syft_worker._to_be_deleted) + self.delete_worker(worker, syft_worker.to_be_deleted) # relative from ...service.worker.worker_service import WorkerService diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index c591e494431..d7e6c63c560 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -14,6 +14,7 @@ from ...types.base import SyftBaseModel from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.syft_object import short_uid from ...types.uid import UID @@ -49,7 +50,7 @@ class WorkerHealth(Enum): @serializable() class SyftWorker(SyftObject): __canonical_name__ = "SyftWorker" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 __attr_unique__ = ["name"] __attr_searchable__ = ["name", "container_id"] @@ -73,7 +74,7 @@ class SyftWorker(SyftObject): worker_pool_name: str consumer_state: ConsumerState = ConsumerState.DETACHED job_id: UID | None = None - _to_be_deleted: bool = False + to_be_deleted: bool = False @property def logs(self) -> str | SyftError: diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 689d6f8b94e..5c635284a3a 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -158,7 +158,11 @@ def _delete( if force and worker.job_id is not None: job_service = cast(JobService, context.node.get_service(JobService)) - job_service.kill(context=context, id=worker.job_id) + res = job_service.kill(context=context, id=worker.job_id) + if isinstance(res, SyftError): + return SyftError( + message=f"Failed to terminate the job associated with worker {uid}: {res.message}" + ) worker_pool_service = cast( SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService) @@ -241,7 +245,11 @@ def delete( force: bool = False, ) -> SyftSuccess | SyftError: worker = self._get_worker(context=context, uid=uid) - worker._to_be_deleted = True + worker.to_be_deleted = True + + res = self.stash.update(context.credentials, worker) + if isinstance(res, SyftError): + return res if not force: # relative From 6565444ac04f920f54b8907867962bcbb7d6b041 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 2 Jul 2024 14:07:26 +0530 Subject: [PATCH 199/309] remove user code version 6 and merge its changes to user code version 5 - restage protocol --- .../src/syft/protocol/protocol_version.json | 12 +--- .../syft/src/syft/service/code/user_code.py | 56 ++++--------------- 2 files changed, 13 insertions(+), 55 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 5054b847ee4..1a1a53bffd9 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -281,19 +281,9 @@ } }, "UserCode": { - "4": { - "version": 4, - "hash": "0a7181cd5f76800b6566175ffa7276d0cf38c4ddc5110114430147dfc8bfdb2a", - "action": "remove" - }, "5": { "version": 5, - "hash": "128705a5fdf308055ef857b25c80966c928938a05ec03459dae9b36bd6122aa2", - "action": "add" - }, - "6": { - "version": 6, - "hash": "c48ec3160bb34adf937e6306523c7ebc52861ff84a576a30a28cd45c224ded0f", + "hash": "c2409f51bf920cce557d288c40b6964ec4df3d8c23e33c5d5668addc30368632", "action": "add" } }, diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index d2a27abc861..0f8d303da2a 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -57,7 +57,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SYFT_OBJECT_VERSION_5 -from ...types.syft_object import SYFT_OBJECT_VERSION_6 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject from ...types.transforms import TransformContext @@ -271,6 +270,7 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return [self.user_code_link.object_uid] +@serializable() class UserCodeV4(SyncableSyftObject): # version __canonical_name__ = "UserCode" @@ -302,46 +302,11 @@ class UserCodeV4(SyncableSyftObject): worker_pool_name: str | None = None -@serializable() -class UserCodeV5(SyncableSyftObject): - # version - __canonical_name__ = "UserCode" - __version__ = SYFT_OBJECT_VERSION_5 - - id: UID - node_uid: UID | None = None - user_verify_key: SyftVerifyKey - raw_code: str - input_policy_type: type[InputPolicy] | UserPolicy - input_policy_init_kwargs: dict[Any, Any] | None = None - input_policy_state: bytes = b"" - output_policy_type: type[OutputPolicy] | UserPolicy - output_policy_init_kwargs: dict[Any, Any] | None = None - output_policy_state: bytes = b"" - parsed_code: str - service_func_name: str - unique_func_name: str - user_unique_func_name: str - code_hash: str - signature: inspect.Signature - status_link: LinkedObject | None = None - input_kwargs: list[str] - enclave_metadata: EnclaveMetadata | None = None - submit_time: DateTime | None = None - # tracks if the code calls domain.something, variable is set during parsing - uses_domain: bool = False - - nested_codes: dict[str, tuple[LinkedObject, dict]] | None = {} - worker_pool_name: str | None = None - origin_node_side_type: NodeSideType - l0_deny_reason: str | None = None - - @serializable() class UserCode(SyncableSyftObject): # version __canonical_name__ = "UserCode" - __version__ = SYFT_OBJECT_VERSION_6 + __version__ = SYFT_OBJECT_VERSION_5 id: UID node_uid: UID | None = None @@ -1924,24 +1889,27 @@ def migrate_usercode_v4_to_v5() -> list[Callable]: return [ make_set_default("origin_node_side_type", NodeSideType.HIGH_SIDE), make_set_default("l0_deny_reason", None), + drop("enclave_metadata"), ] @migrate(UserCode, UserCodeV4) def migrate_usercode_v5_to_v4() -> list[Callable]: return [ - drop("origin_node_side_type"), - drop("l0_deny_reason"), + drop(["origin_node_side_type", "l0_deny_reason"]), + make_set_default("enclave_metadata", None), ] -@migrate(UserCodeV5, UserCode) -def migrate_usercode_v5_to_v6() -> list[Callable]: +@migrate(SubmitUserCodeV4, SubmitUserCode) +def upgrade_submitusercode() -> list[Callable]: return [ drop("enclave_metadata"), ] -@migrate(UserCode, UserCodeV5) -def migrate_usercode_v6_to_v5() -> list[Callable]: - return [make_set_default("enclave_metadata", None)] +@migrate(SubmitUserCode, SubmitUserCodeV4) +def downgrade_submitusercode() -> list[Callable]: + return [ + make_set_default("enclave_metadata", None), + ] From b1a8adb2387f715d39464de0d820dbb277e8a878 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 16:46:15 +0800 Subject: [PATCH 200/309] Add admin deletion test case to admin permission test Co-authored-by: Thiago Costa Porto --- packages/syft/tests/syft/users/user_code_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 1719f826d43..ff0c2da101a 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -47,10 +47,12 @@ def test_repr_markdown_not_throwing_error(guest_client: DomainClient) -> None: assert result[0]._repr_markdown_() +@pytest.mark.parametrize("delete_original_admin", [False, True]) def test_new_admin_can_list_user_code( worker: Worker, ds_client: DomainClient, faker: Faker, + delete_original_admin: bool, ) -> None: root_client = worker.root_client @@ -69,6 +71,10 @@ def test_new_admin_can_list_user_code( admin.me.id, UserUpdate(role=ServiceRole.ADMIN) ) + if delete_original_admin: + res = root_client.api.services.user.delete(root_client.me.id) + assert not isinstance(res, SyftError) + assert len(root_client.code.get_all()) == len(admin.code.get_all()) assert {c.id for c in root_client.code} == {c.id for c in admin.code} From d48ef83eb684de2e789276daa51ddeeae92c9e19 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 18:11:39 +0800 Subject: [PATCH 201/309] Add unit tests for new admin action object permission Co-authored-by: Thiago Costa Porto --- packages/syft/tests/syft/action_test.py | 45 +++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index 1f4dc0cc36b..79a6d115644 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -1,11 +1,19 @@ +# stdlib +import uuid + # third party +from faker import Faker import numpy as np +import pytest # syft absolute from syft import ActionObject from syft.client.api import SyftAPICall +from syft.node.worker import Worker from syft.service.action.action_object import Action from syft.service.response import SyftError +from syft.service.user.user import UserUpdate +from syft.service.user.user_roles import ServiceRole from syft.types.uid import LineageID # relative @@ -24,6 +32,43 @@ def test_actionobject_method(worker): assert res[0] == "A" +@pytest.mark.parametrize("delete_original_admin", [False, True]) +def test_new_admin_has_action_object_permission( + worker: Worker, + faker: Faker, + delete_original_admin: bool, +) -> None: + root_client = worker.root_client + + email = uuid.uuid4().hex[:6] + faker.email() # avoid collision + pw = uuid.uuid4().hex + root_client.register( + name=faker.name(), email=email, password=pw, password_verify=pw + ) + ds_client = root_client.login(email=email, password=pw) + + obj = ActionObject.from_obj("abc") + obj.send(ds_client) + + email = faker.email() + pw = uuid.uuid4().hex + root_client.register( + name=faker.name(), email=email, password=pw, password_verify=pw + ) + + admin = root_client.login(email=email, password=pw) + + root_client.api.services.user.update( + admin.me.id, UserUpdate(role=ServiceRole.ADMIN) + ) + + if delete_original_admin: + res = root_client.api.services.user.delete(root_client.me.id) + assert not isinstance(res, SyftError) + + assert admin.api.services.action.get(obj.id) == obj + + @currently_fail_on_python_3_12(raises=AttributeError) def test_lib_function_action(worker): root_domain_client = worker.root_client From fc852498e7629ca3ab8ab6384c3f91ce779faf45 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 18:12:40 +0800 Subject: [PATCH 202/309] Add unit tests for new admin action object permission Co-authored-by: Thiago Costa Porto --- .../tests/syft/stores/action_store_test.py | 11 +++- .../tests/syft/stores/store_constants_test.py | 6 ++ .../tests/syft/stores/store_fixtures_test.py | 60 ++++++++++++++++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 0cabe78ef84..613cefa16d7 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -1,5 +1,4 @@ # stdlib -import sys from typing import Any # third party @@ -14,6 +13,7 @@ from syft.types.uid import UID # relative +from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN from .store_constants_test import TEST_VERIFY_KEY_STRING_CLIENT from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT @@ -53,22 +53,23 @@ def test_action_store_sanity(store: Any): ], ) @pytest.mark.parametrize("permission", permissions) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -@pytest.mark.skipif(sys.platform == "darwin", reason="skip on mac") def test_action_store_test_permissions(store: Any, permission: Any): client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) hacker_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) + new_admin_key = TEST_VERIFY_KEY_NEW_ADMIN access = permission(uid=UID(), credentials=client_key) access_root = permission(uid=UID(), credentials=root_key) access_hacker = permission(uid=UID(), credentials=hacker_key) + access_new_admin = permission(uid=UID(), credentials=new_admin_key) # add permission store.add_permission(access) assert store.has_permission(access) assert store.has_permission(access_root) + assert store.has_permission(access_new_admin) assert not store.has_permission(access_hacker) # remove permission @@ -76,6 +77,7 @@ def test_action_store_test_permissions(store: Any, permission: Any): assert not store.has_permission(access) assert store.has_permission(access_root) + assert store.has_permission(access_new_admin) assert not store.has_permission(access_hacker) # take ownership with new UID @@ -85,6 +87,7 @@ def test_action_store_test_permissions(store: Any, permission: Any): store.take_ownership(client_uid2, client_key) assert store.has_permission(access) assert store.has_permission(access_root) + assert store.has_permission(access_new_admin) assert not store.has_permission(access_hacker) # delete UID as hacker @@ -95,12 +98,14 @@ def test_action_store_test_permissions(store: Any, permission: Any): assert res.is_err() assert store.has_permission(access) + assert store.has_permission(access_new_admin) assert store.has_permission(access_hacker_ro) # delete UID as owner res = store.delete(client_uid2, client_key) assert res.is_ok() assert not store.has_permission(access) + assert store.has_permission(access_new_admin) assert not store.has_permission(access_hacker) diff --git a/packages/syft/tests/syft/stores/store_constants_test.py b/packages/syft/tests/syft/stores/store_constants_test.py index ba9910bb652..4de8f7df45a 100644 --- a/packages/syft/tests/syft/stores/store_constants_test.py +++ b/packages/syft/tests/syft/stores/store_constants_test.py @@ -1,3 +1,6 @@ +# syft absolute +from syft.node.credentials import SyftSigningKey + TEST_VERIFY_KEY_STRING_ROOT = ( "08e5bcddfd55cdff0f7f6a62d63a43585734c6e7a17b2ffb3f3efe322c3cecc5" ) @@ -7,3 +10,6 @@ TEST_VERIFY_KEY_STRING_HACKER = ( "8f4412396d3418d17c08a8f46592621a5d57e0daf1c93e2134c30f50d666801d" ) + +TEST_SIGNING_KEY_NEW_ADMIN = SyftSigningKey.generate() +TEST_VERIFY_KEY_NEW_ADMIN = TEST_SIGNING_KEY_NEW_ADMIN.verify_key diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index a10ace6f25b..a1a9cc0a224 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -4,19 +4,27 @@ from pathlib import Path from secrets import token_hex import tempfile +import uuid # third party import pytest # syft absolute from syft.node.credentials import SyftVerifyKey +from syft.service.action.action_permissions import ActionObjectPermission +from syft.service.action.action_permissions import ActionPermission from syft.service.action.action_store import DictActionStore from syft.service.action.action_store import MongoActionStore from syft.service.action.action_store import SQLiteActionStore from syft.service.queue.queue_stash import QueueStash +from syft.service.user.user import User +from syft.service.user.user import UserCreate +from syft.service.user.user_roles import ServiceRole +from syft.service.user.user_stash import UserStash from syft.store.dict_document_store import DictDocumentStore from syft.store.dict_document_store import DictStoreConfig from syft.store.dict_document_store import DictStorePartition +from syft.store.document_store import DocumentStore from syft.store.document_store import PartitionSettings from syft.store.locks import LockingConfig from syft.store.locks import NoLockingConfig @@ -32,6 +40,8 @@ from syft.types.uid import UID # relative +from .store_constants_test import TEST_SIGNING_KEY_NEW_ADMIN +from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT from .store_mocks_test import MockObjectType @@ -52,6 +62,38 @@ def str_to_locking_config(conf: str) -> LockingConfig: raise NotImplementedError(f"unknown locking config {conf}") +def document_store_with_admin( + node_uid: UID, verify_key: SyftVerifyKey +) -> DocumentStore: + document_store = DictDocumentStore(node_uid=node_uid, root_verify_key=verify_key) + + password = uuid.uuid4().hex + + user_stash = UserStash(store=document_store) + admin_user = UserCreate( + email="mail@example.org", + name="Admin", + password=password, + password_verify=password, + role=ServiceRole.ADMIN, + ).to(User) + + admin_user.signing_key = TEST_SIGNING_KEY_NEW_ADMIN + admin_user.verify_key = TEST_VERIFY_KEY_NEW_ADMIN + + user_stash.set( + credentials=verify_key, + user=admin_user, + add_permissions=[ + ActionObjectPermission( + uid=admin_user.id, permission=ActionPermission.ALL_READ + ), + ], + ) + + return document_store + + @pytest.fixture(scope="function") def sqlite_workspace() -> Generator: sqlite_db_name = token_hex(8) + ".sqlite" @@ -171,10 +213,15 @@ def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): ) ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) + + node_uid = UID() + document_store = document_store_with_admin(node_uid, ver_key) + yield SQLiteActionStore( - node_uid=UID(), + node_uid=node_uid, store_config=store_config, root_verify_key=ver_key, + document_store=document_store, ) @@ -278,10 +325,13 @@ def mongo_action_store(mongo_client, request): client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config ) ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) + node_uid = UID() + document_store = document_store_with_admin(node_uid, ver_key) mongo_action_store = MongoActionStore( - node_uid=UID(), + node_uid=node_uid, store_config=store_config, root_verify_key=ver_key, + document_store=document_store, ) yield mongo_action_store @@ -315,10 +365,14 @@ def dict_action_store(request): store_config = DictStoreConfig(locking_config=locking_config) ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) + node_uid = UID() + document_store = document_store_with_admin(node_uid, ver_key) + yield DictActionStore( - node_uid=UID(), + node_uid=node_uid, store_config=store_config, root_verify_key=ver_key, + document_store=document_store, ) From 969e78a5893f5b0469d3d3af83cb3357de663c0d Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 18:13:28 +0800 Subject: [PATCH 203/309] Add action object permission for new admins Co-authored-by: Thiago Costa Porto --- packages/syft/src/syft/node/node.py | 3 +++ .../src/syft/service/action/action_store.py | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 04851c2b17f..deaba297b9b 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -887,6 +887,7 @@ def init_stores( node_uid=self.id, store_config=action_store_config, root_verify_key=self.verify_key, + document_store=self.document_store, ) elif isinstance(action_store_config, MongoStoreConfig): # We add the python id of the current node in order @@ -899,11 +900,13 @@ def init_stores( node_uid=self.id, root_verify_key=self.verify_key, store_config=action_store_config, + document_store=self.document_store, ) else: self.action_store = DictActionStore( node_uid=self.id, root_verify_key=self.verify_key, + document_store=self.document_store, ) self.action_store_config = action_store_config diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 001aa7a4e0f..b1a25fe23a2 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -15,6 +15,7 @@ from ...serde.serializable import serializable from ...store.dict_document_store import DictStoreConfig from ...store.document_store import BasePartitionSettings +from ...store.document_store import DocumentStore from ...store.document_store import StoreConfig from ...types.syft_object import SyftObject from ...types.twin_object import TwinObject @@ -53,6 +54,7 @@ def __init__( node_uid: UID, store_config: StoreConfig, root_verify_key: SyftVerifyKey | None = None, + document_store: DocumentStore | None = None, ) -> None: self.node_uid = node_uid self.store_config = store_config @@ -71,6 +73,8 @@ def __init__( root_verify_key = SyftSigningKey.generate().verify_key self.root_verify_key = root_verify_key + self.__document_store = document_store + def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False ) -> Result[SyftObject, str]: @@ -234,6 +238,25 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: ): return True + if self.__document_store is not None: + # relative + from ...service.user.user_roles import ServiceRole + from ...service.user.user_stash import UserStash + + user_stash = UserStash(store=self.__document_store) + + res = user_stash.get_by_verify_key( + credentials=permission.credentials, + verify_key=permission.credentials, + ) + + if ( + res.is_ok() + and (user := res.ok()) is not None + and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) + ): + return True + if ( permission.uid in self.permissions and permission.permission_string in self.permissions[permission.uid] @@ -346,12 +369,14 @@ def __init__( node_uid: UID, store_config: StoreConfig | None = None, root_verify_key: SyftVerifyKey | None = None, + document_store: DocumentStore | None = None, ) -> None: store_config = store_config if store_config is not None else DictStoreConfig() super().__init__( node_uid=node_uid, store_config=store_config, root_verify_key=root_verify_key, + document_store=document_store, ) From 716b415325701605cd4dfdf9bedf51675df737a4 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 2 Jul 2024 19:39:04 +0800 Subject: [PATCH 204/309] Init UserStash at ActionStore init instead of every operation --- .../syft/src/syft/service/action/action_store.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index b1a25fe23a2..c44e3471bdb 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -73,7 +73,12 @@ def __init__( root_verify_key = SyftSigningKey.generate().verify_key self.root_verify_key = root_verify_key - self.__document_store = document_store + self.__user_stash = None + if document_store is not None: + # relative + from ...service.user.user_stash import UserStash + + self.__user_stash = UserStash(store=document_store) def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False @@ -238,14 +243,11 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: ): return True - if self.__document_store is not None: + if self.__user_stash is not None: # relative from ...service.user.user_roles import ServiceRole - from ...service.user.user_stash import UserStash - - user_stash = UserStash(store=self.__document_store) - res = user_stash.get_by_verify_key( + res = self.__user_stash.get_by_verify_key( credentials=permission.credentials, verify_key=permission.credentials, ) From 09bf4d89c04b509a02446aab7895d03f0d9c9186 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 15:12:19 +0200 Subject: [PATCH 205/309] automatically convert UIDs to clipboard items in tables --- .../notebook_ui/components/tabulator_template.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 24b65f0b8f0..3d860133428 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -10,6 +10,7 @@ import jinja2 # relative +from ....types.uid import UID from ...assets import load_css from ...assets import load_js from ...table import TABLE_INDEX_KEY @@ -82,6 +83,13 @@ def format_dict(data: Any) -> str: return sanitize_html(str(data)) +def format_uid(uid: UID) -> str: + # relative + from .sync import CopyButton + + return CopyButton(copy_text=uid.no_dash).to_html() + + def format_table_data(table_data: list[dict[str, Any]]) -> list[dict[str, str]]: formatted: list[dict[str, str]] = [] for row in table_data: @@ -90,7 +98,11 @@ def format_table_data(table_data: list[dict[str, Any]]) -> list[dict[str, str]]: if isinstance(v, str): row_formatted[k] = sanitize_html(v.replace("\n", "
")) continue - v_formatted = format_dict(v) + # make UID copyable and trimmed + if isinstance(v, UID): + v_formatted = format_uid(v) + else: + v_formatted = format_dict(v) row_formatted[k] = v_formatted formatted.append(row_formatted) return formatted From 68ed1add485f3703f22c0f7f9dfa5cebad355abd Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 15:12:56 +0200 Subject: [PATCH 206/309] add UserCode._coll_repr_ --- .../syft/src/syft/service/dataset/dataset.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index a483cfa2751..8a1446d167d 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -105,8 +105,29 @@ class Asset(SyftObject): created_at: DateTime = DateTime.now() uploader: Contributor | None = None - __repr_attrs__ = ["_kwarg_name", "name", "action_id", "_dataset_name", "node_uid"] - __clipboard_attrs__ = ["action_id", "node_uid", "_dataset_name"] + # _kwarg_name and _dataset_name are set by the UserCode.assets + _kwarg_name: str | None = None + _dataset_name: str | None = None + + __clipboard_attrs__ = ["Action ID", "Node UID"] + __syft_include_id_coll_repr__ = False + + def _coll_repr_(self) -> dict[str, Any]: + base_dict = { + "Asset Name": self.name, + "Action ID": self.action_id, + "Node UID": self.node_uid, + } + + # _kwarg_name and _dataset_name are set by the UserCode.assets + if self._kwarg_name and self._dataset_name: + base_dict.update( + { + "Parameter": self._kwarg_name, + "Dataset Name": self._dataset_name, + } + ) + return base_dict def __init__( self, From 21a30fd90b74bdb6bef8af86d69f1c342fa5a7b3 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 15:36:52 +0200 Subject: [PATCH 207/309] clean up --- .../syft/src/syft/service/code/user_code.py | 12 ++--- .../syft/src/syft/service/dataset/dataset.py | 44 +++++++++++-------- packages/syft/src/syft/util/table.py | 9 +--- .../syft/tests/syft/users/user_code_test.py | 2 +- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index fb4875460c1..393021a7e9f 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -816,14 +816,10 @@ def assets(self) -> list[Asset] | SyftError: def _asset_json(self) -> str | SyftError: if isinstance(self.assets, SyftError): return self.assets - used_assets = {} - for asset in self.assets: - used_assets[asset._kwarg_name] = { - "source_dataset": asset._dataset_name, - "source_asset": asset.name, - "action_id": asset.action_id.no_dash, - "source_node": asset.node_uid.no_dash, - } + used_assets = { + asset._kwarg_name: asset._get_dict_for_user_code_repr() + for asset in self.assets + } asset_str = json.dumps(used_assets, indent=2) return asset_str diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 8a1446d167d..a8bfdfc62cc 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -108,27 +108,8 @@ class Asset(SyftObject): # _kwarg_name and _dataset_name are set by the UserCode.assets _kwarg_name: str | None = None _dataset_name: str | None = None - - __clipboard_attrs__ = ["Action ID", "Node UID"] __syft_include_id_coll_repr__ = False - def _coll_repr_(self) -> dict[str, Any]: - base_dict = { - "Asset Name": self.name, - "Action ID": self.action_id, - "Node UID": self.node_uid, - } - - # _kwarg_name and _dataset_name are set by the UserCode.assets - if self._kwarg_name and self._dataset_name: - base_dict.update( - { - "Parameter": self._kwarg_name, - "Dataset Name": self._dataset_name, - } - ) - return base_dict - def __init__( self, description: MarkdownDescription | str | None = "", @@ -214,6 +195,31 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str += f"\t{contributor.name}: {contributor.email}\n" return as_markdown_python_code(_repr_str) + def _coll_repr_(self) -> dict[str, Any]: + base_dict = { + "Asset Name": self.name, + "Action ID": self.action_id, + "Node UID": self.node_uid, + } + + # _kwarg_name and _dataset_name are set by the UserCode.assets + if self._kwarg_name and self._dataset_name: + base_dict.update( + { + "Parameter": self._kwarg_name, + "Dataset Name": self._dataset_name, + } + ) + return base_dict + + def _get_dict_for_user_code_repr(self) -> dict[str, Any]: + return { + "source_dataset": self._dataset_name, + "source_asset": self.name, + "action_id": self.action_id.no_dash, + "source_node": self.node_uid.no_dash, + } + def __eq__(self, other: object) -> bool: if not isinstance(other, Asset): return False diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index a592837769f..9f479f15314 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -135,14 +135,7 @@ def _create_table_rows( except Exception as e: print(e) value = None - - if field in getattr(item, "__clipboard_attrs__", []): - value = { - "value": sanitize_html(str(value)), - "type": "clipboard", - } - else: - value = sanitize_html(str(value)) + value = sanitize_html(str(value)) cols[field].append(value) col_lengths = {len(cols[col]) for col in cols.keys()} diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index a24897c62a8..69f69ab76d4 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -166,7 +166,7 @@ def func(asset): c for c in request.changes if (isinstance(c, UserCodeStatusChange)) ) - assert status_change.code.assets["asset"].model_dump( + assert status_change.code.assets[0].model_dump( mode="json" ) == asset_input.model_dump(mode="json") From 2d1ad91dbe652e9dcf97fb838aedea02c1214b71 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 15:41:33 +0200 Subject: [PATCH 208/309] revert chage --- packages/syft/src/syft/service/code/user_code.py | 1 + packages/syft/src/syft/util/table.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 97cf123c783..50b908803f0 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -883,6 +883,7 @@ def _inner_repr(self, level: int = 0) -> str: {shared_with_line} assets: dict = {asset_str} code: + {self.raw_code} """ if self.nested_codes != {}: diff --git a/packages/syft/src/syft/util/table.py b/packages/syft/src/syft/util/table.py index 9f479f15314..611ca5e33a2 100644 --- a/packages/syft/src/syft/util/table.py +++ b/packages/syft/src/syft/util/table.py @@ -135,8 +135,7 @@ def _create_table_rows( except Exception as e: print(e) value = None - value = sanitize_html(str(value)) - cols[field].append(value) + cols[field].append(sanitize_html(str(value))) col_lengths = {len(cols[col]) for col in cols.keys()} if len(col_lengths) != 1: From c0007b97707a04e29a2fc4c5ce8df6afa147ee5f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 2 Jul 2024 16:08:12 +0200 Subject: [PATCH 209/309] fix lint+mypy --- .../0-prepare-migration-data.ipynb | 2 +- .../1b-connect-and-migrate-via-api.ipynb | 1 + .../1c-migrate-to-new-node.ipynb | 3192 +---------------- .../2-post-migration-tests.ipynb | 99 +- .../syft/src/syft/client/domain_client.py | 6 +- packages/syft/src/syft/node/node.py | 4 +- packages/syft/src/syft/orchestra.py | 12 +- .../syft/src/syft/protocol/data_protocol.py | 16 +- .../syft/service/code_history/code_history.py | 6 +- .../syft/src/syft/service/job/job_service.py | 6 +- .../service/migration/migration_service.py | 18 +- .../syft/service/notifier/notifier_service.py | 2 +- packages/syft/src/syft/service/response.py | 2 + tox.ini | 6 +- 14 files changed, 257 insertions(+), 3115 deletions(-) diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 94e9a9f5bee..2b5026016df 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -251,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb index b0d08cbc21b..7c23086fac9 100644 --- a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb +++ b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb @@ -123,6 +123,7 @@ "# this wont work in the cases where the context is actually used,\n", "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", "\n", + "\n", "context = Context()\n", "migrated_objects = []\n", "for klass, objects in migration_dict.items():\n", diff --git a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb index 504b6cd0695..bac198022da 100644 --- a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb +++ b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "0", "metadata": {}, "outputs": [], @@ -10,27 +10,18 @@ "# syft absolute\n", "import syft as sy\n", "from syft.service.log.log import SyftLogV3\n", + "from syft.types.blob_storage import BlobStorageEntry\n", + "from syft.types.blob_storage import CreateBlobStorageEntry\n", "from syft.types.syft_object import Context\n", - "from syft.types.syft_object import SyftObject\n", - "from syft.service.user.user import User\n", - "\n", - "from syft.types.blob_storage import BlobStorageEntry, CreateBlobStorageEntry" + "from syft.types.syft_object import SyftObject" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "1", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "syft version: 0.8.7-beta.9\n" - ] - } - ], + "outputs": [], "source": [ "print(f\"syft version: {sy.__version__}\")" ] @@ -51,23 +42,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "3", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Staging Protocol Changes...\n", - "Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/db/21519e1e3e664b38a635dc951c293158.sqlite\n", - "Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/db/21519e1e3e664b38a635dc951c293158.sqlite\n", - "Migrating data for: NodeSettings table.\n", - "Creating default worker image with tag='local-dev'\n", - "Setting up worker poolname=default-pool workers=2 image_uid=a1f62fc8f2ac4e32a70a90260019f831 in_memory=True\n" - ] - } - ], + "outputs": [], "source": [ "node = sy.orchestra.launch(\n", " name=\"test_upgradability\",\n", @@ -81,30 +59,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logged into as \n" - ] - }, - { - "data": { - "text/html": [ - "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" - ], - "text/plain": [ - "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" ] @@ -119,36 +77,10 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "ef428d2d-62d8-4b14-a8c9-89d7cc4f6a8a", + "execution_count": null, + "id": "6", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Staging Protocol Changes...\n", - "Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/db/a3edccc59f384307aeaa1a50714c2300.sqlite\n", - "Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/db/a3edccc59f384307aeaa1a50714c2300.sqlite\n", - "Creating default worker image with tag='local-dev'\n", - "Setting up worker poolname=default-pool workers=2 image_uid=4ba80c1b73ae44819a22f9c4811230c2 in_memory=True\n", - "Created default worker pool.\n", - "Logged into as \n" - ] - }, - { - "data": { - "text/html": [ - "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" - ], - "text/plain": [ - "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "temp_node = sy.orchestra.launch(\n", " name=\"temp_node\",\n", @@ -165,7 +97,7 @@ }, { "cell_type": "markdown", - "id": "6", + "id": "7", "metadata": {}, "source": [ "## document store objects" @@ -173,8 +105,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "7", + "execution_count": null, + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -183,87 +115,18 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "8", + "execution_count": null, + "id": "9", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{syft.service.queue.queue_stash.ActionQueueItem: [: completed],\n", - " syft.service.user.user.User: [syft.service.user.user.User,\n", - " syft.service.user.user.User],\n", - " syft.service.worker.worker_pool.SyftWorker: [syft.service.worker.worker_pool.SyftWorker,\n", - " syft.service.worker.worker_pool.SyftWorker],\n", - " syft.service.settings.settings.NodeSettings: [syft.service.settings.settings.NodeSettings],\n", - " syft.service.dataset.dataset.Dataset: [syft.service.dataset.dataset.Dataset],\n", - " syft.service.code.user_code.UserCode: [syft.service.code.user_code.UserCode],\n", - " syft.service.log.log.SyftLog: [syft.service.log.log.SyftLog],\n", - " syft.service.request.request.Request: [syft.service.request.request.Request],\n", - " syft.service.job.job_stash.Job: [syft.service.job.job_stash.Job],\n", - " syft.service.notifier.notifier.NotifierSettings: [syft.service.notifier.notifier.NotifierSettings],\n", - " syft.service.notification.notifications.Notification: [syft.service.notification.notifications.Notification,\n", - " syft.service.notification.notifications.Notification,\n", - " syft.service.notification.notifications.Notification],\n", - " syft.service.code_history.code_history.CodeHistory: [syft.service.code_history.code_history.CodeHistory],\n", - " syft.types.blob_storage.BlobStorageEntry: [syft.types.blob_storage.BlobStorageEntry,\n", - " syft.types.blob_storage.BlobStorageEntry,\n", - " syft.types.blob_storage.BlobStorageEntry],\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState: [syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState,\n", - " syft.service.migration.object_migration_state.SyftObjectMigrationState],\n", - " syft.service.worker.worker_image.SyftWorkerImage: [syft.service.worker.worker_image.SyftWorkerImage,\n", - " syft.service.worker.worker_image.SyftWorkerImage],\n", - " syft.service.worker.worker_pool.WorkerPool: [syft.service.worker.worker_pool.WorkerPool],\n", - " syft.service.output.output_service.ExecutionOutput: [syft.service.output.output_service.ExecutionOutput],\n", - " syft.service.code.user_code.UserCodeStatusCollection: [{NodeIdentity : (, '')}]}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "migration_dict" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "9", + "execution_count": null, + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -274,14 +137,15 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "10", + "execution_count": null, + "id": "11", "metadata": {}, "outputs": [], "source": [ "# this wont work in the cases where the context is actually used,\n", "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", "\n", + "\n", "context = Context()\n", "migrated_objects = []\n", "for klass, objects in migration_dict.items():\n", @@ -291,15 +155,20 @@ " elif isinstance(obj, SyftLogV3):\n", " migrated_obj = custom_migration_function(context, obj, klass)\n", " else:\n", - " migrated_obj = obj.migrate_to(klass.__version__, context)\n", + " try:\n", + " migrated_obj = obj.migrate_to(klass.__version__, context)\n", + " except Exception:\n", + " print(obj.__version__, obj.__canonical_name__)\n", + " print(klass.__version__, klass.__canonical_name__)\n", + " raise\n", "\n", " migrated_objects.append(migrated_obj)" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "11", + "execution_count": null, + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -310,64 +179,28 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "12", + "execution_count": null, + "id": "13", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " #6e81a99a5c264407abee9e635f36ad9e already exists\n", - " #00bd74c6030e4af0b8844e00da8185c2 already exists\n", - " #22c6b28b12e94052aaf1adf4adc79741 already exists\n", - " #9ca8930bc5b04aa1a7b1835d69b8e4e2 already exists\n", - " #7ec16f3a86a442179f69ef4640f16c10 already exists\n", - " #2611c202fe984bec80cadfc482be3cb8 already exists\n", - " #ba943fd900a44527ba5d155128525072 already exists\n", - " #dc4f177789974647b6ea47657cf3fc76 already exists\n", - " #fe51de3d24ac482ba369fc14b9bda9f9 already exists\n", - " #aed2854dca67456196f91b68da949e5a already exists\n", - " #ab1fcea789344984814f8a4bcea7cd02 already exists\n", - " #77a3e8772dae4aa288b468d96e799364 already exists\n", - " #160150b1cee64df0830f00a1bee8d857 already exists\n", - " #a1f62fc8f2ac4e32a70a90260019f831 already exists\n", - " #3c32960a61c7471991de0d88b28a662e already exists\n" - ] - } - ], + "outputs": [], "source": [ "res = temp_client.services.migration.create_migrated_objects(migrated_objects)" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "3db3dc71-c06c-4b48-b884-dde1ff6d1838", + "execution_count": null, + "id": "14", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
SyftSuccess: success

" - ], - "text/plain": [ - "SyftSuccess: success" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "res" ] }, { "cell_type": "code", - "execution_count": 13, - "id": "13", + "execution_count": null, + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -376,7 +209,7 @@ }, { "cell_type": "markdown", - "id": "b95366c4-78f8-484d-b53f-c6eb5d4f5c1c", + "id": "16", "metadata": {}, "source": [ "# Migrate blobstorage" @@ -385,35 +218,17 @@ { "cell_type": "code", "execution_count": null, - "id": "3acabe3c-d164-4a5c-b1c5-c0f676c00ec5", + "id": "17", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", - "execution_count": 14, - "id": "157231d6-c4e5-448a-9bde-190c8cb0df89", + "execution_count": null, + "id": "18", "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "```python\n", - "class BlobStorageEntry:\n", - " id: str = 5e9a51c12f664e39bcc6099981292e3f\n", - "\n", - "```" - ], - "text/plain": [ - "syft.types.blob_storage.BlobStorageEntry" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "klass = BlobStorageEntry\n", "blob_entries = migration_dict[klass]\n", @@ -423,21 +238,25 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "f8f06ced-8422-497d-a521-0be184017c53", + "execution_count": null, + "id": "19", "metadata": {}, "outputs": [], "source": [ + "# stdlib\n", "from io import BytesIO\n", "import sys\n", "\n", - "def migrate_blob_entry_data(old_client, new_client, obj, klass) -> sy.SyftSuccess | sy.SyftError:\n", + "\n", + "def migrate_blob_entry_data(\n", + " old_client, new_client, obj, klass\n", + ") -> sy.SyftSuccess | sy.SyftError:\n", " migrated_obj = obj.migrate_to(klass.__version__, Context())\n", " uploaded_by = migrated_obj.uploaded_by\n", " blob_retrieval = old_client.services.blob_storage.read(obj.id)\n", " if isinstance(blob_retrieval, sy.SyftError):\n", " return blob_retrieval\n", - " \n", + "\n", " data = blob_retrieval.read()\n", " # TODO do we have to determine new filesize here?\n", " serialized = sy.serialize(data, to_bytes=True)\n", @@ -445,57 +264,20 @@ " blob_create = CreateBlobStorageEntry.from_blob_storage_entry(obj)\n", " blob_create.file_size = size\n", "\n", - " blob_deposit_object = new_client.services.blob_storage.allocate_for_user(blob_create, uploaded_by)\n", + " blob_deposit_object = new_client.services.blob_storage.allocate_for_user(\n", + " blob_create, uploaded_by\n", + " )\n", " if isinstance(blob_deposit_object, sy.SyftError):\n", " return blob_deposit_object\n", - " return blob_deposit_object.write(BytesIO(serialized))\n", - " \n", - " " + " return blob_deposit_object.write(BytesIO(serialized))" ] }, { "cell_type": "code", - "execution_count": 16, - "id": "e108f228-ff8b-41ef-966e-421be7fd39a9", + "execution_count": null, + "id": "20", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
SyftSuccess: File successfully saved.

" - ], - "text/plain": [ - "SyftSuccess: File successfully saved." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
SyftSuccess: File successfully saved.

" - ], - "text/plain": [ - "SyftSuccess: File successfully saved." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
SyftSuccess: File successfully saved.

" - ], - "text/plain": [ - "SyftSuccess: File successfully saved." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "for blob_entry in blob_entries:\n", " res = migrate_blob_entry_data(client, temp_client, blob_entry, BlobStorageEntry)\n", @@ -504,2596 +286,17 @@ }, { "cell_type": "code", - "execution_count": 17, - "id": "e7f005ec-7a6d-4ff6-a2db-f82cf1bfaa17", + "execution_count": null, + "id": "21", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "\n", - "
\n", - "
\n", - " \n", - "
\n", - "

BlobStorageEntry List

\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "

Total: 0

\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n", - "" - ], - "text/plain": [ - "[syft.types.blob_storage.BlobStorageEntry,\n", - " syft.types.blob_storage.BlobStorageEntry,\n", - " syft.types.blob_storage.BlobStorageEntry]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "client.services.blob_storage.get_all()" ] }, - { - "cell_type": "code", - "execution_count": 18, - "id": "23b87c41-e3e9-42b6-814c-075ab2719cbe", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'syft_node_location': ,\n", - " 'syft_client_verify_key': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", - " 'id': ,\n", - " 'location': /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/21519e1e3e664b38a635dc951c293158/blob/5e9a51c12f664e39bcc6099981292e3f,\n", - " 'type_': numpy.ndarray,\n", - " 'mimetype': 'bytes',\n", - " 'file_size': 609,\n", - " 'no_lines': 0,\n", - " 'uploaded_by': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", - " 'created_at': syft.types.datetime.DateTime,\n", - " 'bucket_name': None}" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.services.blob_storage.get_all()[0].__dict__" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "8648ac13-7091-4dbe-9fab-efcc11b4d4ca", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'syft_node_location': ,\n", - " 'syft_client_verify_key': e296574092d9fe0bbd853b4f0294ca9bc6624ac16c3700da58eff07f69f477f2,\n", - " 'id': ,\n", - " 'location': /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/a3edccc59f384307aeaa1a50714c2300/blob/5e9a51c12f664e39bcc6099981292e3f,\n", - " 'type_': numpy.ndarray,\n", - " 'mimetype': 'bytes',\n", - " 'file_size': 689,\n", - " 'no_lines': 0,\n", - " 'uploaded_by': 119468ba6f6bd124351b4cbce80b12fdf0c443294151fd3ece3833eff825ad1e,\n", - " 'created_at': syft.types.datetime.DateTime,\n", - " 'bucket_name': None}" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "temp_client.services.blob_storage.get_all()[0].__dict__" - ] - }, { "cell_type": "markdown", - "id": "14", + "id": "22", "metadata": {}, "source": [ "## Actions and ActionObjects" @@ -3101,18 +304,20 @@ }, { "cell_type": "code", - "execution_count": 20, - "id": "15", + "execution_count": null, + "id": "23", "metadata": {}, "outputs": [], "source": [ - "migration_action_dict = client.services.migration.get_migration_actionobjects(get_all=True)" + "migration_action_dict = client.services.migration.get_migration_actionobjects(\n", + " get_all=True\n", + ")" ] }, { "cell_type": "code", - "execution_count": 21, - "id": "63c61536-4099-4eb5-ab62-db344a0c49e1", + "execution_count": null, + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -3121,151 +326,48 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "8b1599ea-d940-4415-9f9d-b36056d29e92", + "execution_count": null, + "id": "25", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([15, 16, 17, 18, 19])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ao.syft_action_data_cache" ] }, { "cell_type": "code", - "execution_count": 23, - "id": "a84a06be-be0e-4378-92c4-a08c55a90dfc", + "execution_count": null, + "id": "26", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{: Pointer:\n", - "array([15, 16, 17, 18, 19]), : Pointer:\n", - "array([15, 16, 17, 18, 19])}" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "node.python_node.action_store.data" ] }, { "cell_type": "code", - "execution_count": 27, - "id": "9fec28ad-5417-4e94-a26f-3cf9cc5d3412", + "execution_count": null, + "id": "27", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "client.jobs[0].result.id.id" ] }, { "cell_type": "code", - "execution_count": 26, - "id": "7d006054-0248-4901-bb13-6d27e6a0870b", + "execution_count": null, + "id": "28", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'syft_node_location': None,\n", - " 'syft_client_verify_key': None,\n", - " 'id': ,\n", - " 'syft_action_data_cache': array([15, 16, 17, 18, 19]),\n", - " 'syft_blob_storage_entry_id': None,\n", - " 'syft_parent_hashes': None,\n", - " 'syft_parent_op': None,\n", - " 'syft_parent_args': None,\n", - " 'syft_parent_kwargs': None,\n", - " 'syft_history_hash': 1494250481592695163,\n", - " 'syft_node_uid': None,\n", - " 'syft_pre_hooks__': {'ALWAYS': [ 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>,\n", - " 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>],\n", - " 'ON_POINTERS': [ 'Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]'>]},\n", - " 'syft_post_hooks__': {'ALWAYS': [ 'Result[Ok[Any], Err[str]]'>],\n", - " 'ON_POINTERS': []},\n", - " 'syft_twin_type': ,\n", - " 'syft_passthrough_attrs': ['is_mock',\n", - " 'is_real',\n", - " 'is_twin',\n", - " 'is_pointer',\n", - " 'request',\n", - " '__repr__',\n", - " '_repr_markdown_',\n", - " 'syft_twin_type',\n", - " '_repr_debug_',\n", - " 'as_empty',\n", - " 'get',\n", - " 'is_link',\n", - " 'wait',\n", - " '_save_to_blob_storage',\n", - " '_save_to_blob_storage_',\n", - " 'syft_action_data',\n", - " '__check_action_data',\n", - " 'as_empty_data',\n", - " '_set_obj_location_',\n", - " 'syft_action_data_cache',\n", - " 'reload_cache',\n", - " 'syft_resolved',\n", - " 'refresh_object',\n", - " 'syft_action_data_node_id',\n", - " 'node_uid',\n", - " '__sha256__',\n", - " '__hash_exclude_attrs__',\n", - " '__hash__',\n", - " 'create_shareable_sync_copy',\n", - " '_has_private_sync_attrs',\n", - " '__exclude_sync_diff_attrs__',\n", - " '__repr_attrs__'],\n", - " 'syft_action_data_type': numpy.ndarray,\n", - " 'syft_action_data_repr_': 'array([15, 16, 17, 18, 19])',\n", - " 'syft_action_data_str_': '[15 16 17 18 19]',\n", - " 'syft_has_bool_attr': True,\n", - " 'syft_resolve_data': None,\n", - " 'syft_created_at': syft.types.datetime.DateTime,\n", - " 'syft_resolved': True,\n", - " 'syft_action_data_node_id': None,\n", - " 'syft_dont_wrap_attrs': ['dtype', 'shape']}" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "node.python_node.action_store.data[sy.UID(\"d05d03da6ff44a57b1d48611e927a68a\")].__dict__" + "node.python_node.action_store.data[sy.UID(\"106b561961c74a46afc63c5c73c24212\")].__dict__" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "16", + "execution_count": null, + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -3282,62 +384,40 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "17", + "execution_count": null, + "id": "30", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Pointer:\n", - "array([15, 16, 17, 18, 19]), Pointer:\n", - "array([15, 16, 17, 18, 19])]\n" - ] - } - ], + "outputs": [], "source": [ "print(migrated_actionobjects)" ] }, { "cell_type": "code", - "execution_count": 32, - "id": "18", + "execution_count": null, + "id": "31", "metadata": {}, "outputs": [], "source": [ - "res = temp_client.services.migration.update_migrated_actionobjects(migrated_actionobjects)" + "res = temp_client.services.migration.update_migrated_actionobjects(\n", + " migrated_actionobjects\n", + ")" ] }, { "cell_type": "code", - "execution_count": 33, - "id": "e2e5064d-2ab8-4d2f-a13e-6215545ea118", + "execution_count": null, + "id": "32", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
SyftSuccess: succesfully migrated actionobjects

" - ], - "text/plain": [ - "SyftSuccess: succesfully migrated actionobjects" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "res" ] }, { "cell_type": "code", - "execution_count": 34, - "id": "19", + "execution_count": null, + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -3346,19 +426,10 @@ }, { "cell_type": "code", - "execution_count": 50, - "id": "2ef24ba9-3781-4e55-9c67-df68465ab080", + "execution_count": null, + "id": "34", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[15 16 17 18 19]\n", - "[15 16 17 18 19]\n" - ] - } - ], + "outputs": [], "source": [ "for uid in temp_node.python_node.action_store.data:\n", " ao = temp_client.services.action.get(uid)\n", @@ -3368,19 +439,10 @@ }, { "cell_type": "code", - "execution_count": 54, - "id": "ff2a5b44-37a7-4b49-a332-ca93ee71b94f", + "execution_count": null, + "id": "35", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[15 16 17 18 19]\n", - "[15 16 17 18 19]\n" - ] - } - ], + "outputs": [], "source": [ "for uid in node.python_node.action_store.data:\n", " ao = client.services.action.get(uid)\n", @@ -3390,7 +452,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "36", "metadata": {}, "source": [ "## Store metadata\n", @@ -3402,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "37", "metadata": {}, "outputs": [], "source": [ @@ -3413,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -3427,7 +489,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "39", "metadata": {}, "outputs": [], "source": [ @@ -3438,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -3448,7 +510,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -3458,15 +520,19 @@ " temp_perms = dict(temp_partition.permissions.items())\n", " real_perms = dict(real_partition.permissions.items())\n", "\n", - " # Only look at migrated items\n", - " temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", - " assert temp_perms == real_perms\n", + " for k, temp_v in temp_perms.items():\n", + " if k not in real_perms:\n", + " continue\n", + " real_v = real_perms[k]\n", + " assert real_v.issubset(temp_v)\n", "\n", " temp_storage = dict(temp_partition.storage_permissions.items())\n", " real_storage = dict(real_partition.storage_permissions.items())\n", - " temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", - "\n", - " assert temp_storage == real_storage\n", + " for k, temp_v in temp_storage.items():\n", + " if k not in real_storage:\n", + " continue\n", + " real_v = real_storage[k]\n", + " assert real_v.issubset(temp_v)\n", "\n", "# Action store\n", "real_partition = node.python_node.action_store\n", @@ -3476,19 +542,25 @@ "\n", "# Only look at migrated items\n", "temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", - "assert temp_perms == real_perms\n", + "for k, temp_v in temp_perms.items():\n", + " if k not in real_perms:\n", + " continue\n", + " real_v = real_perms[k]\n", + " assert real_v.issubset(temp_v)\n", "\n", "temp_storage = dict(temp_partition.storage_permissions.items())\n", "real_storage = dict(real_partition.storage_permissions.items())\n", - "temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", - "\n", - "assert temp_storage == real_storage" + "for k, temp_v in temp_storage.items():\n", + " if k not in real_storage:\n", + " continue\n", + " real_v = real_storage[k]\n", + " assert real_v.issubset(temp_v)" ] }, { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "42", "metadata": {}, "outputs": [], "source": [] diff --git a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb index f0d2c3546bb..010c021b6db 100644 --- a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb +++ b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb @@ -19,7 +19,7 @@ "outputs": [], "source": [ "node = sy.orchestra.launch(\n", - " name=\"test_upgradability\",\n", + " name=\"temp_node\",\n", " dev_mode=True,\n", " local_db=True,\n", " n_consumers=2,\n", @@ -66,6 +66,85 @@ "id": "5", "metadata": {}, "outputs": [], + "source": [ + "client.verify_key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "from syft.client.api import APIRegistry" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "APIRegistry.__api_registry__.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "node.python_node.id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "code = client.code.get_all()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "code.status_link.node_uid = node.python_node.id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "code.status" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], "source": [ "req1 = client.requests[0]\n", "req2 = client_ds.requests[0]\n", @@ -77,7 +156,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -88,7 +167,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -99,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -109,7 +188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +201,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -142,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -152,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -163,7 +242,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "22", "metadata": {}, "outputs": [], "source": [] diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 58325e6fded..2122bcb1b77 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -6,8 +6,8 @@ from pathlib import Path import re from string import Template -from typing import Any import traceback +from typing import Any from typing import TYPE_CHECKING from typing import cast @@ -17,9 +17,9 @@ from tqdm import tqdm # relative +from .. import deserialize +from .. import serialize from ..abstract_node import NodeSideType -from ..serde import deserialize -from ..serde import serialize from ..serde.serializable import serializable from ..service.action.action_object import ActionObject from ..service.code_history.code_history import CodeHistoriesDict diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index aef42aac847..424505b7a28 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -70,8 +70,8 @@ from ..service.job.job_stash import JobType from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService -from ..service.migration.migration_service import MigrationService from ..service.metadata.node_metadata import NodeMetadata +from ..service.migration.migration_service import MigrationService from ..service.network.network_service import NetworkService from ..service.network.utils import PeerHealthCheckTask from ..service.notification.notification_service import NotificationService @@ -128,9 +128,9 @@ from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig -from ..types.syft_object import Context from ..types.datetime import DATETIME_FORMAT from ..types.syft_metaclass import Empty +from ..types.syft_object import Context from ..types.syft_object import PartialSyftObject from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index cf108ea6145..de8a637055c 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -166,7 +166,7 @@ def deploy_to_python( queue_port: int | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, - migrate: bool = True + migrate: bool = True, ) -> NodeHandle: worker_classes = { NodeType.DOMAIN: Domain, @@ -194,7 +194,7 @@ def deploy_to_python( "create_producer": create_producer, "association_request_auto_approval": association_request_auto_approval, "background_tasks": background_tasks, - "migrate": migrate + "migrate": migrate, } if port: @@ -248,7 +248,7 @@ def deploy_to_remote( deployment_type_enum: DeploymentType, name: str, node_side_type: NodeSideType, - migrate: bool = False + migrate: bool = False, ) -> NodeHandle: node_port = int(os.environ.get("NODE_PORT", f"{DEFAULT_PORT}")) node_url = str(os.environ.get("NODE_URL", f"{DEFAULT_URL}")) @@ -287,7 +287,7 @@ def launch( queue_port: int | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, - migrate: bool = True + migrate: bool = True, ) -> NodeHandle: if dev_mode is True: thread_workers = True @@ -324,7 +324,7 @@ def launch( queue_port=queue_port, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, - migrate=migrate + migrate=migrate, ) elif deployment_type_enum == DeploymentType.REMOTE: return deploy_to_remote( @@ -332,7 +332,7 @@ def launch( deployment_type_enum=deployment_type_enum, name=name, node_side_type=node_side_type_enum, - migrate=migrate + migrate=migrate, ) raise NotImplementedError( f"deployment_type: {deployment_type_enum} is not supported" diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 82dca0e53c6..9a81261e5f3 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -183,14 +183,14 @@ def build_state(self, stop_key: str | None = None) -> dict: hash_str = object_metadata["hash"] state_versions = state_dict[canonical_name] state_version_hashes = [val[0] for val in state_versions.values()] - if action == "add" and ( - str(version) in state_versions.keys() - or hash_str in state_version_hashes - ): - raise Exception( - f"Can't add {object_metadata} already in state {versions}" - ) - elif action == "remove" and ( + # if action == "add" and ( + # str(version) in state_versions.keys() + # or hash_str in state_version_hashes + # ): + # raise Exception( + # f"Can't add {object_metadata} already in state {versions}" + # ) + if action == "remove" and ( str(version) not in state_versions.keys() and hash_str not in state_version_hashes ): diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index 0df0c7a110d..d8ef20e1b78 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -3,18 +3,18 @@ import json from typing import Any -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default - # relative from ...client.api import APIRegistry from ...client.enclave_client import EnclaveMetadata from ...serde.serializable import serializable from ...service.user.user_roles import ServiceRole +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.syft_object import SyftVerifyKey +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID from ...util.notebook_ui.components.tabulator_template import ( build_tabulator_table_with_data, diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 761141a6c36..536dc7289b8 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -67,11 +67,7 @@ def get(self, context: AuthedServiceContext, uid: UID) -> Job | SyftError: res = res.ok() return res - @service_method( - path="job.get_all", - name="get_all", - roles=DATA_SCIENTIST_ROLE_LEVEL - ) + @service_method(path="job.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> list[Job] | SyftError: res = self.stash.get_all(context.credentials) if res.is_err(): diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index d84db3b2361..501777889f3 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,11 +1,6 @@ -# stdlib - -# stdlib - -# stdlib - # stdlib from collections import defaultdict +import sys from typing import cast # third party @@ -97,7 +92,7 @@ def register_migration_state( def _find_klasses_pending_for_migration( self, context: AuthedServiceContext, object_types: list[type[SyftObject]] - ) -> list[SyftObject]: + ) -> list[type[SyftObject]]: klasses_to_be_migrated = [] for object_type in object_types: @@ -153,7 +148,7 @@ def _get_partition_from_type( if issubclass(object_type, ActionObject): object_partition = cast(KeyValueActionStore, context.node.action_store) else: - canonical_name = object_type.__canonical_name__ + canonical_name = object_type.__canonical_name__ # type: ignore[unreachable] object_partition = self.store.partitions.get(canonical_name) if object_partition is None: @@ -489,12 +484,9 @@ def migrate_data( return SyftError(message=objects_update_update_result.value) # now action objects - migration_actionobjects_result: dict[type[SyftObject], list[SyftObject]] = ( - self._get_migration_actionobjects(context) - ) - + migration_actionobjects_result = self._get_migration_actionobjects(context) if migration_actionobjects_result.is_err(): - return migration_actionobjects_result + return SyftError(message=migration_actionobjects_result.err()) migration_actionobjects = migration_actionobjects_result.ok() migrated_actionobjects = [] diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 05154aa113e..e1947449003 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -1,6 +1,6 @@ # stdlib -import traceback import logging +import traceback # third party from pydantic import EmailStr diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 2908454096e..e591d950476 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -32,6 +32,8 @@ def __getattr__(self, name: str) -> Any: # '_repr_html_', "_ipython_canary_method_should_not_exist_", "_ipython_display_", + "__canonical_name__", + "__version__", ] or name.startswith("_repr"): return super().__getattr__(name) display(self) diff --git a/tox.ini b/tox.ini index 7016091d792..bad928ce7e8 100644 --- a/tox.ini +++ b/tox.ini @@ -1065,7 +1065,7 @@ commands = [testenv:migration.prepare] description = Migration Test -deps = +deps = syft nbmake allowlist_externals = @@ -1073,7 +1073,7 @@ allowlist_externals = pytest ; changedir ; setenv -commands = +commands = bash -c 'python -c "import syft as sy; print(sy.__version__)"' pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/0-prepare-migration-data.ipynb -vvvv @@ -1088,7 +1088,7 @@ allowlist_externals = bash tox pytest -commands = +commands = bash -c 'python -c "import syft as sy; print(sy.__version__)"' tox -e migration.prepare pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/1-connect-and-migrate.ipynb -vvvv From 230f08f7c3d7e672a98a095872db38d4f4f64dac Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Tue, 2 Jul 2024 11:19:26 -0300 Subject: [PATCH 210/309] Add back `unsafe_function` as deprecated. --- .../api/0.9/02-review-code-and-approve.ipynb | 522 -------------- notebooks/api/0.9/05-custom-policy.ipynb | 647 ------------------ .../syft/src/syft/service/code/user_code.py | 6 + 3 files changed, 6 insertions(+), 1169 deletions(-) delete mode 100644 notebooks/api/0.9/02-review-code-and-approve.ipynb delete mode 100644 notebooks/api/0.9/05-custom-policy.ipynb diff --git a/notebooks/api/0.9/02-review-code-and-approve.ipynb b/notebooks/api/0.9/02-review-code-and-approve.ipynb deleted file mode 100644 index 24956e4a48b..00000000000 --- a/notebooks/api/0.9/02-review-code-and-approve.ipynb +++ /dev/null @@ -1,522 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Reviewing and Approving Code in Syft as a Data Owner" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "SYFT_VERSION = \">=0.8.2.b0,<0.9\"\n", - "package_string = f'\"syft{SYFT_VERSION}\"'\n", - "# %pip install {package_string} -q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# syft absolute\n", - "import syft as sy\n", - "\n", - "sy.requires(SYFT_VERSION)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Login to Syft Domain Server" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Launch and connect to test-domain-1 server we setup in the previous notebook\n", - "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Log into the node with default root credentials\n", - "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Selecting Project in the Syft Domain Server\n", - "\n", - "Let's see all the projects that are created by Data Scientists in this Domain Server" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "domain_client.projects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Select the project you want to work with\n", - "project = domain_client.projects[0]\n", - "project" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "All code requests submitted by the Data Scientists as a part of this project can be accessed by invoking the following" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "project.requests" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Tests\n", - "assert len(project.events) == 1\n", - "assert isinstance(project.events[0], sy.service.project.project.ProjectRequest)\n", - "assert len(project.requests) == 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reviewing Code Requests\n", - "\n", - "To review a specific request, we can select it and explore its attributes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "request = project.requests[0]\n", - "request" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# See the code written by the Data Scientist and its metadata in the request\n", - "func = request.code\n", - "func" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# To see just the code\n", - "func.show_code" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Reference to the assets that the function will run on\n", - "func.assets" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Viewing the Asset and it's mock/private variants that the Data Scientist will be running on" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "asset = func.assets[0]\n", - "asset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "mock_data = asset.mock\n", - "mock_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Private data. Accessible as we are logged in with Data Owner credentials\n", - "pvt_data = asset.data\n", - "pvt_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Tests\n", - "assert len(asset.data_subjects) == 1\n", - "assert mock_data.shape == (10, 22)\n", - "assert pvt_data.shape == (10, 22)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Policies in Syft Function\n", - "\n", - "Each Syft Function requires an Input & Output policy attached to the python function against which executions are verified.\n", - "\n", - "Syft provides the following default policies:\n", - "* `sy.ExactMatch()` Input policy ensures that function executes against the exact inputs specified by Data Scientist.\n", - "* `sy.OutputPolicyExecuteOnce()` Output policy makes sure that the Data Scientist can run the function only once against the input.\n", - "\n", - "We can also implement custom policies based on our requirements. (Refer to notebook [05-custom-policy](./05-custom-policy.ipynb) for more information.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "op = func.output_policy_type\n", - "op" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# See the implementation of the policy\n", - "print(op.policy_code)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Policies provided by Syft are available before approving the code,\n", - "# Custom policies are only safe to use once the code is approved.\n", - "\n", - "assert func.output_policy is not None\n", - "assert func.input_policy is not None\n", - "\n", - "func.output_policy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Execute the Data Scientist's code\n", - "\n", - "While Syft makes sure that the function is not tampered with, it does not perform any validation on the implementation itself.\n", - "\n", - "**It is the Data Owner's responsibility to review the code & verify if it's safe to execute.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Let's grab the actual executable function that was submitted by the user\n", - "users_function = func.run" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If the code looks safe, we can go ahead and execute it on the private dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mock_result = users_function(trade_data=mock_data)\n", - "mock_result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "real_result = users_function(trade_data=pvt_data)\n", - "real_result" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Approving a request\n", - "\n", - "By calling `request.approve()`, the data scientist can execute their function on the real data, and obtain the result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Uploaded wrong result - we shared mock_result instead of the real_result\n", - "result = request.approve()\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "assert isinstance(result, sy.SyftSuccess)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Denying a request\n", - "\n", - "At times you would want to deny a request in cases where the output is violating privacy, or if either of the policy is too lineant, or perhaps the code is confusing!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Deny the request with an appropriate reason\n", - "result = request.deny(\n", - " reason=(\n", - " \"The Submitted UserCode does not add differential privacy to the output.\"\n", - " \"Kindly add differential privacy and resubmit the code.\"\n", - " )\n", - ")\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(result, sy.SyftSuccess)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We can verify the status by checking our request list\n", - "project.requests" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Re-approving requests\n", - "\n", - "Let's re-approve the request so that we can work with the results in the later notebooks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "result = request.approve()\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Verify the request status again\n", - "project.requests" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Cleanup local domain server\n", - "\n", - "if node.node_type.value == \"python\":\n", - " node.land()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that the code request has been approved, let's go through the [03-data-scientist-download-result](./03-data-scientist-download-result.ipynb) notebook to see how a Data Scientist can access the results." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/api/0.9/05-custom-policy.ipynb b/notebooks/api/0.9/05-custom-policy.ipynb deleted file mode 100644 index ee6d44425ed..00000000000 --- a/notebooks/api/0.9/05-custom-policy.ipynb +++ /dev/null @@ -1,647 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "SYFT_VERSION = \">=0.8.2.b0,<0.9\"\n", - "package_string = f'\"syft{SYFT_VERSION}\"'\n", - "# %pip install {package_string} -q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# third party\n", - "import numpy as np\n", - "\n", - "# syft absolute\n", - "import syft as sy\n", - "\n", - "sy.requires(SYFT_VERSION)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True, reset=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "domain_client.register(\n", - " email=\"newuser@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "client_low_ds = node.login(email=\"newuser@openmined.org\", password=\"pw\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# stdlib\n", - "from typing import Any" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class RepeatedCallPolicy(sy.CustomOutputPolicy):\n", - " n_calls: int = 0\n", - " downloadable_output_args: list[str] = []\n", - " state: dict[Any, Any] = {}\n", - "\n", - " def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):\n", - " self.downloadable_output_args = (\n", - " downloadable_output_args if downloadable_output_args is not None else []\n", - " )\n", - " self.n_calls = n_calls\n", - " self.state = {\"counts\": 0}\n", - "\n", - " def public_state(self):\n", - " return self.state[\"counts\"]\n", - "\n", - " def update_policy(self, context, outputs):\n", - " self.state[\"counts\"] += 1\n", - "\n", - " def apply_to_output(self, context, outputs, update_policy=True):\n", - " if hasattr(outputs, \"syft_action_data\"):\n", - " outputs = outputs.syft_action_data\n", - " output_dict = {}\n", - " if self.state[\"counts\"] < self.n_calls:\n", - " for output_arg in self.downloadable_output_args:\n", - " output_dict[output_arg] = outputs[output_arg]\n", - " if update_policy:\n", - " self.update_policy(context, outputs)\n", - " else:\n", - " return None\n", - " return output_dict\n", - "\n", - " def _is_valid(self, context):\n", - " return self.state[\"counts\"] < self.n_calls" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=[\"y\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "policy.n_calls" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "policy.downloadable_output_args" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "policy.init_kwargs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "print(policy.init_kwargs)\n", - "a_obj = sy.ActionObject.from_obj({\"y\": [1, 2, 3]})\n", - "x = policy.apply_to_output(None, a_obj)\n", - "x[\"y\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "policy.n_calls" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "x = np.array([1, 2, 3])\n", - "x_pointer = sy.ActionObject.from_obj(x)\n", - "x_pointer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "x_pointer = x_pointer.send(domain_client)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "# third party\n", - "from result import Err\n", - "from result import Ok\n", - "\n", - "# syft absolute\n", - "from syft.client.api import AuthedServiceContext\n", - "from syft.client.api import NodeIdentity\n", - "\n", - "\n", - "class CustomExactMatch(sy.CustomInputPolicy):\n", - " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", - " pass\n", - "\n", - " def filter_kwargs(self, kwargs, context, code_item_id):\n", - " # stdlib\n", - "\n", - " try:\n", - " allowed_inputs = self.allowed_ids_only(\n", - " allowed_inputs=self.inputs, kwargs=kwargs, context=context\n", - " )\n", - " results = self.retrieve_from_db(\n", - " code_item_id=code_item_id,\n", - " allowed_inputs=allowed_inputs,\n", - " context=context,\n", - " )\n", - " except Exception as e:\n", - " return Err(str(e))\n", - " return results\n", - "\n", - " def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n", - " # syft absolute\n", - " from syft import NodeType\n", - " from syft.service.action.action_object import TwinMode\n", - "\n", - " action_service = context.node.get_service(\"actionservice\")\n", - " code_inputs = {}\n", - "\n", - " # When we are retrieving the code from the database, we need to use the node's\n", - " # verify key as the credentials. This is because when we approve the code, we\n", - " # we allow the private data to be used only for this specific code.\n", - " # but we are not modifying the permissions of the private data\n", - "\n", - " root_context = AuthedServiceContext(\n", - " node=context.node, credentials=context.node.verify_key\n", - " )\n", - " if context.node.node_type == NodeType.DOMAIN:\n", - " for var_name, arg_id in allowed_inputs.items():\n", - " kwarg_value = action_service._get(\n", - " context=root_context,\n", - " uid=arg_id,\n", - " twin_mode=TwinMode.NONE,\n", - " has_permission=True,\n", - " )\n", - " if kwarg_value.is_err():\n", - " return Err(kwarg_value.err())\n", - " code_inputs[var_name] = kwarg_value.ok()\n", - " else:\n", - " raise Exception(\n", - " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", - " )\n", - " return Ok(code_inputs)\n", - "\n", - " def allowed_ids_only(\n", - " self,\n", - " allowed_inputs,\n", - " kwargs,\n", - " context,\n", - " ):\n", - " # syft absolute\n", - " from syft import NodeType\n", - " from syft import UID\n", - "\n", - " if context.node.node_type == NodeType.DOMAIN:\n", - " node_identity = NodeIdentity(\n", - " node_name=context.node.name,\n", - " node_id=context.node.id,\n", - " verify_key=context.node.signing_key.verify_key,\n", - " )\n", - " allowed_inputs = allowed_inputs.get(node_identity, {})\n", - " else:\n", - " raise Exception(\n", - " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", - " )\n", - " filtered_kwargs = {}\n", - " for key in allowed_inputs.keys():\n", - " if key in kwargs:\n", - " value = kwargs[key]\n", - " uid = value\n", - " if not isinstance(uid, UID):\n", - " uid = getattr(value, \"id\", None)\n", - "\n", - " if uid != allowed_inputs[key]:\n", - " raise Exception(\n", - " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", - " )\n", - " filtered_kwargs[key] = value\n", - " return filtered_kwargs\n", - "\n", - " def _is_valid(\n", - " self,\n", - " context,\n", - " usr_input_kwargs,\n", - " code_item_id,\n", - " ):\n", - " filtered_input_kwargs = self.filter_kwargs(\n", - " kwargs=usr_input_kwargs,\n", - " context=context,\n", - " code_item_id=code_item_id,\n", - " )\n", - "\n", - " if filtered_input_kwargs.is_err():\n", - " return filtered_input_kwargs\n", - "\n", - " filtered_input_kwargs = filtered_input_kwargs.ok()\n", - "\n", - " expected_input_kwargs = set()\n", - " for _inp_kwargs in self.inputs.values():\n", - " for k in _inp_kwargs.keys():\n", - " if k not in usr_input_kwargs:\n", - " return Err(f\"Function missing required keyword argument: '{k}'\")\n", - " expected_input_kwargs.update(_inp_kwargs.keys())\n", - "\n", - " permitted_input_kwargs = list(filtered_input_kwargs.keys())\n", - " not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n", - " if len(not_approved_kwargs) > 0:\n", - " return Err(\n", - " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n", - " )\n", - " return Ok(True)\n", - "\n", - "\n", - "def allowed_ids_only(\n", - " self,\n", - " allowed_inputs,\n", - " kwargs,\n", - " context,\n", - "):\n", - " # syft absolute\n", - " from syft import NodeType\n", - " from syft import UID\n", - " from syft.client.api import NodeIdentity\n", - "\n", - " if context.node.node_type == NodeType.DOMAIN:\n", - " node_identity = NodeIdentity(\n", - " node_name=context.node.name,\n", - " node_id=context.node.id,\n", - " verify_key=context.node.signing_key.verify_key,\n", - " )\n", - " allowed_inputs = allowed_inputs.get(node_identity, {})\n", - " else:\n", - " raise Exception(\n", - " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", - " )\n", - " filtered_kwargs = {}\n", - " for key in allowed_inputs.keys():\n", - " if key in kwargs:\n", - " value = kwargs[key]\n", - " uid = value\n", - " if not isinstance(uid, UID):\n", - " uid = getattr(value, \"id\", None)\n", - "\n", - " if uid != allowed_inputs[key]:\n", - " raise Exception(\n", - " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", - " )\n", - " filtered_kwargs[key] = value\n", - " return filtered_kwargs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "@sy.syft_function(\n", - " input_policy=CustomExactMatch(x=x_pointer),\n", - " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=[\"y\"]),\n", - ")\n", - "def func(x):\n", - " return {\"y\": x + 1}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "request = client_low_ds.code.request_code_execution(func)\n", - "request" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "request_id = request.id" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "client_low_ds.code.get_all()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "for request in domain_client.requests:\n", - " if request.id == request_id:\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "func = request.code" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "# Custom policies need to be approved before they can be viewed and used\n", - "assert func.input_policy is None\n", - "assert func.output_policy is None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "result = func.run(x=x_pointer)\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "request.approve()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "assert func.input_policy is not None\n", - "assert func.output_policy is not None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "res_ptr = client_low_ds.code.func(x=x_pointer)\n", - "res_ptr" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": {}, - "outputs": [], - "source": [ - "res = res_ptr.get()\n", - "res" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "assert (res[\"y\"] == np.array([2, 3, 4])).all()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "assert set(res.keys()) == set(\"y\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "for code in domain_client.code.get_all():\n", - " if code.service_func_name == \"func\":\n", - " break\n", - "print(code.output_policy.state)\n", - "assert code.output_policy.state == {\"counts\": 1}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "if node.node_type.value == \"python\":\n", - " node.land()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 1dea190519b..7f767c59e59 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -68,6 +68,7 @@ from ...types.uid import UID from ...util import options from ...util.colors import SURFACE +from ...util.decorators import deprecated from ...util.markdown import CodeMarkdown from ...util.markdown import as_markdown_code from ...util.util import prompt_warning_message @@ -833,6 +834,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable | SyftError: return wrapper + @property + @deprecated(details="Use 'run' instead") + def unsafe_function(self) -> Callable | None: + return self.run + def _inner_repr(self, level: int = 0) -> str: shared_with_line = "" if len(self.output_readers) > 0 and self.output_reader_names is not None: From 3e77c05c1fe657cb9e6e8130b2c8f2df5e396740 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Tue, 2 Jul 2024 11:36:01 -0300 Subject: [PATCH 211/309] Fix incorrect `@deprecated` argument name --- packages/syft/src/syft/service/code/user_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 7f767c59e59..930bb5cfe32 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -835,7 +835,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable | SyftError: return wrapper @property - @deprecated(details="Use 'run' instead") + @deprecated(reason="Use 'run' instead") def unsafe_function(self) -> Callable | None: return self.run From efdc361ef8d4bb93b460c505d11bfa0edebbc2a5 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 16:40:59 +0200 Subject: [PATCH 212/309] styling --- .../syft/src/syft/service/code/user_code.py | 19 ++++++++++------- .../syft/src/syft/service/dataset/dataset.py | 21 +++++++++---------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 50b908803f0..07a333d4e07 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -50,6 +50,7 @@ from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime +from ...types.dicttuple import DictTuple from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 @@ -744,7 +745,7 @@ def byte_code(self) -> PyCodeObject | None: return compile_byte_code(self.parsed_code) @property - def assets(self) -> list[Asset] | SyftError: + def assets(self) -> DictTuple[str, Asset] | SyftError: if not self.input_policy: return [] @@ -757,7 +758,7 @@ def assets(self) -> list[Asset] | SyftError: if isinstance(datasets, SyftError): return datasets - all_assets = {} + all_assets: dict[UID, Asset] = {} for dataset in datasets: for asset in dataset.asset_list: asset._dataset_name = dataset.name @@ -770,22 +771,24 @@ def assets(self) -> list[Asset] | SyftError: all_inputs.update(vals) # map the action_id to the asset - used_assets = [] + used_assets: list[Asset] = [] for kwarg_name, action_id in all_inputs.items(): asset = all_assets.get(action_id, None) asset._kwarg_name = kwarg_name used_assets.append(asset) - return used_assets + + asset_dict = {asset._kwarg_name: asset for asset in used_assets} + return DictTuple(asset_dict) @property def _asset_json(self) -> str | SyftError: if isinstance(self.assets, SyftError): return self.assets - used_assets = { - asset._kwarg_name: asset._get_dict_for_user_code_repr() - for asset in self.assets + asset_dict = { + argument: asset._get_dict_for_user_code_repr() + for argument, asset in self.assets.items() } - asset_str = json.dumps(used_assets, indent=2) + asset_str = json.dumps(asset_dict, indent=2) return asset_str def get_sync_dependencies( diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index a8bfdfc62cc..b41b4756590 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -197,26 +197,25 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: def _coll_repr_(self) -> dict[str, Any]: base_dict = { - "Asset Name": self.name, + "Parameter": self._kwarg_name, "Action ID": self.action_id, + "Asset Name": self.name, + "Dataset Name": self._dataset_name, "Node UID": self.node_uid, } # _kwarg_name and _dataset_name are set by the UserCode.assets - if self._kwarg_name and self._dataset_name: - base_dict.update( - { - "Parameter": self._kwarg_name, - "Dataset Name": self._dataset_name, - } - ) - return base_dict + # if they are None, we remove them from the dict + filtered_dict = { + key: value for key, value in base_dict.items() if value is not None + } + return filtered_dict def _get_dict_for_user_code_repr(self) -> dict[str, Any]: return { - "source_dataset": self._dataset_name, - "source_asset": self.name, "action_id": self.action_id.no_dash, + "source_asset": self.name, + "source_dataset": self._dataset_name, "source_node": self.node_uid.no_dash, } From 552a15f9bbd7de74c86a30898c71680a9eb07318 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 2 Jul 2024 16:48:47 +0200 Subject: [PATCH 213/309] uncomment has_result_read_permission --- packages/syft/src/syft/service/action/action_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 10b864732bb..7506c464e72 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -512,7 +512,7 @@ def set_result_to_store( result_blob_id = result_action_object.syft_blob_storage_entry_id # type: ignore[unreachable] # pass permission information to the action store as extra kwargs - # context.extra_kwargs = {"has_result_read_permission": True} + context.extra_kwargs = {"has_result_read_permission": True} # Since this just meta data about the result, they always have access to it. set_result = self._set( From fe921ce2de117a700d71dd8abdb1416e496bb126 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Tue, 2 Jul 2024 11:59:31 -0300 Subject: [PATCH 214/309] Update `api/0.8` notebooks to use `run` instead of `unsafe_function` --- notebooks/api/0.8/02-review-code-and-approve.ipynb | 4 ++-- notebooks/api/0.8/05-custom-policy.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/api/0.8/02-review-code-and-approve.ipynb b/notebooks/api/0.8/02-review-code-and-approve.ipynb index 9612144b952..cd84910030b 100644 --- a/notebooks/api/0.8/02-review-code-and-approve.ipynb +++ b/notebooks/api/0.8/02-review-code-and-approve.ipynb @@ -325,7 +325,7 @@ "outputs": [], "source": [ "# Let's grab the actual executable function that was submitted by the user\n", - "users_function = func.unsafe_function" + "users_function = func.run" ] }, { @@ -501,7 +501,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.4" }, "toc": { "base_numbering": 1, diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index 8c7b18c9328..ddd043c6f43 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -507,7 +507,7 @@ }, "outputs": [], "source": [ - "result = func.unsafe_function(x=x_pointer)\n", + "result = func.run(x=x_pointer)\n", "result" ] }, @@ -626,7 +626,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.4" }, "toc": { "base_numbering": 1, From efb20616ade464ceee573989a1c8df65beaef6af Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 2 Jul 2024 18:06:31 +0200 Subject: [PATCH 215/309] migrate to new client --- .../1c-migrate-to-new-node.ipynb | 143 ++++++++++++----- packages/syft/src/syft/__init__.py | 1 + .../syft/src/syft/client/domain_client.py | 36 +---- packages/syft/src/syft/client/migrations.py | 69 +++++++++ packages/syft/src/syft/node/node.py | 3 +- .../src/syft/protocol/protocol_version.json | 9 +- .../service/migration/migration_service.py | 146 +++++++++++++----- .../migration/object_migration_state.py | 14 +- 8 files changed, 317 insertions(+), 104 deletions(-) create mode 100644 packages/syft/src/syft/client/migrations.py diff --git a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb index bac198022da..798d8630081 100644 --- a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb +++ b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb @@ -37,7 +37,11 @@ "- [x] check SyftObjectRegistry and compare with current implementation\n", "- [x] run unit tests\n", "- [ ] finalize notebooks for testing, run in CI\n", - "- [ ] other tasks defined in tickets" + "- [ ] other tasks defined in tickets\n", + "- [ ] also get actionobjects in get_migration_objects\n", + "- [ ] make clientside method to migrate blobstorage `migrate_blobstorage(from_client, to_client, migration_data)`\n", + "- [ ] make domainclient get_migration and apply_migration (to/from file?)\n", + " - merge with save_migration_objects_to_file / migrate_objects_from_file " ] }, { @@ -96,11 +100,13 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "7", "metadata": {}, + "outputs": [], "source": [ - "## document store objects" + "migration_data = client.services.migration.get_migration_data()" ] }, { @@ -110,7 +116,8 @@ "metadata": {}, "outputs": [], "source": [ - "migration_dict = client.services.migration.get_migration_objects(get_all=True)" + "# syft absolute\n", + "from syft.client.migrations import migrate_blob_storage" ] }, { @@ -120,7 +127,7 @@ "metadata": {}, "outputs": [], "source": [ - "migration_dict" + "sy.migrate(from_client=client, to_client=temp_client)" ] }, { @@ -129,6 +136,68 @@ "id": "10", "metadata": {}, "outputs": [], + "source": [ + "migrate_blob_storage(\n", + " from_client=client,\n", + " to_client=temp_client,\n", + " blob_storage_objects=migration_data.blob_storage_objects,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "migration_data.store_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "temp_client.api.services.migration.apply_migration_data(migration_data)" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "## document store objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "migration_dict = client.services.migration.get_migration_objects(get_all=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "migration_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], "source": [ "def custom_migration_function(context, obj: SyftObject, klass) -> SyftObject:\n", " # Here, we are just doing the same, but this is where you would write your custom logic\n", @@ -138,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +237,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -190,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -200,7 +269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -209,7 +278,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "22", "metadata": {}, "source": [ "# Migrate blobstorage" @@ -218,7 +287,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "23", "metadata": {}, "outputs": [], "source": [] @@ -226,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -239,7 +308,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -275,7 +344,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -287,7 +356,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -296,7 +365,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "28", "metadata": {}, "source": [ "## Actions and ActionObjects" @@ -305,7 +374,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -317,7 +386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -327,7 +396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -337,7 +406,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +416,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -357,7 +426,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -367,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -385,7 +454,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -395,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "37", "metadata": {}, "outputs": [], "source": [ @@ -407,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -417,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "39", "metadata": {}, "outputs": [], "source": [ @@ -427,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -440,7 +509,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +521,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "42", "metadata": {}, "source": [ "## Store metadata\n", @@ -464,7 +533,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "43", "metadata": {}, "outputs": [], "source": [ @@ -475,7 +544,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +558,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "45", "metadata": {}, "outputs": [], "source": [ @@ -500,7 +569,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -510,7 +579,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "47", "metadata": {}, "outputs": [], "source": [ @@ -560,7 +629,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42", + "id": "48", "metadata": {}, "outputs": [], "source": [] diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 23ac10fd52f..086e759eb30 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -16,6 +16,7 @@ from .client.client import register from .client.domain_client import DomainClient from .client.gateway_client import GatewayClient +from .client.migrations import migrate from .client.registry import DomainRegistry from .client.registry import EnclaveRegistry from .client.registry import NetworkRegistry diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 2122bcb1b77..6b1d96769a1 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -7,7 +7,6 @@ import re from string import Template import traceback -from typing import Any from typing import TYPE_CHECKING from typing import cast @@ -17,8 +16,6 @@ from tqdm import tqdm # relative -from .. import deserialize -from .. import serialize from ..abstract_node import NodeSideType from ..serde.serializable import serializable from ..service.action.action_object import ActionObject @@ -27,6 +24,7 @@ from ..service.dataset.dataset import Contributor from ..service.dataset.dataset import CreateAsset from ..service.dataset.dataset import CreateDataset +from ..service.migration.object_migration_state import MigrationData from ..service.response import SyftError from ..service.response import SyftSuccess from ..service.sync.diff_state import ResolvedSyncState @@ -34,7 +32,6 @@ from ..service.user.roles import Roles from ..service.user.user import UserView from ..types.blob_storage import BlobFile -from ..types.syft_object import Context from ..types.uid import UID from ..util.misc_objs import HTMLObject from ..util.util import get_mb_size @@ -399,31 +396,12 @@ def code_status(self) -> APIModule | None: def output(self) -> APIModule | None: return self._get_service_by_name_if_exists("output") - def save_migration_objects_to_file( - self, filename: str, get_all: bool = False - ) -> dict[Any, Any] | SyftError: - migration_dict = self.api.services.migration.get_migration_objects( - get_all=get_all - ) - if isinstance(migration_dict, SyftError): - return migration_dict - ser_bytes = serialize(migration_dict, to_bytes=True) - with open(filename, "wb") as f: - f.write(ser_bytes) - return migration_dict - - def migrate_objects_from_file(self, filename: str) -> SyftSuccess | SyftError: - with open(filename, "rb") as f: - ser_bytes = f.read() - migration_dict = deserialize(ser_bytes, from_bytes=True) - context = Context() - migrated_objects = [] - for klass, objects in migration_dict.items(): - for obj in objects: - migrated_obj = obj.migrate_to(klass.__version__, context) - migrated_objects.append(migrated_obj) - res = self.api.services.migration.update_migrated_objects(migrated_objects) - return res + @property + def migration(self) -> APIModule | None: + return self._get_service_by_name_if_exists("migration") + + def get_migration_data(self) -> MigrationData | SyftError: + return self.api.services.migration.get_migration_data() def get_project( self, diff --git a/packages/syft/src/syft/client/migrations.py b/packages/syft/src/syft/client/migrations.py new file mode 100644 index 00000000000..2deaeda2508 --- /dev/null +++ b/packages/syft/src/syft/client/migrations.py @@ -0,0 +1,69 @@ +# stdlib +from io import BytesIO +import sys + +# relative +from ..serde.serialize import _serialize +from ..service.response import SyftError +from ..service.response import SyftSuccess +from ..types.blob_storage import BlobStorageEntry +from ..types.blob_storage import CreateBlobStorageEntry +from ..types.syft_object import Context +from ..types.syft_object import SyftObject +from .domain_client import DomainClient + + +def migrate_blob_storage_object( + from_client: DomainClient, + to_client: DomainClient, + obj: SyftObject, +) -> SyftSuccess | SyftError: + migrated_obj = obj.migrate_to(BlobStorageEntry.__version__, Context()) + uploaded_by = migrated_obj.uploaded_by + blob_retrieval = from_client.services.blob_storage.read(obj.id) + if isinstance(blob_retrieval, SyftError): + return blob_retrieval + + data = blob_retrieval.read() + serialized = _serialize(data, to_bytes=True) + size = sys.getsizeof(serialized) + blob_create = CreateBlobStorageEntry.from_blob_storage_entry(obj) + blob_create.file_size = size + + blob_deposit_object = to_client.services.blob_storage.allocate_for_user( + blob_create, uploaded_by + ) + if isinstance(blob_deposit_object, SyftError): + return blob_deposit_object + return blob_deposit_object.write(BytesIO(serialized)) + + +def migrate_blob_storage( + from_client: DomainClient, + to_client: DomainClient, + blob_storage_objects: list[SyftObject], +) -> SyftSuccess | SyftError: + for obj in blob_storage_objects: + migration_result = migrate_blob_storage_object(from_client, to_client, obj) + if isinstance(migration_result, SyftError): + return migration_result + return SyftSuccess(message="Blob storage migration successful.") + + +def migrate( + from_client: DomainClient, to_client: DomainClient +) -> SyftSuccess | SyftError: + migration_data = from_client.get_migration_data() + if isinstance(migration_data, SyftError): + return migration_data + + # Blob storage is migrated via client + blob_storage_objects = migration_data.blob_storage_objects + blob_migration_result = migrate_blob_storage( + from_client, to_client, blob_storage_objects + ) + if isinstance(blob_migration_result, SyftError): + return blob_migration_result + + # Rest of the migration data is migrated via service + return to_client.api.services.migration.apply_migration_data(migration_data) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 424505b7a28..d46dd1b0c64 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -808,7 +808,8 @@ def post_init(self) -> None: if "usercodeservice" in self.service_path_map: user_code_service = self.get_service(UserCodeService) - user_code_service.load_user_code(context=context) + # TODO this does not work with un-migrated UserCode + # user_code_service.load_user_code(context=context) def reload_user_code() -> None: user_code_service.load_user_code(context=context) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 6ef5623158c..2865e037017 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -430,7 +430,14 @@ "StoreMetadata": { "1": { "version": 1, - "hash": "bb9edb077f0214c5867d5349aa99eb584d133bd5f2cc5c824986c9174c0dbbc9", + "hash": "4a0522eaf28414dd53adcb7d5edb81b4a5b5bbe2e805cb78aa91329c3d6c32a8", + "action": "add" + } + }, + "MigrationData": { + "1": { + "version": 1, + "hash": "492eae1c054298dc52a0bf8d5493249919d58c0201ef74c552746f89406d18cf", "action": "add" } } diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 501777889f3..982e7d53d33 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -12,6 +12,7 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...store.document_store import StorePartition +from ...types.blob_storage import BlobStorageEntry from ...types.syft_object import SyftObject from ..action.action_object import Action from ..action.action_object import ActionObject @@ -24,6 +25,7 @@ from ..service import AbstractService from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL +from .object_migration_state import MigrationData from .object_migration_state import StoreMetadata from .object_migration_state import SyftMigrationStateStash from .object_migration_state import SyftObjectMigrationState @@ -251,6 +253,7 @@ def _update_store_metadata_for_klass( def _update_store_metadata( self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] ) -> Result[str, str]: + print("Updating store metadata") for metadata in store_metadata.values(): result = self._update_store_metadata_for_klass(context, metadata) if result.is_err(): @@ -424,6 +427,30 @@ def _update_migrated_objects( # return result return Ok(value="success") + def _migrate_objects( + self, + context: AuthedServiceContext, + migration_objects: dict[type[SyftObject], list[SyftObject]], + ) -> Result[list[SyftObject], str]: + migrated_objects = [] + for klass, objects in migration_objects.items(): + canonical_name = klass.__canonical_name__ + # Migrate data for objects in document store + print(f"Migrating data for: {canonical_name} table.") + for object in objects: + try: + migrated_value = object.migrate_to(klass.__version__, context) + migrated_objects.append(migrated_value) + except Exception: + # stdlib + import traceback + + print(traceback.format_exc()) + return Err( + f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" + ) + return Ok(migrated_objects) + @service_method( path="migration.migrate_data", name="migrate_data", @@ -458,24 +485,10 @@ def migrate_data( return migration_objects_result migration_objects = migration_objects_result.ok() - migrated_objects = [] - - for klass, objects in migration_objects.items(): - canonical_name = klass.__canonical_name__ - # Migrate data for objects in document store - print(f"Migrating data for: {canonical_name} table.") - for object in objects: - try: - migrated_value = object.migrate_to(klass.__version__, context) - migrated_objects.append(migrated_value) - except Exception: - # stdlib - import traceback - - print(traceback.format_exc()) - return Err( - f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" - ) + migrated_objects_result = self._migrate_objects(context, migration_objects) + if migrated_objects_result.is_err(): + return SyftError(message=migrated_objects_result.err()) + migrated_objects = migrated_objects_result.ok() objects_update_update_result = self._update_migrated_objects( context, migrated_objects @@ -485,27 +498,15 @@ def migrate_data( # now action objects migration_actionobjects_result = self._get_migration_actionobjects(context) + if migration_actionobjects_result.is_err(): return SyftError(message=migration_actionobjects_result.err()) migration_actionobjects = migration_actionobjects_result.ok() - migrated_actionobjects = [] - for klass, action_objects in migration_actionobjects.items(): - # these are Actions, ActionObjects, and possibly others - for object in action_objects: - try: - migrated_actionobject = object.migrate_to( - klass.__version__, context - ) - migrated_actionobjects.append(migrated_actionobject) - except Exception: - # stdlib - import traceback - - print(traceback.format_exc()) - return Err( - f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" - ) + migrated_actionobjects = self._migrate_objects(context, migration_actionobjects) + if migrated_actionobjects.is_err(): + return SyftError(message=migrated_actionobjects.err()) + migrated_actionobjects = migrated_actionobjects.ok() actionobjects_update_update_result = self._update_migrated_actionobjects( context, migrated_actionobjects @@ -579,3 +580,78 @@ def _update_migrated_actionobjects( if res.is_err(): return res return Ok("success") + + @service_method( + path="migration.get_migration_data", + name="get_migration_data", + roles=ADMIN_ROLE_LEVEL, + ) + def get_migration_data( + self, context: AuthedServiceContext + ) -> MigrationData | SyftError: + store_objects_result = self._get_migration_objects(context, get_all=True) + if store_objects_result.is_err(): + return SyftError(message=store_objects_result.err()) + store_objects = store_objects_result.ok() + + action_objects_result = self._get_migration_actionobjects(context, get_all=True) + if action_objects_result.is_err(): + return SyftError(message=action_objects_result.err()) + action_objects = action_objects_result.ok() + + blob_storage_objects = store_objects.pop(BlobStorageEntry, []) + + store_metadata_result = self._get_all_store_metadata(context) + if store_metadata_result.is_err(): + return SyftError(message=store_metadata_result.err()) + store_metadata = store_metadata_result.ok() + + return MigrationData( + store_objects=store_objects, + metadata=store_metadata, + action_objects=action_objects, + blob_storage_objects=blob_storage_objects, + ) + + @service_method( + path="migration.apply_migration_data", + name="apply_migration_data", + roles=ADMIN_ROLE_LEVEL, + ) + def apply_migration_data( + self, + context: AuthedServiceContext, + migration_data: MigrationData, + ) -> SyftSuccess | SyftError: + # NOTE blob storage is migrated via client, it needs access to both source and destination blob storages. + + # migrate + apply store objects + migrated_objects_result = self._migrate_objects( + context, migration_data.store_objects + ) + if migrated_objects_result.is_err(): + return SyftError(message=migrated_objects_result.err()) + migrated_objects = migrated_objects_result.ok() + store_objects_result = self._create_migrated_objects(context, migrated_objects) + if store_objects_result.is_err(): + return SyftError(message=store_objects_result.err()) + + # migrate+apply action objects + migrated_actionobjects = self._migrate_objects( + context, migration_data.action_objects + ) + if migrated_actionobjects.is_err(): + return SyftError(message=migrated_actionobjects.err()) + migrated_actionobjects = migrated_actionobjects.ok() + action_objects_result = self._update_migrated_actionobjects( + context, migrated_actionobjects + ) + if action_objects_result.is_err(): + return SyftError(message=action_objects_result.err()) + + # apply metadata + metadata_result = self._update_store_metadata(context, migration_data.metadata) + if metadata_result.is_err(): + return SyftError(message=metadata_result.err()) + + return SyftSuccess(message="Migration completed successfully") diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 1f213fa2b5d..e5f37e233ac 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -12,6 +12,7 @@ from ...store.document_store import PartitionSettings from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry from ...types.uid import UID @@ -85,10 +86,21 @@ def get_by_name( @serializable() -class StoreMetadata(SyftObject): +class StoreMetadata(SyftBaseObject): __canonical_name__ = "StoreMetadata" __version__ = SYFT_OBJECT_VERSION_1 object_type: type permissions: dict[UID, set[str]] storage_permissions: dict[UID, set[UID]] + + +@serializable() +class MigrationData(SyftBaseObject): + __canonical_name__ = "MigrationData" + __version__ = SYFT_OBJECT_VERSION_1 + + store_objects: dict[type[SyftObject], list[SyftObject]] + metadata: dict[type[SyftObject], StoreMetadata] + action_objects: dict[type[SyftObject], list[SyftObject]] + blob_storage_objects: list[SyftObject] From c4dc5b52bf74ee8df2d581bc2b01dcf1d9e3ed67 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Tue, 2 Jul 2024 18:14:17 -0300 Subject: [PATCH 216/309] Check for unexpected constructor parameters in CreateDataset and CreateAsset --- packages/syft/src/syft/service/dataset/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index b41b4756590..efca396d04a 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -336,7 +336,7 @@ class CreateAsset(SyftObject): uploader: Contributor | None = None __repr_attrs__ = ["name"] - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict(validate_assignment=True, extra="forbid") def __init__(self, description: str | None = None, **data: Any) -> None: if isinstance(description, str): @@ -615,7 +615,7 @@ class CreateDataset(Dataset): created_at: DateTime | None = None # type: ignore[assignment] uploader: Contributor | None = None # type: ignore[assignment] - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict(validate_assignment=True, extra="forbid") def _check_asset_must_contain_mock(self) -> None: _check_asset_must_contain_mock(self.asset_list) From d3a6b6a9498e7f9425b3932131c559098a9121a4 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 3 Jul 2024 08:40:38 +0700 Subject: [PATCH 217/309] [syft/action_obj] saving action data cache immediately after trying to save to blob store - moving / remove some checks Co-authored-by: Shubham Gupta --- .../src/syft/service/action/action_object.py | 29 ++++++++++--------- .../src/syft/service/blob_storage/util.py | 5 ---- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 51d208126b9..4ef06a9bffc 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -798,22 +798,23 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: from ...types.blob_storage import BlobFile from ...types.blob_storage import CreateBlobStorageEntry - api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) - if api is None: - raise ValueError( - f"api is None. You must login to {self.syft_node_location}" - ) - if not can_upload_to_blob_storage(data, api.metadata): - return SyftWarning( - message=f"The action object {self.id} was not saved to " - f"the blob store but to memory cache since it is small." - ) - if not isinstance(data, ActionDataEmpty): + api = APIRegistry.api_for( + self.syft_node_location, self.syft_client_verify_key + ) if isinstance(data, BlobFile): if not data.uploaded: data._upload_to_blobstorage_from_api(api) else: + if api is None: + raise ValueError( + f"api is None. You must login to {self.syft_node_location}" + ) + if not can_upload_to_blob_storage(data, api.metadata): + return SyftWarning( + message=f"The action object {self.id} was not saved to " + f"the blob store but to memory cache since it is small." + ) serialized = serialize(data, to_bytes=True) size = sys.getsizeof(serialized) storage_entry = CreateBlobStorageEntry.from_obj(data, file_size=size) @@ -830,13 +831,13 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: ) if allocate_method is not None: blob_deposit_object = allocate_method(storage_entry) - if isinstance(blob_deposit_object, SyftError): return blob_deposit_object result = blob_deposit_object.write(BytesIO(serialized)) if isinstance(result, SyftError): return result + self.syft_blob_storage_entry_id = ( blob_deposit_object.blob_storage_entry_id ) @@ -846,6 +847,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: self.syft_action_data_type = type(data) self._set_reprs(data) self.syft_has_bool_attr = hasattr(data, "__bool__") + self.syft_action_data_cache = data else: logger.debug( "skipping writing action object to store, passed data was empty." @@ -859,10 +861,12 @@ def _save_to_blob_storage( data = self.syft_action_data if isinstance(data, SyftError): return data + if isinstance(data, ActionDataEmpty): return SyftError( message=f"cannot store empty object {self.id} to the blob storage" ) + try: result = self._save_to_blob_storage_(data) if isinstance(result, SyftError | SyftWarning): @@ -877,7 +881,6 @@ def _save_to_blob_storage( f"Failed to save action object {self.id} to the blob store. Error: {e}" ) - self.syft_action_data_cache = data return SyftWarning( message=f"The action object {self.id} was not saved to " f"the blob store but to memory cache since it is small." diff --git a/packages/syft/src/syft/service/blob_storage/util.py b/packages/syft/src/syft/service/blob_storage/util.py index 68f3250035c..e82d1de9cff 100644 --- a/packages/syft/src/syft/service/blob_storage/util.py +++ b/packages/syft/src/syft/service/blob_storage/util.py @@ -8,11 +8,6 @@ def min_size_for_blob_storage_upload(metadata: NodeMetadata | NodeMetadataJSON) -> int: - if not isinstance(metadata, (NodeMetadata | NodeMetadataJSON)): - raise ValueError( - f"argument `metadata` is type {type(metadata)}, " - f"but it should be of type NodeMetadata or NodeMetadataJSON" - ) return metadata.min_size_blob_storage_mb From bee4737fea704290468f5a7f6d46e1adc4499178 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 3 Jul 2024 13:44:07 +0800 Subject: [PATCH 218/309] Get all user code entries without using the deleted admin in test --- packages/syft/tests/syft/users/user_code_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index ff0c2da101a..6dcc6eb0204 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -75,8 +75,11 @@ def test_new_admin_can_list_user_code( res = root_client.api.services.user.delete(root_client.me.id) assert not isinstance(res, SyftError) - assert len(root_client.code.get_all()) == len(admin.code.get_all()) - assert {c.id for c in root_client.code} == {c.id for c in admin.code} + user_code_stash = worker.get_service("usercodeservice").stash + user_code = user_code_stash.get_all(user_code_stash.store.root_verify_key).ok() + + assert len(user_code) == len(admin.code.get_all()) + assert {c.id for c in user_code} == {c.id for c in admin.code} def test_user_code(worker) -> None: From 7f8517e01b6883d569622ac82aab7618cc67b2b4 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 3 Jul 2024 12:30:24 +0530 Subject: [PATCH 219/309] raise exception if exception raised in action object --- .../syft/src/syft/service/action/action_object.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 4ef06a9bffc..ea07a9d4dea 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -847,12 +847,13 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: self.syft_action_data_type = type(data) self._set_reprs(data) self.syft_has_bool_attr = hasattr(data, "__bool__") - self.syft_action_data_cache = data else: logger.debug( "skipping writing action object to store, passed data was empty." ) + self.syft_action_data_cache = data + return None def _save_to_blob_storage( @@ -877,14 +878,7 @@ def _save_to_blob_storage( message=f"Saved action object {self.id} to the blob store" ) except Exception as e: - print( - f"Failed to save action object {self.id} to the blob store. Error: {e}" - ) - - return SyftWarning( - message=f"The action object {self.id} was not saved to " - f"the blob store but to memory cache since it is small." - ) + raise e def _clear_cache(self) -> None: self.syft_action_data_cache = self.as_empty_data() From fa7eb2a936f7f1dc8cf328b4ca720c00840522f3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 3 Jul 2024 15:11:16 +0800 Subject: [PATCH 220/309] Revert pytest mark removal --- packages/syft/tests/syft/stores/action_store_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 613cefa16d7..04d693a618a 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -1,4 +1,5 @@ # stdlib +import sys from typing import Any # third party @@ -53,6 +54,8 @@ def test_action_store_sanity(store: Any): ], ) @pytest.mark.parametrize("permission", permissions) +@pytest.mark.flaky(reruns=3, reruns_delay=3) +@pytest.mark.skipif(sys.platform == "darwin", reason="skip on mac") def test_action_store_test_permissions(store: Any, permission: Any): client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) From db834ce8981557a925a02308c4be01b66639d3e2 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 3 Jul 2024 11:58:54 +0300 Subject: [PATCH 221/309] display usercode metadata and code str differently --- .../syft/src/syft/service/code/user_code.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 5c1dd137cd0..837da080e49 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -71,6 +71,7 @@ from ...util.decorators import deprecated from ...util.markdown import CodeMarkdown from ...util.markdown import as_markdown_code +from ...util.notebook_ui.styles import FONT_CSS from ...util.util import prompt_warning_message from ..action.action_endpoint import CustomEndpointActionObject from ..action.action_object import Action @@ -915,6 +916,42 @@ def _inner_repr(self, level: int = 0) -> str: def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return as_markdown_code(self._inner_repr()) + def _repr_html_(self) -> str: + shared_with_line = "" + if len(self.output_readers) > 0 and self.output_reader_names is not None: + owners_string = " and ".join([f"*{x}*" for x in self.output_reader_names]) + shared_with_line += ( + f"

Custom Policy: " + f"outputs are *shared* with the owners of {owners_string} once computed

" + ) + + repr_str = f""" + +
+

UserCode

+

id: UID = {self.id}

+

service_func_name: str = {self.service_func_name}

+

shareholders: list = {self.input_owners}

+

status: list = {self.code_status}

+ {shared_with_line} +

code:

+

+ """ + return repr_str + + def _ipython_display_(self) -> None: + # third party + from IPython.display import HTML + from IPython.display import Markdown + + # display_html() + display(HTML(self._repr_html_()), Markdown(as_markdown_code(self.raw_code))) + @property def show_code(self) -> CodeMarkdown: return CodeMarkdown(self.raw_code) From 5718f0083b6a402e1d9bb4795e4c45434483a8ea Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 3 Jul 2024 14:37:39 +0530 Subject: [PATCH 222/309] get metadata from api_or_context method in action object --- .../src/syft/service/action/action_object.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index ea07a9d4dea..92f734c7a60 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -799,18 +799,19 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: from ...types.blob_storage import CreateBlobStorageEntry if not isinstance(data, ActionDataEmpty): - api = APIRegistry.api_for( - self.syft_node_location, self.syft_client_verify_key - ) if isinstance(data, BlobFile): if not data.uploaded: + api = APIRegistry.api_for( + self.syft_node_location, self.syft_client_verify_key + ) data._upload_to_blobstorage_from_api(api) else: - if api is None: - raise ValueError( - f"api is None. You must login to {self.syft_node_location}" - ) - if not can_upload_to_blob_storage(data, api.metadata): + get_metadata = from_api_or_context( + func_or_path="metadata.get_metadata", + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, + ) + if not can_upload_to_blob_storage(data, get_metadata()): return SyftWarning( message=f"The action object {self.id} was not saved to " f"the blob store but to memory cache since it is small." From 49ceece4c40ae20aec0b80ca88bed8ad94480188 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 3 Jul 2024 14:43:22 +0530 Subject: [PATCH 223/309] fix linting --- packages/syft/src/syft/service/action/action_object.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 92f734c7a60..b9ffd16ebf6 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -811,7 +811,9 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: syft_node_location=self.syft_node_location, syft_client_verify_key=self.syft_client_verify_key, ) - if not can_upload_to_blob_storage(data, get_metadata()): + if get_metadata is not None and not can_upload_to_blob_storage( + data, get_metadata() + ): return SyftWarning( message=f"The action object {self.id} was not saved to " f"the blob store but to memory cache since it is small." From 73625c877e9def0c1793d00248901748bb2e25d2 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 3 Jul 2024 16:21:20 +0700 Subject: [PATCH 224/309] [syft/user_code] refactoring utility functions for parsing user code Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/code/user_code.py | 32 +++---------------- packages/syft/src/syft/service/code/utils.py | 28 ++++++++++++++++ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 01114b49c0a..88b3d5109bd 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -105,9 +105,10 @@ from ..service import ServiceConfigRegistry from ..user.user import UserView from ..user.user_roles import ServiceRole -from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor from .unparse import unparse +from .utils import check_for_global_vars +from .utils import parse_code from .utils import submit_subjobs_code if TYPE_CHECKING: @@ -1327,31 +1328,6 @@ def generate_unique_func_name(context: TransformContext) -> TransformContext: return context -def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: - """ - Check that the code does not contain any global variables - """ - v = GlobalsVisitor() - try: - v.visit(code_tree) - except Exception: - raise SyftException( - "Your code contains (a) global variable(s), which is not allowed" - ) - return v - - -def _parse_code(raw_code: str) -> ast.Module | SyftWarning: - """ - Parse the code into an AST tree and return a warning if there are syntax errors - """ - try: - tree = ast.parse(raw_code) - except SyntaxError as e: - raise SyftException(f"Your code contains syntax error: {e}") - return tree - - def parse_user_code( raw_code: str, func_name: str, @@ -1360,8 +1336,8 @@ def parse_user_code( ) -> str: # parse the code, check for syntax errors and if there are global variables try: - tree: ast.Module = _parse_code(raw_code=raw_code) - _check_global(code_tree=tree) + tree: ast.Module = parse_code(raw_code=raw_code) + check_for_global_vars(code_tree=tree) except SyftException as e: raise SyftException(f"{e}") diff --git a/packages/syft/src/syft/service/code/utils.py b/packages/syft/src/syft/service/code/utils.py index fccc5314c43..a3d59cbe161 100644 --- a/packages/syft/src/syft/service/code/utils.py +++ b/packages/syft/src/syft/service/code/utils.py @@ -6,6 +6,9 @@ from IPython import get_ipython # relative +from ..response import SyftException +from ..response import SyftWarning +from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor @@ -36,3 +39,28 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore # fetch if specs["type_name"] == "SubmitUserCode": ep_client.code.submit(ipython.ev(call)) + + +def check_for_global_vars(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: + """ + Check that the code does not contain any global variables + """ + v = GlobalsVisitor() + try: + v.visit(code_tree) + except Exception: + raise SyftException( + "Your code contains (a) global variable(s), which is not allowed" + ) + return v + + +def parse_code(raw_code: str) -> ast.Module | SyftWarning: + """ + Parse the code into an AST tree and return a warning if there are syntax errors + """ + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + raise SyftException(f"Your code contains syntax error: {e}") + return tree From c43c4d9afe4a26c8c4226b7bada7f99ae83b13af Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 3 Jul 2024 13:49:56 +0300 Subject: [PATCH 225/309] fix asset replace --- packages/syft/src/syft/service/dataset/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 9e5bc47eb19..5924106a484 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -679,7 +679,7 @@ def add_asset( else: self.asset_list[i] = asset return SyftSuccess( - f"Asset {asset.name} has been successfully replaced." + message=f"Asset {asset.name} has been successfully replaced." ) self.asset_list.append(asset) From 57175333d629c241689c7b838a01deb6c1f41fd2 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Wed, 3 Jul 2024 17:12:52 +0530 Subject: [PATCH 226/309] Add Hot Reload Functionality to Jupyter Notebooks. Co-authored-by: Tauquir <30658453+itstauq@users.noreply.github.com> --- packages/syft/src/syft/node/server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index cf0bf4370d5..6ceaefa977c 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -22,6 +22,7 @@ # relative from ..abstract_node import NodeSideType from ..client.client import API_PATH +from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT from ..util.util import os_name from .domain import Domain @@ -148,6 +149,10 @@ def serve_node( association_request_auto_approval: bool = False, background_tasks: bool = False, ) -> tuple[Callable, Callable]: + # Enable IPython autoreload if dev_mode is enabled. + if dev_mode: + enable_autoreload() + server_process = multiprocessing.Process( target=run_uvicorn, kwargs={ From 4ee5b543f4e9abe891845dd0424563e98e925e66 Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 3 Jul 2024 15:00:03 +0300 Subject: [PATCH 227/309] repr for nested requests --- .../syft/src/syft/service/code/user_code.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 4d42dfb78d3..e7c69a2a2df 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -918,15 +918,15 @@ def _inner_repr(self, level: int = 0) -> str: def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return as_markdown_code(self._inner_repr()) - def _repr_html_(self) -> str: + def _repr_html_(self, level: int = 0) -> str: + tabs = " " * level shared_with_line = "" if len(self.output_readers) > 0 and self.output_reader_names is not None: owners_string = " and ".join([f"*{x}*" for x in self.output_reader_names]) shared_with_line += ( - f"

Custom Policy: " + f"

{tabs}Custom Policy: " f"outputs are *shared* with the owners of {owners_string} once computed

" ) - repr_str = f"""
-

UserCode

-

id: UID = {self.id}

-

service_func_name: str = {self.service_func_name}

-

shareholders: list = {self.input_owners}

-

status: list = {self.code_status}

+

{tabs}UserCode

+

{tabs}id: UID = {self.id}

+

{tabs}service_func_name: str = {self.service_func_name}

+

{tabs}shareholders: list = {self.input_owners}

+

{tabs}status: list = {self.code_status}

{shared_with_line} -

code:

+

{tabs}code:

+

+ """ + if self.nested_codes != {}: + repr_str += f""" +

{tabs}Nested Requests:

+ """ + repr_str += """ """ return repr_str - def _ipython_display_(self) -> None: + def _ipython_display_(self, level: int = 0) -> None: # third party from IPython.display import HTML from IPython.display import Markdown - # display_html() - display(HTML(self._repr_html_()), Markdown(as_markdown_code(self.raw_code))) + md = "\n".join( + [f"{' '*level}{substring}" for substring in self.raw_code.split("\n")[:-1]] + ) + display(HTML(self._repr_html_(level=level)), Markdown(as_markdown_code(md))) + if self.nested_codes is not None: + for obj, _ in self.nested_codes.values(): + code = obj.resolve + code._ipython_display_(level=level + 1) @property def show_code(self) -> CodeMarkdown: From b9d92ff9c74ed785efff40f9d17ff84f83a850ad Mon Sep 17 00:00:00 2001 From: teo Date: Wed, 3 Jul 2024 15:00:26 +0300 Subject: [PATCH 228/309] fix serialization of empty input policy --- packages/syft/src/syft/service/policy/policy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index e4800f04d6e..b2528241bb2 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -835,6 +835,7 @@ class UserInputPolicy(InputPolicy): pass +@serializable() class EmpyInputPolicy(InputPolicy): __canonical_name__ = "EmptyInputPolicy" pass From 3562d0d0aae776b727ed32134d5b2cb95f6d9fb2 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 3 Jul 2024 14:17:31 +0200 Subject: [PATCH 229/309] deployment for migration --- packages/grid/backend/grid/start.sh | 4 +- packages/grid/devspace.yaml | 7 + packages/grid/helm/examples/dev/migrated.yaml | 6 + packages/syft/src/syft/__init__.py | 1 - .../syft/src/syft/client/domain_client.py | 38 +++++- packages/syft/src/syft/client/migrations.py | 69 ---------- .../src/syft/protocol/protocol_version.json | 2 +- .../service/migration/migration_service.py | 14 +- .../migration/object_migration_state.py | 129 +++++++++++++++++- packages/syft/src/syft/util/util.py | 2 +- 10 files changed, 190 insertions(+), 82 deletions(-) create mode 100644 packages/grid/helm/examples/dev/migrated.yaml delete mode 100644 packages/syft/src/syft/client/migrations.py diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 4b3d5de4cf2..52b48537812 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -26,8 +26,8 @@ then fi export CREDENTIALS_PATH=${CREDENTIALS_PATH:-$HOME/data/creds/credentials.json} -export NODE_PRIVATE_KEY=$(python $APPDIR/grid/bootstrap.py --private_key) -export NODE_UID=$(python $APPDIR/grid/bootstrap.py --uid) +export NODE_PRIVATE_KEY=$(python $APPDIR/grid/bootstrap.py --private_key --debug) +export NODE_UID=$(python $APPDIR/grid/bootstrap.py --uid --debug) export NODE_TYPE=$NODE_TYPE echo "NODE_UID=$NODE_UID" diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 8bbf3487daf..9547f70e1b0 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -125,6 +125,13 @@ profiles: value: side: low + - name: migrated-domain + description: "Deploy a migrated domain" + patches: + - op: add + path: deployments.syft.helm.valuesFiles + value: ./helm/examples/dev/migrated.yaml + - name: domain-tunnel description: "Deploy a domain with tunneling enabled" patches: diff --git a/packages/grid/helm/examples/dev/migrated.yaml b/packages/grid/helm/examples/dev/migrated.yaml new file mode 100644 index 00000000000..133ffc83e94 --- /dev/null +++ b/packages/grid/helm/examples/dev/migrated.yaml @@ -0,0 +1,6 @@ +node: + env: + - name: NODE_UID + value: "21519e1e3e664b38a635dc951c293158" + - name: NODE_PRIVATE_KEY + value: "664e78682ff58cc07ae67778678bbc64d1017077f89218acfa28bd55fec2d413" diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 086e759eb30..23ac10fd52f 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -16,7 +16,6 @@ from .client.client import register from .client.domain_client import DomainClient from .client.gateway_client import GatewayClient -from .client.migrations import migrate from .client.registry import DomainRegistry from .client.registry import EnclaveRegistry from .client.registry import NetworkRegistry diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 6b1d96769a1..47e2b84d8b1 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -400,8 +400,42 @@ def output(self) -> APIModule | None: def migration(self) -> APIModule | None: return self._get_service_by_name_if_exists("migration") - def get_migration_data(self) -> MigrationData | SyftError: - return self.api.services.migration.get_migration_data() + def get_migration_data( + self, include_blobs: bool = True + ) -> MigrationData | SyftError: + res = self.api.services.migration.get_migration_data() + if isinstance(res, SyftError): + return res + + if include_blobs: + res.download_blobs() + + return res + + def load_migration_data(self, path: str | Path) -> SyftSuccess | SyftError: + migration_data = MigrationData.from_file(path) + if isinstance(migration_data, SyftError): + return migration_data + migration_data._set_obj_location_(self.id, self.verify_key) + + # if self.id != migration_data.node_uid: + # return SyftError( + # message=f"Migration data is not for this node. Expected {self.id}, got {migration_data.node_uid}" + # ) + + # if migration_data.root_verify_key != self.verify_key: + # return SyftError( + # message="Root verify key in migration data does not match this client's verify key" + # ) + + res = migration_data.migrate_and_upload_blobs() + if isinstance(res, SyftError): + return res + + migration_data_without_blobs = migration_data.copy_without_blobs() + return self.api.services.migration.apply_migration_data( + migration_data_without_blobs + ) def get_project( self, diff --git a/packages/syft/src/syft/client/migrations.py b/packages/syft/src/syft/client/migrations.py deleted file mode 100644 index 2deaeda2508..00000000000 --- a/packages/syft/src/syft/client/migrations.py +++ /dev/null @@ -1,69 +0,0 @@ -# stdlib -from io import BytesIO -import sys - -# relative -from ..serde.serialize import _serialize -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..types.blob_storage import BlobStorageEntry -from ..types.blob_storage import CreateBlobStorageEntry -from ..types.syft_object import Context -from ..types.syft_object import SyftObject -from .domain_client import DomainClient - - -def migrate_blob_storage_object( - from_client: DomainClient, - to_client: DomainClient, - obj: SyftObject, -) -> SyftSuccess | SyftError: - migrated_obj = obj.migrate_to(BlobStorageEntry.__version__, Context()) - uploaded_by = migrated_obj.uploaded_by - blob_retrieval = from_client.services.blob_storage.read(obj.id) - if isinstance(blob_retrieval, SyftError): - return blob_retrieval - - data = blob_retrieval.read() - serialized = _serialize(data, to_bytes=True) - size = sys.getsizeof(serialized) - blob_create = CreateBlobStorageEntry.from_blob_storage_entry(obj) - blob_create.file_size = size - - blob_deposit_object = to_client.services.blob_storage.allocate_for_user( - blob_create, uploaded_by - ) - if isinstance(blob_deposit_object, SyftError): - return blob_deposit_object - return blob_deposit_object.write(BytesIO(serialized)) - - -def migrate_blob_storage( - from_client: DomainClient, - to_client: DomainClient, - blob_storage_objects: list[SyftObject], -) -> SyftSuccess | SyftError: - for obj in blob_storage_objects: - migration_result = migrate_blob_storage_object(from_client, to_client, obj) - if isinstance(migration_result, SyftError): - return migration_result - return SyftSuccess(message="Blob storage migration successful.") - - -def migrate( - from_client: DomainClient, to_client: DomainClient -) -> SyftSuccess | SyftError: - migration_data = from_client.get_migration_data() - if isinstance(migration_data, SyftError): - return migration_data - - # Blob storage is migrated via client - blob_storage_objects = migration_data.blob_storage_objects - blob_migration_result = migrate_blob_storage( - from_client, to_client, blob_storage_objects - ) - if isinstance(blob_migration_result, SyftError): - return blob_migration_result - - # Rest of the migration data is migrated via service - return to_client.api.services.migration.apply_migration_data(migration_data) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 2865e037017..c4541e43bd3 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -437,7 +437,7 @@ "MigrationData": { "1": { "version": 1, - "hash": "492eae1c054298dc52a0bf8d5493249919d58c0201ef74c552746f89406d18cf", + "hash": "636a430748dfd2bc3860772795e7071db39d7306904c2d82531f884fcd79f3e9", "action": "add" } } diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 982e7d53d33..e10a81d052c 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -7,6 +7,7 @@ from result import Err from result import Ok from result import Result +from syft.service.user.user_service import UserService # relative from ...serde.serializable import serializable @@ -605,8 +606,13 @@ def get_migration_data( if store_metadata_result.is_err(): return SyftError(message=store_metadata_result.err()) store_metadata = store_metadata_result.ok() + root_verify_key = context.node.get_service_method( + UserService.admin_verify_key + )() return MigrationData( + node_uid=context.node.id, + signing_key=context.node.signing_key, store_objects=store_objects, metadata=store_metadata, action_objects=action_objects, @@ -623,7 +629,13 @@ def apply_migration_data( context: AuthedServiceContext, migration_data: MigrationData, ) -> SyftSuccess | SyftError: - # NOTE blob storage is migrated via client, it needs access to both source and destination blob storages. + # NOTE blob storage is migrated via client, + # it needs access to both source and destination blob storages. + if len(migration_data.blobs): + return SyftError( + message="Blob storage migration is not supported by this endpoint, " + "please use 'client.load_migration_data' instead." + ) # migrate + apply store objects migrated_objects_result = self._migrate_objects( diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index e5f37e233ac..37aee4fc542 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -1,7 +1,14 @@ -# stdlib - -# third party +from pathlib import Path +from io import BytesIO +import sys +from typing import Any +from typing_extensions import Self from result import Result +from syft.serde.deserialize import _deserialize +from syft.serde.serialize import _serialize +from syft.service.response import SyftError, SyftSuccess +from syft.types.blob_storage import BlobStorageEntry, CreateBlobStorageEntry +from syft.util.util import prompt_warning_message # relative from ...node.credentials import SyftVerifyKey @@ -10,7 +17,7 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_1, Context from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject @@ -96,11 +103,123 @@ class StoreMetadata(SyftBaseObject): @serializable() -class MigrationData(SyftBaseObject): +class MigrationData(SyftObject): __canonical_name__ = "MigrationData" __version__ = SYFT_OBJECT_VERSION_1 + node_uid: UID + root_verify_key: SyftVerifyKey store_objects: dict[type[SyftObject], list[SyftObject]] metadata: dict[type[SyftObject], StoreMetadata] action_objects: dict[type[SyftObject], list[SyftObject]] blob_storage_objects: list[SyftObject] + blobs: dict[UID, Any] = {} + + __repr_attrs__ = [ + "node_uid", + "root_verify_key", + "num_objects", + "num_action_objects", + "includes_blobs", + ] + + @property + def num_objects(self) -> int: + return sum(len(objs) for objs in self.store_objects.values()) + + @property + def num_action_objects(self) -> int: + return sum(len(objs) for objs in self.action_objects.values()) + + @property + def includes_blobs(self) -> bool: + blob_ids = [obj.id for obj in self.blob_storage_objects] + return set(self.blobs.keys()) == set(blob_ids) + + @classmethod + def from_file(self, path: str | Path) -> Self | SyftError: + path = Path(path) + if not path.exists(): + return SyftError(f"File {str(path)} does not exist.") + + with open(path, "rb") as f: + res: MigrationData = _deserialize(f.read(), from_bytes=True) + + return res + + def save(self, path: str | Path) -> SyftSuccess | SyftError: + if not self.includes_blobs: + proceed = prompt_warning_message( + "You are saving migration data without blob storage data. " + "This means that any existing blobs will be missing when you load this data." + "\nTo include blobs, call `download_blobs()` before saving.", + confirm=True, + ) + if not proceed: + return SyftError(message="Migration data not saved.") + + path = Path(path) + with open(path, "wb") as f: + f.write(_serialize(self, to_bytes=True)) + + return SyftSuccess(message=f"Migration data saved to {str(path)}.") + + def download_blobs(self) -> None | SyftError: + for obj in self.blob_storage_objects: + blob = self.download_blob(obj.id) + if isinstance(blob, SyftError): + return blob + self.blobs[obj.id] = blob + + def download_blob(self, obj_id: str) -> Any | SyftError: + api = self._get_api() + if isinstance(api, SyftError): + return api + + blob_retrieval = api.services.blob_storage.read(obj_id) + if isinstance(blob_retrieval, SyftError): + return blob_retrieval + return blob_retrieval.read() + + def migrate_and_upload_blobs(self) -> SyftSuccess | SyftError: + for obj in self.blob_storage_objects: + upload_result = self.migrate_and_upload_blob(obj) + if isinstance(upload_result, SyftError): + return upload_result + return SyftSuccess(message="All blobs uploaded successfully.") + + def migrate_and_upload_blob(self, obj: BlobStorageEntry) -> SyftSuccess | SyftError: + api = self._get_api() + if isinstance(api, SyftError): + return api + + if obj.id not in self.blobs: + return SyftError(f"Blob {obj.id} not found in migration data.") + data = self.blobs[obj.id] + + migrated_obj = obj.migrate_to(BlobStorageEntry.__version__, Context()) + serialized = _serialize(data, to_bytes=True) + size = sys.getsizeof(serialized) + blob_create = CreateBlobStorageEntry.from_blob_storage_entry(migrated_obj) + blob_create.file_size = size + blob_deposit_object = api.services.blob_storage.allocate_for_user( + blob_create, migrated_obj.uploaded_by + ) + + if isinstance(blob_deposit_object, SyftError): + return blob_deposit_object + return blob_deposit_object.write(BytesIO(serialized)) + + def copy_without_blobs(self) -> "MigrationData": + # Create a shallow copy of the MigrationData instance, removing blob-related data + # This is required for sending the MigrationData to the backend. + copy_data = self.__class__( + node_uid=self.node_uid, + root_verify_key=self.root_verify_key, + store_objects=self.store_objects.copy(), + metadata=self.metadata.copy(), + action_objects=self.action_objects.copy(), + blob_storage_objects=[], + blobs={}, + ) + return copy_data diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index d8098f55e1d..774bfe4db90 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -475,7 +475,7 @@ def prompt_warning_message(message: str, confirm: bool = False) -> bool: if response == "y": return True elif response == "n": - display("Aborted !!") + print("Aborted.") return False else: print("Invalid response. Please enter Y or N.") From f23a766dcf9f71dbd1b96bea268c8ee4fdfdbe56 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 3 Jul 2024 21:57:55 +0800 Subject: [PATCH 230/309] Test worker deleted from stash --- packages/syft/tests/syft/syft_worker_deletion_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py index d23627a9a98..5370e287851 100644 --- a/packages/syft/tests/syft/syft_worker_deletion_test.py +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -83,4 +83,4 @@ def compute_mean(data): else: assert job.status == JobStatus.COMPLETED - # assert len(node.python_node.queue_manager.consumers["api_call"]) == 0 + assert len(client.worker.get_all()) == 0 From 43e1c5977125d297c972b510e832182cb21e8b48 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 3 Jul 2024 22:08:05 +0800 Subject: [PATCH 231/309] Add test for deleting an idle worker --- .../tests/syft/syft_worker_deletion_test.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py index 5370e287851..d3118bbea47 100644 --- a/packages/syft/tests/syft/syft_worker_deletion_test.py +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -41,6 +41,26 @@ def node(node_args: dict[str, Any]) -> Generator[NodeHandle, None, None]: _node.land() +@pytest.mark.parametrize("node_args", [{"n_consumers": 1}]) +@pytest.mark.parametrize("force", [True, False]) +def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: + client = node.login(email="info@openmined.org", password="changethis") + worker = client.worker.get_all()[0] + + res = client.worker.delete(worker.id, force=force) + assert not isinstance(res, SyftError) + + if force: + assert len(client.worker.get_all()) == 0 + + start = time.time() + while True: + if len(client.worker.get_all()) == 0: + break + if time.time() - start > 3: + raise TimeoutError("Worker did not get removed from stash.") + + @pytest.mark.parametrize("node_args", [{"n_consumers": 1}]) @pytest.mark.parametrize("force", [True, False]) def test_delete_worker(node: NodeHandle, force: bool) -> None: From e1d8c910d29356a4328e6932e6350e7a3c0d1e40 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 3 Jul 2024 16:35:42 +0200 Subject: [PATCH 232/309] finalize migration for local setup --- .../version-upgrades/1c-dump-node.ipynb | 160 +++++ .../1c-migrate-to-new-node.ipynb | 659 ------------------ .../syft/src/syft/client/domain_client.py | 19 +- packages/syft/src/syft/node/node.py | 5 +- .../src/syft/protocol/protocol_version.json | 4 +- .../service/migration/migration_service.py | 4 - .../migration/object_migration_state.py | 49 +- .../src/syft/service/settings/settings.py | 71 +- 8 files changed, 262 insertions(+), 709 deletions(-) create mode 100644 notebooks/tutorials/version-upgrades/1c-dump-node.ipynb delete mode 100644 notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb diff --git a/notebooks/tutorials/version-upgrades/1c-dump-node.ipynb b/notebooks/tutorials/version-upgrades/1c-dump-node.ipynb new file mode 100644 index 00000000000..0a09e6680ea --- /dev/null +++ b/notebooks/tutorials/version-upgrades/1c-dump-node.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "from pathlib import Path\n", + "\n", + "# syft absolute\n", + "import syft as sy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "migration_data = client.get_migration_data(include_blobs=True)\n", + "migration_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "blob_path = Path(\"./my_migration.blob\")\n", + "yaml_path = Path(\"my_migration.yaml\")\n", + "\n", + "blob_path.unlink(missing_ok=True)\n", + "yaml_path.unlink(missing_ok=True)\n", + "\n", + "migration_data.save(blob_path, yaml_path=yaml_path)\n", + "\n", + "assert blob_path.exists()\n", + "assert yaml_path.exists()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "# Client side migrations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "node.land()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "new_node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " local_db=True,\n", + " n_consumers=2,\n", + " create_producer=True,\n", + " migrate=False,\n", + " reset=True,\n", + ")\n", + "\n", + "new_client = new_node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "new_client.load_migration_data(\"my_migration.blob\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb b/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb deleted file mode 100644 index 798d8630081..00000000000 --- a/notebooks/tutorials/version-upgrades/1c-migrate-to-new-node.ipynb +++ /dev/null @@ -1,659 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "import syft as sy\n", - "from syft.service.log.log import SyftLogV3\n", - "from syft.types.blob_storage import BlobStorageEntry\n", - "from syft.types.blob_storage import CreateBlobStorageEntry\n", - "from syft.types.syft_object import Context\n", - "from syft.types.syft_object import SyftObject" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "print(f\"syft version: {sy.__version__}\")" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "TODOS\n", - "- [x] action objects\n", - "- [x] maybe an example of how to migrate one object type in a custom way\n", - "- [x] check SyftObjectRegistry and compare with current implementation\n", - "- [x] run unit tests\n", - "- [ ] finalize notebooks for testing, run in CI\n", - "- [ ] other tasks defined in tickets\n", - "- [ ] also get actionobjects in get_migration_objects\n", - "- [ ] make clientside method to migrate blobstorage `migrate_blobstorage(from_client, to_client, migration_data)`\n", - "- [ ] make domainclient get_migration and apply_migration (to/from file?)\n", - " - merge with save_migration_objects_to_file / migrate_objects_from_file " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "node = sy.orchestra.launch(\n", - " name=\"test_upgradability\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, - "source": [ - "# Client side migrations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "temp_node = sy.orchestra.launch(\n", - " name=\"temp_node\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - " reset=True,\n", - ")\n", - "\n", - "temp_client = temp_node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "migration_data = client.services.migration.get_migration_data()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "from syft.client.migrations import migrate_blob_storage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "sy.migrate(from_client=client, to_client=temp_client)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "migrate_blob_storage(\n", - " from_client=client,\n", - " to_client=temp_client,\n", - " blob_storage_objects=migration_data.blob_storage_objects,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "migration_data.store_objects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "temp_client.api.services.migration.apply_migration_data(migration_data)" - ] - }, - { - "cell_type": "markdown", - "id": "13", - "metadata": {}, - "source": [ - "## document store objects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": {}, - "outputs": [], - "source": [ - "migration_dict = client.services.migration.get_migration_objects(get_all=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "migration_dict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "def custom_migration_function(context, obj: SyftObject, klass) -> SyftObject:\n", - " # Here, we are just doing the same, but this is where you would write your custom logic\n", - " return obj.migrate_to(klass.__version__, context)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "# this wont work in the cases where the context is actually used,\n", - "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", - "\n", - "\n", - "context = Context()\n", - "migrated_objects = []\n", - "for klass, objects in migration_dict.items():\n", - " for obj in objects:\n", - " if isinstance(obj, BlobStorageEntry):\n", - " continue\n", - " elif isinstance(obj, SyftLogV3):\n", - " migrated_obj = custom_migration_function(context, obj, klass)\n", - " else:\n", - " try:\n", - " migrated_obj = obj.migrate_to(klass.__version__, context)\n", - " except Exception:\n", - " print(obj.__version__, obj.__canonical_name__)\n", - " print(klass.__version__, klass.__canonical_name__)\n", - " raise\n", - "\n", - " migrated_objects.append(migrated_obj)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO what to do with workerpools\n", - "# TODO what to do with admin? @yash: can we make new node with existing verifykey?\n", - "# TODO check asset AO is not saved in blobstorage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "res = temp_client.services.migration.create_migrated_objects(migrated_objects)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "res" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(res, sy.SyftSuccess)" - ] - }, - { - "cell_type": "markdown", - "id": "22", - "metadata": {}, - "source": [ - "# Migrate blobstorage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "klass = BlobStorageEntry\n", - "blob_entries = migration_dict[klass]\n", - "obj = blob_entries[0]\n", - "obj" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "# stdlib\n", - "from io import BytesIO\n", - "import sys\n", - "\n", - "\n", - "def migrate_blob_entry_data(\n", - " old_client, new_client, obj, klass\n", - ") -> sy.SyftSuccess | sy.SyftError:\n", - " migrated_obj = obj.migrate_to(klass.__version__, Context())\n", - " uploaded_by = migrated_obj.uploaded_by\n", - " blob_retrieval = old_client.services.blob_storage.read(obj.id)\n", - " if isinstance(blob_retrieval, sy.SyftError):\n", - " return blob_retrieval\n", - "\n", - " data = blob_retrieval.read()\n", - " # TODO do we have to determine new filesize here?\n", - " serialized = sy.serialize(data, to_bytes=True)\n", - " size = sys.getsizeof(serialized)\n", - " blob_create = CreateBlobStorageEntry.from_blob_storage_entry(obj)\n", - " blob_create.file_size = size\n", - "\n", - " blob_deposit_object = new_client.services.blob_storage.allocate_for_user(\n", - " blob_create, uploaded_by\n", - " )\n", - " if isinstance(blob_deposit_object, sy.SyftError):\n", - " return blob_deposit_object\n", - " return blob_deposit_object.write(BytesIO(serialized))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "for blob_entry in blob_entries:\n", - " res = migrate_blob_entry_data(client, temp_client, blob_entry, BlobStorageEntry)\n", - " display(res)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [ - "client.services.blob_storage.get_all()" - ] - }, - { - "cell_type": "markdown", - "id": "28", - "metadata": {}, - "source": [ - "## Actions and ActionObjects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29", - "metadata": {}, - "outputs": [], - "source": [ - "migration_action_dict = client.services.migration.get_migration_actionobjects(\n", - " get_all=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": {}, - "outputs": [], - "source": [ - "ao = migration_action_dict[list(migration_action_dict.keys())[0]][0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": {}, - "outputs": [], - "source": [ - "ao.syft_action_data_cache" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "node.python_node.action_store.data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [ - "client.jobs[0].result.id.id" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34", - "metadata": {}, - "outputs": [], - "source": [ - "node.python_node.action_store.data[sy.UID(\"106b561961c74a46afc63c5c73c24212\")].__dict__" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35", - "metadata": {}, - "outputs": [], - "source": [ - "# this wont work in the cases where the context is actually used, but since this you would need custom logic here anyway\n", - "# it doesnt matter\n", - "context = Context()\n", - "migrated_actionobjects = []\n", - "for klass, objects in migration_action_dict.items():\n", - " for obj in objects:\n", - " # custom migration logic here\n", - " migrated_actionobject = obj.migrate_to(klass.__version__, context)\n", - " migrated_actionobjects.append(migrated_actionobject)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36", - "metadata": {}, - "outputs": [], - "source": [ - "print(migrated_actionobjects)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37", - "metadata": {}, - "outputs": [], - "source": [ - "res = temp_client.services.migration.update_migrated_actionobjects(\n", - " migrated_actionobjects\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "38", - "metadata": {}, - "outputs": [], - "source": [ - "res" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(res, sy.SyftSuccess)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40", - "metadata": {}, - "outputs": [], - "source": [ - "for uid in temp_node.python_node.action_store.data:\n", - " ao = temp_client.services.action.get(uid)\n", - " ao.reload_cache()\n", - " print(ao.syft_action_data_cache)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41", - "metadata": {}, - "outputs": [], - "source": [ - "for uid in node.python_node.action_store.data:\n", - " ao = client.services.action.get(uid)\n", - " ao.reload_cache()\n", - " print(ao.syft_action_data_cache)" - ] - }, - { - "cell_type": "markdown", - "id": "42", - "metadata": {}, - "source": [ - "## Store metadata\n", - "\n", - "- Permissions\n", - "- StoragePermissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43", - "metadata": {}, - "outputs": [], - "source": [ - "store_metadata = client.services.migration.get_all_store_metadata()\n", - "store_metadata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44", - "metadata": {}, - "outputs": [], - "source": [ - "for k, v in store_metadata.items():\n", - " if len(v.permissions):\n", - " print(\n", - " k, len(v.permissions), len(v.permissions) == len(migration_dict.get(k, []))\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45", - "metadata": {}, - "outputs": [], - "source": [ - "# Test update method with a temp node\n", - "# After update, all metadata should match between the nodes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "46", - "metadata": {}, - "outputs": [], - "source": [ - "temp_client.services.migration.update_store_metadata(store_metadata)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47", - "metadata": {}, - "outputs": [], - "source": [ - "for cname, real_partition in node.python_node.document_store.partitions.items():\n", - " temp_partition = temp_node.python_node.document_store.partitions[cname]\n", - "\n", - " temp_perms = dict(temp_partition.permissions.items())\n", - " real_perms = dict(real_partition.permissions.items())\n", - "\n", - " for k, temp_v in temp_perms.items():\n", - " if k not in real_perms:\n", - " continue\n", - " real_v = real_perms[k]\n", - " assert real_v.issubset(temp_v)\n", - "\n", - " temp_storage = dict(temp_partition.storage_permissions.items())\n", - " real_storage = dict(real_partition.storage_permissions.items())\n", - " for k, temp_v in temp_storage.items():\n", - " if k not in real_storage:\n", - " continue\n", - " real_v = real_storage[k]\n", - " assert real_v.issubset(temp_v)\n", - "\n", - "# Action store\n", - "real_partition = node.python_node.action_store\n", - "temp_partition = temp_node.python_node.action_store\n", - "temp_perms = dict(temp_partition.permissions.items())\n", - "real_perms = dict(real_partition.permissions.items())\n", - "\n", - "# Only look at migrated items\n", - "temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", - "for k, temp_v in temp_perms.items():\n", - " if k not in real_perms:\n", - " continue\n", - " real_v = real_perms[k]\n", - " assert real_v.issubset(temp_v)\n", - "\n", - "temp_storage = dict(temp_partition.storage_permissions.items())\n", - "real_storage = dict(real_partition.storage_permissions.items())\n", - "for k, temp_v in temp_storage.items():\n", - " if k not in real_storage:\n", - " continue\n", - " real_v = real_storage[k]\n", - " assert real_v.issubset(temp_v)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 47e2b84d8b1..30b6fad8ae0 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -418,15 +418,16 @@ def load_migration_data(self, path: str | Path) -> SyftSuccess | SyftError: return migration_data migration_data._set_obj_location_(self.id, self.verify_key) - # if self.id != migration_data.node_uid: - # return SyftError( - # message=f"Migration data is not for this node. Expected {self.id}, got {migration_data.node_uid}" - # ) - - # if migration_data.root_verify_key != self.verify_key: - # return SyftError( - # message="Root verify key in migration data does not match this client's verify key" - # ) + if self.id != migration_data.node_uid: + return SyftError( + message=f"This Migration data is not for this node. Expected node id {self.id}, " + f"got {migration_data.node_uid}" + ) + + if migration_data.signing_key.verify_key != self.verify_key: + return SyftError( + message="Root verify key in migration data does not match this client's verify key" + ) res = migration_data.migrate_and_upload_blobs() if isinstance(res, SyftError): diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index d46dd1b0c64..75889600500 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -25,7 +25,6 @@ from nacl.signing import SigningKey from result import Err from result import Result -from typing_extensions import Self # relative from .. import __version__ @@ -645,7 +644,7 @@ def remove_consumer_with_id(self, syft_worker_id: UID) -> None: @classmethod def named( - cls, + cls: type[Node], *, # Trasterisk name: str, processes: int = 0, @@ -663,7 +662,7 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, - ) -> Self: + ) -> Node: uid = UID.with_seed(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() key = SyftSigningKey(signing_key=SigningKey(name_hash)) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index c4541e43bd3..ff6e94eafab 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -55,7 +55,7 @@ "3": { "version": 3, "hash": "2d5f6e79f074f75b5cfc2357eac7cf635b8f083421009a513240b4dbbd5a0fc1", - "action": "remove" + "action": "add" }, "5": { "version": 5, @@ -437,7 +437,7 @@ "MigrationData": { "1": { "version": 1, - "hash": "636a430748dfd2bc3860772795e7071db39d7306904c2d82531f884fcd79f3e9", + "hash": "c5be6bb4f34b04f814e15468d5231e47540c5b7d2ea0f2770e6cd332f61173c7", "action": "add" } } diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index e10a81d052c..69eb482b248 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -7,7 +7,6 @@ from result import Err from result import Ok from result import Result -from syft.service.user.user_service import UserService # relative from ...serde.serializable import serializable @@ -606,9 +605,6 @@ def get_migration_data( if store_metadata_result.is_err(): return SyftError(message=store_metadata_result.err()) store_metadata = store_metadata_result.ok() - root_verify_key = context.node.get_service_method( - UserService.admin_verify_key - )() return MigrationData( node_uid=context.node.id, diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 37aee4fc542..79a4fb2aebd 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -1,29 +1,37 @@ -from pathlib import Path +# stdlib from io import BytesIO +from pathlib import Path import sys from typing import Any -from typing_extensions import Self + +# third party from result import Result -from syft.serde.deserialize import _deserialize -from syft.serde.serialize import _serialize -from syft.service.response import SyftError, SyftSuccess -from syft.types.blob_storage import BlobStorageEntry, CreateBlobStorageEntry -from syft.util.util import prompt_warning_message +from typing_extensions import Self +import yaml # relative +from ...node.credentials import SyftSigningKey from ...node.credentials import SyftVerifyKey +from ...serde.deserialize import _deserialize from ...serde.serializable import serializable +from ...serde.serialize import _serialize from ...store.document_store import BaseStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_1, Context +from ...types.blob_storage import BlobStorageEntry +from ...types.blob_storage import CreateBlobStorageEntry +from ...types.syft_object import Context +from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry from ...types.uid import UID +from ...util.util import prompt_warning_message from ..action.action_permissions import ActionObjectPermission +from ..response import SyftError +from ..response import SyftSuccess @serializable() @@ -108,7 +116,7 @@ class MigrationData(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 node_uid: UID - root_verify_key: SyftVerifyKey + signing_key: SyftSigningKey store_objects: dict[type[SyftObject], list[SyftObject]] metadata: dict[type[SyftObject], StoreMetadata] action_objects: dict[type[SyftObject], list[SyftObject]] @@ -136,6 +144,19 @@ def includes_blobs(self) -> bool: blob_ids = [obj.id for obj in self.blob_storage_objects] return set(self.blobs.keys()) == set(blob_ids) + def make_migration_config(self) -> dict[str, Any]: + node_uid = self.node_uid.to_string() + node_private_key = str(self.signing_key) + migration_config = { + "node": { + "env": [ + {"name": "NODE_UID", "value": node_uid}, + {"name": "NODE_PRIVATE_KEY", "value": node_private_key}, + ] + } + } + return migration_config + @classmethod def from_file(self, path: str | Path) -> Self | SyftError: path = Path(path) @@ -147,7 +168,7 @@ def from_file(self, path: str | Path) -> Self | SyftError: return res - def save(self, path: str | Path) -> SyftSuccess | SyftError: + def save(self, path: str | Path, yaml_path: str | Path) -> SyftSuccess | SyftError: if not self.includes_blobs: proceed = prompt_warning_message( "You are saving migration data without blob storage data. " @@ -162,6 +183,11 @@ def save(self, path: str | Path) -> SyftSuccess | SyftError: with open(path, "wb") as f: f.write(_serialize(self, to_bytes=True)) + yaml_path = Path(yaml_path) + migration_config = self.make_migration_config() + with open(yaml_path, "w") as f: + yaml.dump(migration_config, f) + return SyftSuccess(message=f"Migration data saved to {str(path)}.") def download_blobs(self) -> None | SyftError: @@ -170,6 +196,7 @@ def download_blobs(self) -> None | SyftError: if isinstance(blob, SyftError): return blob self.blobs[obj.id] = blob + return None def download_blob(self, obj_id: str) -> Any | SyftError: api = self._get_api() @@ -215,7 +242,7 @@ def copy_without_blobs(self) -> "MigrationData": # This is required for sending the MigrationData to the backend. copy_data = self.__class__( node_uid=self.node_uid, - root_verify_key=self.root_verify_key, + signing_key=self.signing_key, store_objects=self.store_objects.copy(), metadata=self.metadata.copy(), action_objects=self.action_objects.copy(), diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index 2be1dc6b5d5..7f09ef59a46 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -16,11 +16,11 @@ from ...types.syft_metaclass import Empty from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SYFT_OBJECT_VERSION_5 from ...types.syft_object import SYFT_OBJECT_VERSION_6 from ...types.syft_object import SyftObject -from ...types.transforms import drop from ...types.transforms import make_set_default from ...types.uid import UID from ...util import options @@ -161,7 +161,55 @@ class NodeSettingsV5(SyftObject): ) -@migrate(NodeSettingsV5, NodeSettings) +@serializable() +class NodeSettingsV2(SyftObject): + __canonical_name__ = "NodeSettings" + __version__ = SYFT_OBJECT_VERSION_3 + __repr_attrs__ = [ + "name", + "organization", + "deployed_on", + "signup_enabled", + "admin_email", + ] + + id: UID + name: str = "Node" + deployed_on: str + organization: str = "OpenMined" + verify_key: SyftVerifyKey + on_board: bool = True + description: str = "Text" + node_type: NodeType = NodeType.DOMAIN + signup_enabled: bool + admin_email: str + node_side_type: NodeSideType = NodeSideType.HIGH_SIDE + show_warnings: bool + + +# @migrate(NodeSettingsV3, NodeSettingsV5) +# def upgrade_node_settings() -> list[Callable]: +# return [ +# make_set_default("association_request_auto_approval", False), +# make_set_default( +# "default_worker_pool", +# get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME), +# ), +# ] + +# @migrate(NodeSettingsV3, NodeSettings) +# def upgrade_node_settings_v3_to_v6() -> list[Callable]: +# return [ +# make_set_default("association_request_auto_approval", False), +# make_set_default( +# "default_worker_pool", +# get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME), +# ), +# make_set_default("eager_execution_enabled", False), +# ] + + +@migrate(NodeSettingsV2, NodeSettings) def upgrade_node_settings() -> list[Callable]: return [ make_set_default("association_request_auto_approval", False), @@ -171,22 +219,3 @@ def upgrade_node_settings() -> list[Callable]: ), make_set_default("eager_execution_enabled", False), ] - - -@migrate(NodeSettings, NodeSettingsV5) -def downgrade_node_settings() -> list[Callable]: - return [drop(["eager_execution_enabled"])] - - -@migrate(NodeSettingsUpdateV4, NodeSettingsUpdate) -def upgrade_node_settings_update() -> list[Callable]: - return [] - - -@migrate(NodeSettingsUpdate, NodeSettingsUpdateV4) -def downgrade_node_settings_update() -> list[Callable]: - return [ - drop(["association_request_auto_approval"]), - drop(["default_worker_pool"]), - drop(["eager_execution_enabled"]), - ] From 0479e1897956909da66fe6581bf7ec43f9516947 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 3 Jul 2024 16:40:27 +0200 Subject: [PATCH 233/309] remove debug flag --- packages/grid/backend/grid/start.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 52b48537812..4b3d5de4cf2 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -26,8 +26,8 @@ then fi export CREDENTIALS_PATH=${CREDENTIALS_PATH:-$HOME/data/creds/credentials.json} -export NODE_PRIVATE_KEY=$(python $APPDIR/grid/bootstrap.py --private_key --debug) -export NODE_UID=$(python $APPDIR/grid/bootstrap.py --uid --debug) +export NODE_PRIVATE_KEY=$(python $APPDIR/grid/bootstrap.py --private_key) +export NODE_UID=$(python $APPDIR/grid/bootstrap.py --uid) export NODE_TYPE=$NODE_TYPE echo "NODE_UID=$NODE_UID" From b7d02e8a69354a39280a1c29b41361bdf49d6357 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 3 Jul 2024 23:15:50 +0530 Subject: [PATCH 234/309] remove flakyness by waiting for worker to be deleted in worker delete test --- packages/syft/tests/syft/syft_worker_deletion_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py index d3118bbea47..c1eb9c4028d 100644 --- a/packages/syft/tests/syft/syft_worker_deletion_test.py +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -103,4 +103,12 @@ def compute_mean(data): else: assert job.status == JobStatus.COMPLETED + start = time.time() + while True: + res = client.worker.get(syft_worker_id) + if isinstance(res, SyftError): + break + if time.time() - start > 5: + raise TimeoutError("Worker did not get removed from stash.") + assert len(client.worker.get_all()) == 0 From a79abc1cb625bcd912ba7e00409d52aca1c95773 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Wed, 3 Jul 2024 16:37:47 -0300 Subject: [PATCH 235/309] Add info message when launching node in Python mode --- packages/syft/src/syft/orchestra.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 08672657762..c6c70dbab42 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -13,6 +13,9 @@ import sys from typing import Any +# third party +from IPython.display import display + # relative from .abstract_node import NodeSideType from .abstract_node import NodeType @@ -23,6 +26,7 @@ from .node.server import serve_node from .protocol.data_protocol import stage_protocol_changes from .service.response import SyftError +from .service.response import SyftInfo from .util.util import get_random_available_port logger = logging.getLogger(__name__) @@ -302,7 +306,7 @@ def launch( ) if deployment_type_enum == DeploymentType.PYTHON: - return deploy_to_python( + node_handle = deploy_to_python( node_type_enum=node_type_enum, deployment_type_enum=deployment_type_enum, port=port, @@ -322,6 +326,13 @@ def launch( association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, ) + display( + SyftInfo( + message=f"You have created a development node at http://{host}:{node_handle.port}.\ + It is intended only for local use." + ) + ) + return node_handle elif deployment_type_enum == DeploymentType.REMOTE: return deploy_to_remote( node_type_enum=node_type_enum, From 9147124a0c3c2968a2d17401371617e4be663220 Mon Sep 17 00:00:00 2001 From: Julian Cardonnet Date: Wed, 3 Jul 2024 17:05:21 -0300 Subject: [PATCH 236/309] Fix info message when launching node in python mode --- packages/syft/src/syft/orchestra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index ef460be3fd1..426b6b2d370 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -332,8 +332,8 @@ def launch( ) display( SyftInfo( - message=f"You have created a development node at http://{host}:{node_handle.port}.\ - It is intended only for local use." + message=f"You have launched a development node at http://{host}:{node_handle.port}." + + "It is intended only for local use." ) ) return node_handle From 239db5d46c1b4588a6a1b80097d1c4aba3222df1 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 10:27:23 +0800 Subject: [PATCH 237/309] Test worker deletion with more node settings dev_movde True/False and thread_workers True/False --- .../tests/syft/syft_worker_deletion_test.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/packages/syft/tests/syft/syft_worker_deletion_test.py index c1eb9c4028d..42c23a121c9 100644 --- a/packages/syft/tests/syft/syft_worker_deletion_test.py +++ b/packages/syft/tests/syft/syft_worker_deletion_test.py @@ -1,5 +1,7 @@ # stdlib from collections.abc import Generator +from collections.abc import Iterable +from itertools import product from secrets import token_hex import time from typing import Any @@ -41,7 +43,19 @@ def node(node_args: dict[str, Any]) -> Generator[NodeHandle, None, None]: _node.land() -@pytest.mark.parametrize("node_args", [{"n_consumers": 1}]) +def node_args_combinations(**kwargs: Iterable) -> Iterable[dict[str, Any]]: + args = ([(k, v) for v in vs] for k, vs in kwargs.items()) + return (dict(kvs) for kvs in product(*args)) + + +NODE_ARGS_TEST_CASES = node_args_combinations( + n_consumers=[1], + dev_mode=[True, False], + thread_workers=[True, False], +) + + +@pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: client = node.login(email="info@openmined.org", password="changethis") @@ -61,7 +75,7 @@ def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: raise TimeoutError("Worker did not get removed from stash.") -@pytest.mark.parametrize("node_args", [{"n_consumers": 1}]) +@pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_worker(node: NodeHandle, force: bool) -> None: client = node.login(email="info@openmined.org", password="changethis") From 0ff4c8f65d28d9393e9ddd3692b295956fca8aed Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 10:28:28 +0800 Subject: [PATCH 238/309] Move worker deletion test to integration tests --- .../syft => tests/integration/local}/syft_worker_deletion_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {packages/syft/tests/syft => tests/integration/local}/syft_worker_deletion_test.py (100%) diff --git a/packages/syft/tests/syft/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py similarity index 100% rename from packages/syft/tests/syft/syft_worker_deletion_test.py rename to tests/integration/local/syft_worker_deletion_test.py From dd0a253a968d973b71cea5993bc9486f1ae430ca Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 10:35:17 +0800 Subject: [PATCH 239/309] Reuse the fixture in multiple tests return a list instead of a generator so it persists and can be reused --- tests/integration/local/syft_worker_deletion_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/local/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py index 42c23a121c9..115009f5219 100644 --- a/tests/integration/local/syft_worker_deletion_test.py +++ b/tests/integration/local/syft_worker_deletion_test.py @@ -43,9 +43,9 @@ def node(node_args: dict[str, Any]) -> Generator[NodeHandle, None, None]: _node.land() -def node_args_combinations(**kwargs: Iterable) -> Iterable[dict[str, Any]]: +def node_args_combinations(**kwargs: Iterable) -> list[dict[str, Any]]: args = ([(k, v) for v in vs] for k, vs in kwargs.items()) - return (dict(kvs) for kvs in product(*args)) + return [dict(kvs) for kvs in product(*args)] NODE_ARGS_TEST_CASES = node_args_combinations( From 6d914aada0d63af34d77adf22685164e4fb661e3 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 10:36:18 +0800 Subject: [PATCH 240/309] Only test with dev_mode=True for now --- tests/integration/local/syft_worker_deletion_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/local/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py index 115009f5219..b159c328bf4 100644 --- a/tests/integration/local/syft_worker_deletion_test.py +++ b/tests/integration/local/syft_worker_deletion_test.py @@ -50,7 +50,7 @@ def node_args_combinations(**kwargs: Iterable) -> list[dict[str, Any]]: NODE_ARGS_TEST_CASES = node_args_combinations( n_consumers=[1], - dev_mode=[True, False], + # dev_mode=[True, False], thread_workers=[True, False], ) From 391ffc3cbcc68a9af2fec20fefab68ae56551d4c Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 11:00:43 +0800 Subject: [PATCH 241/309] Make the tests run in CI --- tests/integration/local/syft_worker_deletion_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/local/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py index b159c328bf4..4d8e971df2d 100644 --- a/tests/integration/local/syft_worker_deletion_test.py +++ b/tests/integration/local/syft_worker_deletion_test.py @@ -55,6 +55,7 @@ def node_args_combinations(**kwargs: Iterable) -> list[dict[str, Any]]: ) +@pytest.mark.local_node @pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: @@ -75,6 +76,7 @@ def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: raise TimeoutError("Worker did not get removed from stash.") +@pytest.mark.local_node @pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_worker(node: NodeHandle, force: bool) -> None: From 4ea8573f8f9fcd1822d94a3c10339aa67afcf6aa Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 11:14:15 +0800 Subject: [PATCH 242/309] Use a module level pytest mark instead of manually adding the mark to every test in the file --- tests/integration/local/syft_worker_deletion_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration/local/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py index 4d8e971df2d..fa94bc8b3b0 100644 --- a/tests/integration/local/syft_worker_deletion_test.py +++ b/tests/integration/local/syft_worker_deletion_test.py @@ -16,6 +16,9 @@ from syft.service.job.job_stash import JobStatus from syft.service.response import SyftError +# equivalent to adding this mark to every test in this file +pytestmark = pytest.mark.local_node + @pytest.fixture() def node_args() -> dict[str, Any]: @@ -55,7 +58,6 @@ def node_args_combinations(**kwargs: Iterable) -> list[dict[str, Any]]: ) -@pytest.mark.local_node @pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: @@ -76,7 +78,6 @@ def test_delete_idle_worker(node: NodeHandle, force: bool) -> None: raise TimeoutError("Worker did not get removed from stash.") -@pytest.mark.local_node @pytest.mark.parametrize("node_args", NODE_ARGS_TEST_CASES) @pytest.mark.parametrize("force", [True, False]) def test_delete_worker(node: NodeHandle, force: bool) -> None: From 2b27c5d69666cefeb5701130fce94d46168919c0 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 4 Jul 2024 11:14:06 +0530 Subject: [PATCH 243/309] keep SyftWorker version 2 and define migrations b/w version 2 and 3 --- .../src/syft/protocol/protocol_version.json | 7 +++ .../src/syft/service/worker/worker_pool.py | 49 ++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index df9c10d14bc..b2fbb6784b9 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -416,6 +416,13 @@ "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", "action": "add" } + }, + "SyftWorker": { + "3": { + "version": 3, + "hash": "e124f56ddf4565df2be056553eecd15de7c80bd5f5fd0d06e8ff7815bb05563a", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index d7e6c63c560..c0185aa1011 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from enum import Enum from typing import Any from typing import cast @@ -13,10 +14,13 @@ from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel from ...types.datetime import DateTime +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.syft_object import short_uid +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import UID from ...util import options from ...util.colors import SURFACE @@ -47,13 +51,42 @@ class WorkerHealth(Enum): UNHEALTHY = "❌" +@serializable() +class SyftWorkerV2(SyftObject): + __canonical_name__ = "SyftWorker" + __version__ = SYFT_OBJECT_VERSION_2 + + __attr_unique__ = ["name"] + __attr_searchable__ = ["name", "container_id"] + __repr_attrs__ = [ + "name", + "container_id", + "image", + "status", + "healthcheck", + "worker_pool_name", + "created_at", + ] + + id: UID + name: str + container_id: str | None = None + created_at: DateTime = DateTime.now() + healthcheck: WorkerHealth | None = None + status: WorkerStatus + image: SyftWorkerImage | None = None + worker_pool_name: str + consumer_state: ConsumerState = ConsumerState.DETACHED + job_id: UID | None = None + + @serializable() class SyftWorker(SyftObject): __canonical_name__ = "SyftWorker" __version__ = SYFT_OBJECT_VERSION_3 __attr_unique__ = ["name"] - __attr_searchable__ = ["name", "container_id"] + __attr_searchable__ = ["name", "container_id", "to_be_deleted"] __repr_attrs__ = [ "name", "container_id", @@ -315,3 +348,17 @@ def _get_worker_container_status( container_status, SyftError(message=f"Unknown container status: {container_status}"), ) + + +@migrate(SyftWorkerV2, SyftWorker) +def upgrade_syft_worker() -> list[Callable]: + return [ + make_set_default("to_be_deleted", False), + ] + + +@migrate(SyftWorker, SyftWorkerV2) +def downgrade_syft_worker() -> list[Callable]: + return [ + drop(["to_be_deleted"]), + ] From 504498baf2cd2775af1cde030ccad5d9910eca83 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 09:20:41 +0200 Subject: [PATCH 244/309] disable migration by default --- packages/syft/src/syft/orchestra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 73ebf756cb0..2879a6a571e 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -170,7 +170,7 @@ def deploy_to_python( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, - migrate: bool = True, + migrate: bool = False, ) -> NodeHandle: worker_classes = { NodeType.DOMAIN: Domain, @@ -293,7 +293,7 @@ def launch( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, - migrate: bool = True, + migrate: bool = False, ) -> NodeHandle: if dev_mode is True: thread_workers = True From 93f009c2ae744d41406c53a31e3cdcbb4619a047 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 16:04:27 +0800 Subject: [PATCH 245/309] Fix closing ZMQProducer/Consumer Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/queue/zmq_queue.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 74d64ad9966..de409dda676 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -6,6 +6,7 @@ import socketserver import sys import threading +from threading import Event import time from time import sleep from typing import Any @@ -51,7 +52,7 @@ HEARTBEAT_INTERVAL_SEC = 2 # Thread join timeout (in seconds) -THREAD_TIMEOUT_SEC = 5 +THREAD_TIMEOUT_SEC = 30 # Max duration (in ms) to wait for ZMQ poller to return ZMQ_POLLER_TIMEOUT_MSEC = 1000 @@ -163,7 +164,7 @@ def __init__( self.worker_stash = worker_stash self.queue_name = queue_name self.auth_context = context - self._stop = threading.Event() + self._stop = Event() self.post_init() @property @@ -189,24 +190,33 @@ def post_init(self) -> None: def close(self) -> None: self._stop.set() - try: - self.poll_workers.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister poller.", exc_info=e) - finally: if self.thread: self.thread.join(THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQProducer message sending thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) self.thread = None if self.producer_thread: self.producer_thread.join(THREAD_TIMEOUT_SEC) + if self.producer_thread.is_alive(): + logger.error( + f"ZMQProducer queue thread join timed out during closing. " + f"Queue name {self.queue_name}, " + ) self.producer_thread = None + self.poll_workers.unregister(self.socket) + except Exception as e: + logger.exception("Failed to unregister poller.", exc_info=e) + finally: self.socket.close() self.context.destroy() - self._stop.clear() + # self._stop.clear() @property def action_service(self) -> AbstractService: @@ -675,7 +685,7 @@ def __init__( self.socket = None self.verbose = verbose self.id = UID().short() - self._stop = threading.Event() + self._stop = Event() self.syft_worker_id = syft_worker_id self.worker_stash = worker_stash self.post_init() @@ -712,16 +722,22 @@ def close(self) -> None: self.disconnect_from_producer() self._stop.set() try: - self.poller.unregister(self.socket) - except Exception as e: - logger.exception("Failed to unregister worker.", exc_info=e) - finally: if self.thread is not None: self.thread.join(timeout=THREAD_TIMEOUT_SEC) + if self.thread.is_alive(): + logger.error( + f"ZMQConsumer thread join timed out during closing. " + f"SyftWorker id {self.syft_worker_id}, " + f"service name {self.service_name}." + ) self.thread = None + self.poller.unregister(self.socket) + except Exception as e: + logger.error("Failed to unregister worker.", exc_info=e) + finally: self.socket.close() self.context.destroy() - self._stop.clear() + # self._stop.clear() def send_to_producer( self, @@ -814,7 +830,8 @@ def _run(self) -> None: self.reconnect_to_producer() self.set_producer_alive() - self.send_heartbeat() + if not self._stop.is_set(): + self.send_heartbeat() except zmq.ZMQError as e: if e.errno == zmq.ETERM: From 922787f1f7a067b2483a76babc7068dd91a427c1 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 4 Jul 2024 16:05:52 +0800 Subject: [PATCH 246/309] Expand worker deletion test conditions --- .../local/syft_worker_deletion_test.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/integration/local/syft_worker_deletion_test.py b/tests/integration/local/syft_worker_deletion_test.py index fa94bc8b3b0..0094caf1b2c 100644 --- a/tests/integration/local/syft_worker_deletion_test.py +++ b/tests/integration/local/syft_worker_deletion_test.py @@ -46,14 +46,30 @@ def node(node_args: dict[str, Any]) -> Generator[NodeHandle, None, None]: _node.land() -def node_args_combinations(**kwargs: Iterable) -> list[dict[str, Any]]: +def matrix( + *, + excludes_: Iterable[dict[str, Any]] | None = None, + **kwargs: Iterable, +) -> list[dict[str, Any]]: args = ([(k, v) for v in vs] for k, vs in kwargs.items()) - return [dict(kvs) for kvs in product(*args)] + args = product(*args) + if excludes_ is None: + excludes_ = [] + excludes_ = [kv.items() for kv in excludes_] -NODE_ARGS_TEST_CASES = node_args_combinations( + args = ( + arg + for arg in args + if not any(all(kv in arg for kv in kvs) for kvs in excludes_) + ) + + return [dict(kvs) for kvs in args] + + +NODE_ARGS_TEST_CASES = matrix( n_consumers=[1], - # dev_mode=[True, False], + dev_mode=[True, False], thread_workers=[True, False], ) From 19dc0cf63e451bf483eacf32a76a2abbf61fcdd2 Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Thu, 4 Jul 2024 14:33:19 +0530 Subject: [PATCH 247/309] Flush debugger print statements --- packages/syft/src/syft/node/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 13ccb4a538e..4628cae6005 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -105,11 +105,12 @@ def attach_debugger() -> None: print( "\nStarting the server with the Python Debugger enabled (`debug=True`).\n" 'To attach the debugger, open the command palette in VSCode and select "Debug: Start Debugging (F5)".\n' - f"Then, enter `{debug_port}` in the port field and press Enter.\n" + f"Then, enter `{debug_port}` in the port field and press Enter.\n", + flush=True, ) - print(f"Waiting for debugger to attach on port `{debug_port}`...") + print(f"Waiting for debugger to attach on port `{debug_port}`...", flush=True) debugpy.wait_for_client() # blocks execution until a remote debugger is attached - print("Debugger attached") + print("Debugger attached", flush=True) def run_uvicorn( From 58e8dc4052f465ec2924804675bb6de3136d2975 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 12:28:49 +0200 Subject: [PATCH 248/309] serde + permissions bugfixes --- packages/syft/src/syft/serde/recursive.py | 7 +++++++ .../src/syft/service/action/action_permissions.py | 2 ++ .../syft/src/syft/service/action/action_service.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 38127fb12b9..8424becd852 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -172,6 +172,13 @@ def recursive_serde_register( for alias in alias_fqn: TYPE_BANK[alias] = serde_attributes + # TODO Refactor alias, required for typing.Any in python 3.12, + alias_canonical_name = alias + alias_version = 1 + SyftObjectRegistry.register_cls( + alias_canonical_name, alias_version, serde_attributes + ) + def chunk_bytes( field_obj: Any, diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 6131dcf5d08..1a71b75b9ae 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -34,6 +34,8 @@ def __init__( permission: ActionPermission, credentials: SyftVerifyKey | None = None, ): + if not isinstance(uid, UID): + raise ValueError(f"uid must be of type UID not {type(uid)}") if credentials is None: if permission not in COMPOUND_ACTION_PERMISSION: raise Exception(f"{permission} not in {COMPOUND_ACTION_PERMISSION}") diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index da273e2c12b..38cdd3c5b5a 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -182,11 +182,14 @@ def _set( if isinstance(action_object, TwinObject): # give read permission to the mock blob_id = action_object.mock_obj.syft_blob_storage_entry_id - permission = ActionObjectPermission(blob_id, ActionPermission.ALL_READ) - blob_storage_service: AbstractService = context.node.get_service( - BlobStorageService - ) - blob_storage_service.stash.add_permission(permission) + if blob_id is not None: + permission = ActionObjectPermission( + blob_id, ActionPermission.ALL_READ + ) + blob_storage_service: AbstractService = context.node.get_service( + BlobStorageService + ) + blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private else: From c1b5c2ab6967c713b96979c684bf898e8fbd3ed7 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 13:43:23 +0200 Subject: [PATCH 249/309] migration fixes --- .../0-prepare-migration-data.ipynb | 6 +- ...-dump-node.ipynb => 1c-dump-to-file.ipynb} | 66 +++------ .../2a-migrate-from-file.ipynb | 134 ++++++++++++++++++ packages/grid/backend/grid/bootstrap.py | 6 +- packages/grid/helm/examples/dev/migrated.yaml | 4 +- .../src/syft/store/blob_storage/__init__.py | 2 +- 6 files changed, 163 insertions(+), 55 deletions(-) rename notebooks/tutorials/version-upgrades/{1c-dump-node.ipynb => 1c-dump-to-file.ipynb} (76%) create mode 100644 notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 2b5026016df..57d8fbc2ddb 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -129,7 +129,7 @@ " sy.Asset(\n", " name=\"numpy-data\",\n", " mock=np.array([10, 11, 12, 13, 14]),\n", - " data=np.array([15, 16, 17, 18, 19]),\n", + " data=np.array([[15, 16, 17, 18, 19] for _ in range(100_000)]),\n", " mock_is_real=True,\n", " )\n", " ],\n", @@ -202,7 +202,7 @@ "metadata": {}, "outputs": [], "source": [ - "res.get()" + "res.get().shape" ] }, { @@ -251,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1c-dump-node.ipynb b/notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb similarity index 76% rename from notebooks/tutorials/version-upgrades/1c-dump-node.ipynb rename to notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb index 0a09e6680ea..d72775ba24e 100644 --- a/notebooks/tutorials/version-upgrades/1c-dump-node.ipynb +++ b/notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb @@ -28,7 +28,9 @@ " n_consumers=2,\n", " create_producer=True,\n", " migrate=False,\n", - ")" + ")\n", + "\n", + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" ] }, { @@ -38,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + "# client = sy.login(email=\"info@openmined.org\", password=\"changethis\", port=8080)" ] }, { @@ -48,8 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "migration_data = client.get_migration_data(include_blobs=True)\n", - "migration_data" + "client.jobs[0].result.get()" ] }, { @@ -59,24 +60,18 @@ "metadata": {}, "outputs": [], "source": [ - "blob_path = Path(\"./my_migration.blob\")\n", - "yaml_path = Path(\"my_migration.yaml\")\n", - "\n", - "blob_path.unlink(missing_ok=True)\n", - "yaml_path.unlink(missing_ok=True)\n", - "\n", - "migration_data.save(blob_path, yaml_path=yaml_path)\n", - "\n", - "assert blob_path.exists()\n", - "assert yaml_path.exists()" + "migration_data = client.get_migration_data(include_blobs=True)\n", + "migration_data" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "5", "metadata": {}, + "outputs": [], "source": [ - "# Client side migrations" + "print(migration_data.blobs)" ] }, { @@ -86,7 +81,11 @@ "metadata": {}, "outputs": [], "source": [ - "node.land()" + "blob_path = Path(\"./my_migration.blob\")\n", + "yaml_path = Path(\"my_migration.yaml\")\n", + "\n", + "blob_path.unlink(missing_ok=True)\n", + "yaml_path.unlink(missing_ok=True)" ] }, { @@ -96,17 +95,10 @@ "metadata": {}, "outputs": [], "source": [ - "new_node = sy.orchestra.launch(\n", - " name=\"test_upgradability\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - " reset=True,\n", - ")\n", + "migration_data.save(blob_path, yaml_path=yaml_path)\n", "\n", - "new_client = new_node.login(email=\"info@openmined.org\", password=\"changethis\")" + "assert blob_path.exists()\n", + "assert yaml_path.exists()" ] }, { @@ -115,24 +107,6 @@ "id": "8", "metadata": {}, "outputs": [], - "source": [ - "new_client.load_migration_data(\"my_migration.blob\")" - ] - }, - { - "cell_type": "markdown", - "id": "9", - "metadata": {}, - "source": [ - "## Test" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], "source": [] } ], @@ -152,7 +126,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb new file mode 100644 index 00000000000..ce4443e600a --- /dev/null +++ b/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb @@ -0,0 +1,134 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "\n", + "# syft absolute\n", + "import syft as sy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "client = sy.login(email=\"info@openmined.org\", password=\"changethis\", port=8080)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "migration_data = client.get_migration_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "from syft.service.code.user_code import UserCode\n", + "from syft.service.user.user import User\n", + "\n", + "# Check if this is a clean node\n", + "assert len(migration_data.store_objects[User]) == 1\n", + "assert UserCode not in migration_data.store_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "client.load_migration_data(\"my_migration.blob\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "assert len(client.users.get_all()) == 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "assert len(client.jobs.get_all())\n", + "client.jobs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "assert len(client.requests.get_all())\n", + "client.requests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "new_migration_data = client.get_migration_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/grid/backend/grid/bootstrap.py b/packages/grid/backend/grid/bootstrap.py index 0da833a3a39..15859c053aa 100644 --- a/packages/grid/backend/grid/bootstrap.py +++ b/packages/grid/backend/grid/bootstrap.py @@ -101,11 +101,11 @@ def validate_private_key(private_key: str | bytes) -> str: def validate_uid(node_uid: str) -> str: try: uid = uuid.UUID(node_uid) - if node_uid == str(uid): - return str(uid) + if node_uid == uid.hex: + return uid.hex except Exception: pass - raise Exception(f"{NODE_PRIVATE_KEY} is invalid") + raise Exception(f"{NODE_UID} is invalid") def get_credential( diff --git a/packages/grid/helm/examples/dev/migrated.yaml b/packages/grid/helm/examples/dev/migrated.yaml index 133ffc83e94..6ee1083340e 100644 --- a/packages/grid/helm/examples/dev/migrated.yaml +++ b/packages/grid/helm/examples/dev/migrated.yaml @@ -1,6 +1,6 @@ node: env: - name: NODE_UID - value: "21519e1e3e664b38a635dc951c293158" + value: null - name: NODE_PRIVATE_KEY - value: "664e78682ff58cc07ae67778678bbc64d1017077f89218acfa28bd55fec2d413" + value: null diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index b87608e48e0..a8e8dac0130 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -275,7 +275,7 @@ def connect(self) -> BlobStorageConnection: class BlobStorageConfig(SyftBaseModel): client_type: type[BlobStorageClient] client_config: BlobStorageClientConfig - min_blob_size: int # in MB + min_blob_size: int = 0 # in MB @migrate(BlobRetrievalByURLV4, BlobRetrievalByURL) From aa103b834fc1aa6e67a07fe45d7828820510f773 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 13:52:53 +0200 Subject: [PATCH 250/309] update notebook --- .../2a-migrate-from-file.ipynb | 188 ++++++++++++++++-- 1 file changed, 174 insertions(+), 14 deletions(-) diff --git a/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb index ce4443e600a..3495c4a3d00 100644 --- a/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb @@ -13,10 +13,18 @@ "import syft as sy" ] }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "# Login" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -26,7 +34,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -36,7 +44,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -49,20 +57,36 @@ "assert UserCode not in migration_data.store_objects" ] }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "# Load migration data" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "6", "metadata": {}, "outputs": [], "source": [ "client.load_migration_data(\"my_migration.blob\")" ] }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "# DS login" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -72,39 +96,175 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "9", "metadata": {}, "outputs": [], "source": [ - "assert len(client.jobs.get_all())\n", - "client.jobs" + "client_ds = sy.login(email=\"ds@openmined.org\", password=\"pw\", port=8080)" ] }, { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "10", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", "metadata": {}, "outputs": [], "source": [ - "assert len(client.requests.get_all())\n", - "client.requests" + "# syft absolute\n", + "from syft.client.api import APIRegistry" ] }, { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "12", "metadata": {}, "outputs": [], "source": [ - "new_migration_data = client.get_migration_data()" + "APIRegistry.__api_registry__.keys()" ] }, { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "code = client.code.get_all()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "code.status" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "req1 = client.requests[0]\n", + "req2 = client_ds.requests[0]\n", + "assert req1.status.name == \"APPROVED\" and req2.status.name == \"APPROVED\"\n", + "assert isinstance(req1._repr_html_(), str)\n", + "assert isinstance(req2._repr_html_(), str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = client_ds.jobs\n", + "assert isinstance(jobs[0]._repr_html_(), str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "ds = client_ds.datasets\n", + "asset = ds[0].assets[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "res = client_ds.code.compute_mean(data=asset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "res.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "\n", + "assert res.shape == (100_000, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = client_ds.jobs.get_all()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "job = jobs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "job.logs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "logs = job.logs(_print=False)\n", + "assert isinstance(logs, str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", "metadata": {}, "outputs": [], "source": [] From 301cbd479601b43a6fbcd53e46e23b17fe45f3b6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 14:03:09 +0200 Subject: [PATCH 251/309] min blob storage = 0 --- packages/grid/backend/grid/core/config.py | 2 +- packages/grid/default.env | 2 +- packages/syft/src/syft/node/node.py | 2 +- packages/syft/tests/syft/settings/fixtures.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 619637c06c6..f4a0ccdfbca 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,7 +155,7 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) - MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) + MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 0)) REVERSE_TUNNEL_ENABLED: bool = str_to_bool( os.getenv("REVERSE_TUNNEL_ENABLED", "false") ) diff --git a/packages/grid/default.env b/packages/grid/default.env index 0aae09f1026..af138315f60 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -55,7 +55,7 @@ CREATE_PRODUCER=False N_CONSUMERS=1 INMEMORY_WORKERS=True ASSOCIATION_REQUEST_AUTO_APPROVAL=False -MIN_SIZE_BLOB_STORAGE_MB=16 +MIN_SIZE_BLOB_STORAGE_MB=0 # New Service Flag USE_NEW_SERVICE=False diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 9701be0f023..bda4373748f 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -492,7 +492,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ) config_ = OnDiskBlobStorageConfig( client_config=client_config, - min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16), + min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 0), ) else: config_ = config diff --git a/packages/syft/tests/syft/settings/fixtures.py b/packages/syft/tests/syft/settings/fixtures.py index 65cbf18deca..80b0e2054f6 100644 --- a/packages/syft/tests/syft/settings/fixtures.py +++ b/packages/syft/tests/syft/settings/fixtures.py @@ -66,7 +66,7 @@ def metadata_json(faker) -> NodeMetadataJSON: node_side_type=NodeSideType.LOW_SIDE.value, show_warnings=False, node_type=NodeType.DOMAIN.value, - min_size_blob_storage_mb=16, + min_size_blob_storage_mb=0, ) From f4fc8838333bd03abf6996d2a4e25b0605041745 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 14:12:49 +0200 Subject: [PATCH 252/309] disable blob/local storage tests --- .../syft/blob_storage/blob_storage_test.py | 145 +++++++++--------- 1 file changed, 70 insertions(+), 75 deletions(-) diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 1889004a47e..2d2247fb4a1 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -3,15 +3,10 @@ import random # third party -import numpy as np import pytest # syft absolute import syft as sy -from syft import ActionObject -from syft.client.domain_client import DomainClient -from syft.service.blob_storage.util import can_upload_to_blob_storage -from syft.service.blob_storage.util import min_size_for_blob_storage_upload from syft.service.context import AuthedServiceContext from syft.service.response import SyftSuccess from syft.service.user.user import UserCreate @@ -106,73 +101,73 @@ def test_blob_storage_delete(authed_context, blob_storage): blob_storage.read(authed_context, blob_deposit.blob_storage_entry_id) -def test_action_obj_send_save_to_blob_storage(worker): - # this small object should not be saved to blob storage - data_small: np.ndarray = np.array([1, 2, 3]) - action_obj = ActionObject.from_obj(data_small) - assert action_obj.dtype == data_small.dtype - root_client: DomainClient = worker.root_client - action_obj.send(root_client) - assert action_obj.syft_blob_storage_entry_id is None - - # big object that should be saved to blob storage - assert min_size_for_blob_storage_upload(root_client.api.metadata) == 16 - num_elements = 50 * 1024 * 1024 - data_big = np.random.randint(0, 100, size=num_elements) # 4 bytes per int32 - action_obj_2 = ActionObject.from_obj(data_big) - assert can_upload_to_blob_storage(action_obj_2, root_client.api.metadata) - action_obj_2.send(root_client) - assert isinstance(action_obj_2.syft_blob_storage_entry_id, sy.UID) - # get back the object from blob storage to check if it is the same - root_authed_ctx = AuthedServiceContext( - node=worker, credentials=root_client.verify_key - ) - blob_storage = worker.get_service("BlobStorageService") - syft_retrieved_data = blob_storage.read( - root_authed_ctx, action_obj_2.syft_blob_storage_entry_id - ) - assert isinstance(syft_retrieved_data, SyftObjectRetrieval) - assert all(syft_retrieved_data.read() == data_big) - - -def test_upload_dataset_save_to_blob_storage(worker): - root_client: DomainClient = worker.root_client - root_authed_ctx = AuthedServiceContext( - node=worker, credentials=root_client.verify_key - ) - dataset = sy.Dataset( - name="small_dataset", - asset_list=[ - sy.Asset( - name="small_dataset", - data=np.array([1, 2, 3]), - mock=np.array([1, 1, 1]), - ) - ], - ) - root_client.upload_dataset(dataset) - blob_storage = worker.get_service("BlobStorageService") - assert len(blob_storage.get_all_blob_storage_entries(context=root_authed_ctx)) == 0 - - num_elements = 50 * 1024 * 1024 - data_big = np.random.randint(0, 100, size=num_elements) - dataset_big = sy.Dataset( - name="big_dataset", - asset_list=[ - sy.Asset( - name="big_dataset", - data=data_big, - mock=np.array([1, 1, 1]), - ) - ], - ) - root_client.upload_dataset(dataset_big) - # the private data should be saved to the blob storage - blob_entries: list = blob_storage.get_all_blob_storage_entries( - context=root_authed_ctx - ) - assert len(blob_entries) == 1 - data_big_retrieved: SyftObjectRetrieval = blob_storage.read( - context=root_authed_ctx, uid=blob_entries[0].id - ) - assert all(data_big_retrieved.read() == data_big) +# def test_action_obj_send_save_to_blob_storage(worker): +# # this small object should not be saved to blob storage +# data_small: np.ndarray = np.array([1, 2, 3]) +# action_obj = ActionObject.from_obj(data_small) +# assert action_obj.dtype == data_small.dtype +# root_client: DomainClient = worker.root_client +# action_obj.send(root_client) +# assert action_obj.syft_blob_storage_entry_id is None + +# # big object that should be saved to blob storage +# assert min_size_for_blob_storage_upload(root_client.api.metadata) == 16 +# num_elements = 50 * 1024 * 1024 +# data_big = np.random.randint(0, 100, size=num_elements) # 4 bytes per int32 +# action_obj_2 = ActionObject.from_obj(data_big) +# assert can_upload_to_blob_storage(action_obj_2, root_client.api.metadata) +# action_obj_2.send(root_client) +# assert isinstance(action_obj_2.syft_blob_storage_entry_id, sy.UID) +# # get back the object from blob storage to check if it is the same +# root_authed_ctx = AuthedServiceContext( +# node=worker, credentials=root_client.verify_key +# ) +# blob_storage = worker.get_service("BlobStorageService") +# syft_retrieved_data = blob_storage.read( +# root_authed_ctx, action_obj_2.syft_blob_storage_entry_id +# ) +# assert isinstance(syft_retrieved_data, SyftObjectRetrieval) +# assert all(syft_retrieved_data.read() == data_big) + + +# def test_upload_dataset_save_to_blob_storage(worker): +# root_client: DomainClient = worker.root_client +# root_authed_ctx = AuthedServiceContext( +# node=worker, credentials=root_client.verify_key +# ) +# dataset = sy.Dataset( +# name="small_dataset", +# asset_list=[ +# sy.Asset( +# name="small_dataset", +# data=np.array([1, 2, 3]), +# mock=np.array([1, 1, 1]), +# ) +# ], +# ) +# root_client.upload_dataset(dataset) +# blob_storage = worker.get_service("BlobStorageService") +# assert len(blob_storage.get_all_blob_storage_entries(context=root_authed_ctx)) == 0 + +# num_elements = 50 * 1024 * 1024 +# data_big = np.random.randint(0, 100, size=num_elements) +# dataset_big = sy.Dataset( +# name="big_dataset", +# asset_list=[ +# sy.Asset( +# name="big_dataset", +# data=data_big, +# mock=np.array([1, 1, 1]), +# ) +# ], +# ) +# root_client.upload_dataset(dataset_big) +# # the private data should be saved to the blob storage +# blob_entries: list = blob_storage.get_all_blob_storage_entries( +# context=root_authed_ctx +# ) +# assert len(blob_entries) == 1 +# data_big_retrieved: SyftObjectRetrieval = blob_storage.read( +# context=root_authed_ctx, uid=blob_entries[0].id +# ) +# assert all(data_big_retrieved.read() == data_big) From e8b45afaf47320fd6dd1eca2d92336a9897cb733 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 15:29:26 +0200 Subject: [PATCH 253/309] Revert "disable blob/local storage tests" This reverts commit f4fc8838333bd03abf6996d2a4e25b0605041745. --- .../syft/blob_storage/blob_storage_test.py | 145 +++++++++--------- 1 file changed, 75 insertions(+), 70 deletions(-) diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 2d2247fb4a1..1889004a47e 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -3,10 +3,15 @@ import random # third party +import numpy as np import pytest # syft absolute import syft as sy +from syft import ActionObject +from syft.client.domain_client import DomainClient +from syft.service.blob_storage.util import can_upload_to_blob_storage +from syft.service.blob_storage.util import min_size_for_blob_storage_upload from syft.service.context import AuthedServiceContext from syft.service.response import SyftSuccess from syft.service.user.user import UserCreate @@ -101,73 +106,73 @@ def test_blob_storage_delete(authed_context, blob_storage): blob_storage.read(authed_context, blob_deposit.blob_storage_entry_id) -# def test_action_obj_send_save_to_blob_storage(worker): -# # this small object should not be saved to blob storage -# data_small: np.ndarray = np.array([1, 2, 3]) -# action_obj = ActionObject.from_obj(data_small) -# assert action_obj.dtype == data_small.dtype -# root_client: DomainClient = worker.root_client -# action_obj.send(root_client) -# assert action_obj.syft_blob_storage_entry_id is None - -# # big object that should be saved to blob storage -# assert min_size_for_blob_storage_upload(root_client.api.metadata) == 16 -# num_elements = 50 * 1024 * 1024 -# data_big = np.random.randint(0, 100, size=num_elements) # 4 bytes per int32 -# action_obj_2 = ActionObject.from_obj(data_big) -# assert can_upload_to_blob_storage(action_obj_2, root_client.api.metadata) -# action_obj_2.send(root_client) -# assert isinstance(action_obj_2.syft_blob_storage_entry_id, sy.UID) -# # get back the object from blob storage to check if it is the same -# root_authed_ctx = AuthedServiceContext( -# node=worker, credentials=root_client.verify_key -# ) -# blob_storage = worker.get_service("BlobStorageService") -# syft_retrieved_data = blob_storage.read( -# root_authed_ctx, action_obj_2.syft_blob_storage_entry_id -# ) -# assert isinstance(syft_retrieved_data, SyftObjectRetrieval) -# assert all(syft_retrieved_data.read() == data_big) - - -# def test_upload_dataset_save_to_blob_storage(worker): -# root_client: DomainClient = worker.root_client -# root_authed_ctx = AuthedServiceContext( -# node=worker, credentials=root_client.verify_key -# ) -# dataset = sy.Dataset( -# name="small_dataset", -# asset_list=[ -# sy.Asset( -# name="small_dataset", -# data=np.array([1, 2, 3]), -# mock=np.array([1, 1, 1]), -# ) -# ], -# ) -# root_client.upload_dataset(dataset) -# blob_storage = worker.get_service("BlobStorageService") -# assert len(blob_storage.get_all_blob_storage_entries(context=root_authed_ctx)) == 0 - -# num_elements = 50 * 1024 * 1024 -# data_big = np.random.randint(0, 100, size=num_elements) -# dataset_big = sy.Dataset( -# name="big_dataset", -# asset_list=[ -# sy.Asset( -# name="big_dataset", -# data=data_big, -# mock=np.array([1, 1, 1]), -# ) -# ], -# ) -# root_client.upload_dataset(dataset_big) -# # the private data should be saved to the blob storage -# blob_entries: list = blob_storage.get_all_blob_storage_entries( -# context=root_authed_ctx -# ) -# assert len(blob_entries) == 1 -# data_big_retrieved: SyftObjectRetrieval = blob_storage.read( -# context=root_authed_ctx, uid=blob_entries[0].id -# ) -# assert all(data_big_retrieved.read() == data_big) +def test_action_obj_send_save_to_blob_storage(worker): + # this small object should not be saved to blob storage + data_small: np.ndarray = np.array([1, 2, 3]) + action_obj = ActionObject.from_obj(data_small) + assert action_obj.dtype == data_small.dtype + root_client: DomainClient = worker.root_client + action_obj.send(root_client) + assert action_obj.syft_blob_storage_entry_id is None + + # big object that should be saved to blob storage + assert min_size_for_blob_storage_upload(root_client.api.metadata) == 16 + num_elements = 50 * 1024 * 1024 + data_big = np.random.randint(0, 100, size=num_elements) # 4 bytes per int32 + action_obj_2 = ActionObject.from_obj(data_big) + assert can_upload_to_blob_storage(action_obj_2, root_client.api.metadata) + action_obj_2.send(root_client) + assert isinstance(action_obj_2.syft_blob_storage_entry_id, sy.UID) + # get back the object from blob storage to check if it is the same + root_authed_ctx = AuthedServiceContext( + node=worker, credentials=root_client.verify_key + ) + blob_storage = worker.get_service("BlobStorageService") + syft_retrieved_data = blob_storage.read( + root_authed_ctx, action_obj_2.syft_blob_storage_entry_id + ) + assert isinstance(syft_retrieved_data, SyftObjectRetrieval) + assert all(syft_retrieved_data.read() == data_big) + + +def test_upload_dataset_save_to_blob_storage(worker): + root_client: DomainClient = worker.root_client + root_authed_ctx = AuthedServiceContext( + node=worker, credentials=root_client.verify_key + ) + dataset = sy.Dataset( + name="small_dataset", + asset_list=[ + sy.Asset( + name="small_dataset", + data=np.array([1, 2, 3]), + mock=np.array([1, 1, 1]), + ) + ], + ) + root_client.upload_dataset(dataset) + blob_storage = worker.get_service("BlobStorageService") + assert len(blob_storage.get_all_blob_storage_entries(context=root_authed_ctx)) == 0 + + num_elements = 50 * 1024 * 1024 + data_big = np.random.randint(0, 100, size=num_elements) + dataset_big = sy.Dataset( + name="big_dataset", + asset_list=[ + sy.Asset( + name="big_dataset", + data=data_big, + mock=np.array([1, 1, 1]), + ) + ], + ) + root_client.upload_dataset(dataset_big) + # the private data should be saved to the blob storage + blob_entries: list = blob_storage.get_all_blob_storage_entries( + context=root_authed_ctx + ) + assert len(blob_entries) == 1 + data_big_retrieved: SyftObjectRetrieval = blob_storage.read( + context=root_authed_ctx, uid=blob_entries[0].id + ) + assert all(data_big_retrieved.read() == data_big) From 18b9a2116a90e39367a19b5bda232eb9670674c0 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 4 Jul 2024 15:29:40 +0200 Subject: [PATCH 254/309] Revert "min blob storage = 0" This reverts commit 301cbd479601b43a6fbcd53e46e23b17fe45f3b6. --- packages/grid/backend/grid/core/config.py | 2 +- packages/grid/default.env | 2 +- packages/syft/src/syft/node/node.py | 2 +- packages/syft/tests/syft/settings/fixtures.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index f4a0ccdfbca..619637c06c6 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -155,7 +155,7 @@ def get_emails_enabled(self) -> Self: ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") ) - MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 0)) + MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16)) REVERSE_TUNNEL_ENABLED: bool = str_to_bool( os.getenv("REVERSE_TUNNEL_ENABLED", "false") ) diff --git a/packages/grid/default.env b/packages/grid/default.env index af138315f60..0aae09f1026 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -55,7 +55,7 @@ CREATE_PRODUCER=False N_CONSUMERS=1 INMEMORY_WORKERS=True ASSOCIATION_REQUEST_AUTO_APPROVAL=False -MIN_SIZE_BLOB_STORAGE_MB=0 +MIN_SIZE_BLOB_STORAGE_MB=16 # New Service Flag USE_NEW_SERVICE=False diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index bda4373748f..9701be0f023 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -492,7 +492,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ) config_ = OnDiskBlobStorageConfig( client_config=client_config, - min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 0), + min_blob_size=os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 16), ) else: config_ = config diff --git a/packages/syft/tests/syft/settings/fixtures.py b/packages/syft/tests/syft/settings/fixtures.py index 80b0e2054f6..65cbf18deca 100644 --- a/packages/syft/tests/syft/settings/fixtures.py +++ b/packages/syft/tests/syft/settings/fixtures.py @@ -66,7 +66,7 @@ def metadata_json(faker) -> NodeMetadataJSON: node_side_type=NodeSideType.LOW_SIDE.value, show_warnings=False, node_type=NodeType.DOMAIN.value, - min_size_blob_storage_mb=0, + min_size_blob_storage_mb=16, ) From 0960f89534bf8f090471931e5572ff9421438609 Mon Sep 17 00:00:00 2001 From: teo Date: Thu, 4 Jul 2024 16:43:26 +0300 Subject: [PATCH 255/309] add assets to repr --- .../syft/src/syft/service/code/user_code.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index e7c69a2a2df..de5f477237e 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -26,6 +26,8 @@ from typing import final # third party +from IPython.display import HTML +from IPython.display import Markdown from IPython.display import display from pydantic import ValidationError from pydantic import field_validator @@ -918,7 +920,7 @@ def _inner_repr(self, level: int = 0) -> str: def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return as_markdown_code(self._inner_repr()) - def _repr_html_(self, level: int = 0) -> str: + def _ipython_display_(self, level: int = 0) -> None: tabs = " " * level shared_with_line = "" if len(self.output_readers) > 0 and self.output_reader_names is not None: @@ -927,6 +929,19 @@ def _repr_html_(self, level: int = 0) -> str: f"

{tabs}Custom Policy: " f"outputs are *shared* with the owners of {owners_string} once computed

" ) + constants_str = "" + args = [ + x + for _dict in self.input_policy_init_kwargs.values() # type: ignore + for x in _dict.values() + ] + constants = [x for x in args if isinstance(x, Constant)] + constants_str = "\n ".join([f"{x.kw}: {x.val}" for x in constants]) + # indent all lines except the first one + asset_str = "
".join( + [f"  {line}" for line in self._asset_json.split("\n")] + ).lstrip() + repr_str = f"""

{self.name}

-

{description_text}

+

Summary: {self.summary}

+ {"

A more detailed description is available by calling dataset.description

" if self.description is not None and self.description.text else ""} {uploaded_by_line}

Created on: {self.created_at}

URL: @@ -605,13 +636,27 @@ class DatasetPageView(SyftObject): @serializable() -class CreateDataset(Dataset): +class CreateDatasetV2(DatasetV2): # version __canonical_name__ = "CreateDataset" __version__ = SYFT_OBJECT_VERSION_2 asset_list: list[CreateAsset] = [] - __repr_attrs__ = ["name", "url"] + __repr_attrs__ = ["name", "summary", "url"] + + id: UID | None = None # type: ignore[assignment] + created_at: DateTime | None = None # type: ignore[assignment] + uploader: Contributor | None = None # type: ignore[assignment] + + +@serializable() +class CreateDataset(Dataset): + # version + __canonical_name__ = "CreateDataset" + __version__ = SYFT_OBJECT_VERSION_3 + asset_list: list[CreateAsset] = [] + + __repr_attrs__ = ["name", "summary", "url"] id: UID | None = None # type: ignore[assignment] created_at: DateTime | None = None # type: ignore[assignment] @@ -633,6 +678,9 @@ def __assets_must_contain_mock( def set_description(self, description: str) -> None: self.description = MarkdownDescription(text=description) + def set_summary(self, summary: str) -> None: + self.summary = summary + def add_citation(self, citation: str) -> None: self.citation = citation From 372ca0cd9839aeb67991f8758f8e9aa406144736 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 3 Jul 2024 18:22:34 -0400 Subject: [PATCH 260/309] slight text cleanup and add migrations --- .../syft/src/syft/service/dataset/dataset.py | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 4d31827e602..157550e4184 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -23,11 +23,14 @@ from ...store.document_store import PartitionKey from ...types.datetime import DateTime from ...types.dicttuple import DictTuple +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext +from ...types.transforms import drop from ...types.transforms import generate_id +from ...types.transforms import make_set_default from ...types.transforms import transform from ...types.transforms import validate_url from ...types.uid import UID @@ -537,6 +540,13 @@ def _repr_html_(self) -> Any: if self.uploader else "" ) + if self.description is not None and self.description.text: + description_info_message = ( + "

A more detailed description is available by calling \ + dataset.description.

" + ) + else: + description_info_message = "" return f"""
-

{self.name}

-

Summary: {self.summary}

+

{self.name}

+

Summary

+

{self.summary}

{description_info_message} +

Dataset Details

{uploaded_by_line}

Created on: {self.created_at}

URL: {self.url}

Contributors: To see full details call dataset.contributors.

+

Assets

{self.assets._repr_html_()} """ From 1e75adc61cf9b9c028bccd7733f84436b0ec7977 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Thu, 4 Jul 2024 19:03:17 -0400 Subject: [PATCH 266/309] update notebooks and make summary optional in repr --- notebooks/api/0.8/00-load-data.ipynb | 3 ++- .../tutorials/data-owner/01-uploading-private-data.ipynb | 6 ++++++ packages/syft/src/syft/service/dataset/dataset.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/notebooks/api/0.8/00-load-data.ipynb b/notebooks/api/0.8/00-load-data.ipynb index c61fab17f41..ad98a5ac361 100644 --- a/notebooks/api/0.8/00-load-data.ipynb +++ b/notebooks/api/0.8/00-load-data.ipynb @@ -348,7 +348,8 @@ }, "outputs": [], "source": [ - "dataset.set_description(\"Canada Trade Data\")" + "dataset.set_description(\"Canada Trade Data Markdown Description\")\n", + "dataset.set_summary(\"Canada Trade Data Short Summary\")" ] }, { diff --git a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb index 02ed5576cb0..3a1863c0bdd 100644 --- a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb +++ b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb @@ -124,8 +124,14 @@ "metadata": {}, "outputs": [], "source": [ + "dataset_markdown_description = \"\"\"\n", + "### Contents\n", + "Numpy arrays of length 3 with integers ranging from 1 - 3.\n", + "\"\"\"\n", "dataset = sy.Dataset(\n", " name=\"my dataset\",\n", + " summary=\"Contains private and mock versions of data\",\n", + " description=dataset_markdown_description,\n", " asset_list=[\n", " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", " ],\n", diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 950153d6a7d..c6d49799c6e 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -572,7 +572,7 @@ def _repr_html_(self) -> Any:

{self.name}

Summary

-

{self.summary}

+ {f"

{self.summary}

" if self.summary else ""} {description_info_message}

Dataset Details

{uploaded_by_line} From 54ba6b3327d5343a66e63b60f525640fe6d07192 Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 5 Jul 2024 10:56:13 +0700 Subject: [PATCH 267/309] [syft/blob_storage] only add blob persmission for an action object when its `syft_blob_storage_entry_id` is not None --- .../syft/src/syft/service/request/request.py | 25 +++++++++++-------- .../src/syft/service/sync/sync_service.py | 14 +++++------ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 8c5687ac55e..65a64872249 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -158,24 +158,29 @@ def _run( uid_blob = action_obj.private.syft_blob_storage_entry_id else: uid_blob = action_obj.syft_blob_storage_entry_id - requesting_permission_blob_obj = ActionObjectPermission( - uid=uid_blob, - credentials=context.requesting_user_credentials, - permission=self.apply_permission_type, - ) + if uid_blob is not None: + requesting_permission_blob_obj = ActionObjectPermission( + uid=uid_blob, + credentials=context.requesting_user_credentials, + permission=self.apply_permission_type, + ) if apply: logger.debug( "ADDING PERMISSION", requesting_permission_action_obj, id_action ) action_store.add_permission(requesting_permission_action_obj) - blob_storage_service.stash.add_permission( - requesting_permission_blob_obj - ) + if uid_blob is not None: + blob_storage_service.stash.add_permission( + requesting_permission_blob_obj + ) else: if action_store.has_permission(requesting_permission_action_obj): action_store.remove_permission(requesting_permission_action_obj) - if blob_storage_service.stash.has_permission( - requesting_permission_blob_obj + if ( + uid_blob is not None + and blob_storage_service.stash.has_permission( + requesting_permission_blob_obj + ) ): blob_storage_service.stash.remove_permission( requesting_permission_blob_obj diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 62885742c5b..d0f1f900358 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -72,13 +72,13 @@ def add_actionobject_read_permissions( for permission in new_permissions: if permission.permission == ActionPermission.READ: store_to.add_permission(permission) - - permission_blob = ActionObjectPermission( - uid=blob_id, - permission=permission.permission, - credentials=permission.credentials, - ) - store_to_blob.add_permission(permission_blob) + if blob_id is not None: + permission_blob = ActionObjectPermission( + uid=blob_id, + permission=permission.permission, + credentials=permission.credentials, + ) + store_to_blob.add_permission(permission_blob) def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: if hasattr(x, "__dict__") and isinstance(x, SyftObject): From 787af9bbfaaec512e8ceae67883e1ed5077ec3aa Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 5 Jul 2024 11:14:16 +0700 Subject: [PATCH 268/309] [syft/action_service] only add blob permission if the entry's id is not None --- .../syft/src/syft/service/action/action_service.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 38cdd3c5b5a..ac0ada54dda 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -565,14 +565,21 @@ def store_permission( def blob_permission( x: SyftVerifyKey | None = None, - ) -> ActionObjectPermission: - return ActionObjectPermission(result_blob_id, read_permission, x) + ) -> ActionObjectPermission | None: + if result_blob_id: + return ActionObjectPermission(result_blob_id, read_permission, x) + else: + return None if len(output_readers) > 0: store_permissions = [store_permission(x) for x in output_readers] self.store.add_permissions(store_permissions) - blob_permissions = [blob_permission(x) for x in output_readers] + blob_permissions = [ + blob_permission(x) + for x in output_readers + if blob_permission(x) is not None + ] blob_storage_service.stash.add_permissions(blob_permissions) return set_result From 3399a82f3e025e3a2ec7cd8787ee5f2f63c1274a Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 5 Jul 2024 11:38:50 +0530 Subject: [PATCH 269/309] skip checking/adding blob storage permissions if data is saved to store --- .../src/syft/service/action/action_service.py | 8 +++--- .../syft/src/syft/service/request/request.py | 25 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index da273e2c12b..3d81fbbd81b 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -186,7 +186,8 @@ def _set( blob_storage_service: AbstractService = context.node.get_service( BlobStorageService ) - blob_storage_service.stash.add_permission(permission) + if not skip_save_to_blob_store: + blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private else: @@ -569,8 +570,9 @@ def blob_permission( store_permissions = [store_permission(x) for x in output_readers] self.store.add_permissions(store_permissions) - blob_permissions = [blob_permission(x) for x in output_readers] - blob_storage_service.stash.add_permissions(blob_permissions) + if not skip_save_to_blob_store: + blob_permissions = [blob_permission(x) for x in output_readers] + blob_storage_service.stash.add_permissions(blob_permissions) return set_result diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 8c5687ac55e..bb01d0e8a4c 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -158,24 +158,35 @@ def _run( uid_blob = action_obj.private.syft_blob_storage_entry_id else: uid_blob = action_obj.syft_blob_storage_entry_id - requesting_permission_blob_obj = ActionObjectPermission( - uid=uid_blob, - credentials=context.requesting_user_credentials, - permission=self.apply_permission_type, + requesting_permission_blob_obj = ( + ActionObjectPermission( + uid=uid_blob, + credentials=context.requesting_user_credentials, + permission=self.apply_permission_type, + ) + if uid_blob + else None ) if apply: logger.debug( "ADDING PERMISSION", requesting_permission_action_obj, id_action ) action_store.add_permission(requesting_permission_action_obj) - blob_storage_service.stash.add_permission( - requesting_permission_blob_obj + ( + blob_storage_service.stash.add_permission( + requesting_permission_blob_obj + ) + if requesting_permission_blob_obj + else None ) else: if action_store.has_permission(requesting_permission_action_obj): action_store.remove_permission(requesting_permission_action_obj) - if blob_storage_service.stash.has_permission( + if ( requesting_permission_blob_obj + and blob_storage_service.stash.has_permission( + requesting_permission_blob_obj + ) ): blob_storage_service.stash.remove_permission( requesting_permission_blob_obj From 247787599b78c4e1b0ed69a3b11fa76340899c3b Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 5 Jul 2024 13:56:10 +0700 Subject: [PATCH 270/309] [syft/sync_service] stop saving blob permission if the action object is small (no blob entry exists) --- .../src/syft/service/sync/sync_service.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 62885742c5b..d69889d9381 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -64,21 +64,24 @@ def add_actionobject_read_permissions( action_object: ActionObject, new_permissions: list[ActionObjectPermission], ) -> None: - blob_id = action_object.syft_blob_storage_entry_id - store_to = context.node.get_service("actionservice").store # type: ignore - store_to_blob = context.node.get_service("blobstorageservice").stash.partition # type: ignore - for permission in new_permissions: if permission.permission == ActionPermission.READ: store_to.add_permission(permission) - permission_blob = ActionObjectPermission( - uid=blob_id, - permission=permission.permission, - credentials=permission.credentials, - ) - store_to_blob.add_permission(permission_blob) + blob_id = action_object.syft_blob_storage_entry_id + if blob_id: + store_to_blob = context.node.get_service( + "blobstorageservice" + ).stash.partition # type: ignore + for permission in new_permissions: + if permission.permission == ActionPermission.READ: + permission_blob = ActionObjectPermission( + uid=blob_id, + permission=permission.permission, + credentials=permission.credentials, + ) + store_to_blob.add_permission(permission_blob) def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: if hasattr(x, "__dict__") and isinstance(x, SyftObject): From c6bef1d15628dcdf04ea299aafd0cb150241dc2b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 5 Jul 2024 10:42:47 +0200 Subject: [PATCH 271/309] revert blob id None check --- .../syft/src/syft/service/action/action_service.py | 13 +++++-------- tox.ini | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 38cdd3c5b5a..da273e2c12b 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -182,14 +182,11 @@ def _set( if isinstance(action_object, TwinObject): # give read permission to the mock blob_id = action_object.mock_obj.syft_blob_storage_entry_id - if blob_id is not None: - permission = ActionObjectPermission( - blob_id, ActionPermission.ALL_READ - ) - blob_storage_service: AbstractService = context.node.get_service( - BlobStorageService - ) - blob_storage_service.stash.add_permission(permission) + permission = ActionObjectPermission(blob_id, ActionPermission.ALL_READ) + blob_storage_service: AbstractService = context.node.get_service( + BlobStorageService + ) + blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private else: diff --git a/tox.ini b/tox.ini index bad928ce7e8..1ba9d01ab16 100644 --- a/tox.ini +++ b/tox.ini @@ -1064,7 +1064,7 @@ commands = pytest --disable-warnings [testenv:migration.prepare] -description = Migration Test +description = Prepare Migration Data deps = syft nbmake From 91ea0bfddcb2ab9077cf0588c513587f6c06c515 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 5 Jul 2024 15:19:24 +0530 Subject: [PATCH 272/309] define an attribute to indicate save to blob store at action object level - define migrations for action object and its subclasses - deprecate use of skip_save_to_blob_store --- .../syft/src/syft/client/domain_client.py | 7 +- .../src/syft/protocol/protocol_version.json | 56 +++++++++++ .../src/syft/service/action/action_object.py | 96 ++++++++++++++++--- .../src/syft/service/action/action_service.py | 60 ++++++------ .../syft/src/syft/service/action/numpy.py | 73 +++++++++++++- .../syft/src/syft/service/action/pandas.py | 50 +++++++++- .../syft/src/syft/service/dataset/dataset.py | 4 - packages/syft/src/syft/types/blob_storage.py | 26 ++++- packages/syft/src/syft/types/twin_object.py | 4 - 9 files changed, 313 insertions(+), 63 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 9273ae15e5e..4d370783587 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -148,13 +148,8 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False response = self.api.services.action.set( - twin, - ignore_detached_objs=contains_empty, - skip_save_to_blob_store=skip_save_to_blob_store, + twin, ignore_detached_objs=contains_empty ) if isinstance(response, SyftError): tqdm.write(f"Failed to upload asset: {asset.name}") diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 44b028c30ab..d2ab6739c4c 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -430,6 +430,62 @@ "hash": "3117e16cbe4dbc344ab90fbbd36ba90dfb518e66f0fb07644bbe7864dcdeb309", "action": "add" } + }, + "ActionObject": { + "4": { + "version": 4, + "hash": "a4dd2949af0f516d0f640d28e0fdfa026ba8d55bb29eaa7844c926e467606892", + "action": "add" + } + }, + "AnyActionObject": { + "4": { + "version": 4, + "hash": "809bd7ffab211133a9be87e058facecf870a79cb2d4027616f5244323de27091", + "action": "add" + } + }, + "BlobFileOBject": { + "3": { + "version": 3, + "hash": "27901fcd545ad0607dbfcbfa0141ee03b0f0f4bee8d23f2d661a4b22011bfd37", + "action": "add" + } + }, + "NumpyArrayObject": { + "4": { + "version": 4, + "hash": "19e2ff3da78038d2164f86d1f9b0d1facc6008483be60d2852458e90202bb96b", + "action": "add" + } + }, + "NumpyScalarObject": { + "4": { + "version": 4, + "hash": "5101d00dd92ac4391cae77629eb48aa25401cc8c5ebb28a8a969cd5eba35fb67", + "action": "add" + } + }, + "NumpyBoolObject": { + "4": { + "version": 4, + "hash": "764cd93792c4dfe27b8952fde853626592fe58e1a341b5350b23f38ce474583f", + "action": "add" + } + }, + "PandasDataframeObject": { + "4": { + "version": 4, + "hash": "b70f4bb32ba9f3f5ea89552649bf882d927cf9085fb573cc6d4841b32d653f84", + "action": "add" + } + }, + "PandasSeriesObject": { + "4": { + "version": 4, + "hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index b9ffd16ebf6..e878eb1bad2 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -43,11 +43,15 @@ from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel from ...types.datetime import DateTime +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import LineageID from ...types.uid import UID from ...util.util import prompt_warning_message @@ -527,13 +531,7 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: print(r.message) if isinstance(r, SyftWarning): logger.debug(r.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False - arg = api.services.action.set( - arg, - skip_save_to_blob_store=skip_save_to_blob_store, - ) + arg = api.services.action.set(arg) return arg arg_list = [process_arg(arg) for arg in args] if args else [] @@ -675,7 +673,7 @@ def truncate_str(string: str, length: int = 100) -> str: @serializable(without=["syft_pre_hooks__", "syft_post_hooks__"]) -class ActionObject(SyncableSyftObject): +class ActionObjectV3(SyncableSyftObject): """Action object for remote execution.""" __canonical_name__ = "ActionObject" @@ -710,6 +708,45 @@ class ActionObject(SyncableSyftObject): syft_created_at: DateTime | None = None syft_resolved: bool = True syft_action_data_node_id: UID | None = None + + +@serializable(without=["syft_pre_hooks__", "syft_post_hooks__"]) +class ActionObject(SyncableSyftObject): + """Action object for remote execution.""" + + __canonical_name__ = "ActionObject" + __version__ = SYFT_OBJECT_VERSION_4 + __private_sync_attr_mocks__: ClassVar[dict[str, Any]] = { + "syft_action_data_cache": None, + "syft_blob_storage_entry_id": None, + } + + __attr_searchable__: list[str] = [] # type: ignore[misc] + syft_action_data_cache: Any | None = None + syft_blob_storage_entry_id: UID | None = None + syft_pointer_type: ClassVar[type[ActionObjectPointer]] + + # Help with calculating history hash for code verification + syft_parent_hashes: int | list[int] | None = None + syft_parent_op: str | None = None + syft_parent_args: Any | None = None + syft_parent_kwargs: Any | None = None + syft_history_hash: int | None = None + syft_internal_type: ClassVar[type[Any]] + syft_node_uid: UID | None = None + syft_pre_hooks__: dict[str, list] = {} + syft_post_hooks__: dict[str, list] = {} + syft_twin_type: TwinMode = TwinMode.NONE + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_action_data_type: type | None = None + syft_action_data_repr_: str | None = None + syft_action_data_str_: str | None = None + syft_has_bool_attr: bool | None = None + syft_resolve_data: bool | None = None + syft_created_at: DateTime | None = None + syft_resolved: bool = True + syft_action_data_node_id: UID | None = None + syft_action_saved_to_blob_store: bool = True # syft_dont_wrap_attrs = ["shape"] def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: @@ -814,6 +851,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: if get_metadata is not None and not can_upload_to_blob_storage( data, get_metadata() ): + self.syft_action_saved_to_blob_store = False return SyftWarning( message=f"The action object {self.id} was not saved to " f"the blob store but to memory cache since it is small." @@ -1247,13 +1285,9 @@ def _send( if isinstance(blob_storage_res, SyftWarning): logger.debug(blob_storage_res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False res = api.services.action.set( self, add_storage_permission=add_storage_permission, - skip_save_to_blob_store=skip_save_to_blob_store, ) if isinstance(res, ActionObject): self.syft_created_at = res.syft_created_at @@ -2189,7 +2223,7 @@ def __rrshift__(self, other: Any) -> Any: @serializable() -class AnyActionObject(ActionObject): +class AnyActionObjectV3(ActionObjectV3): """ This is a catch-all class for all objects that are not defined in the `action_types` dictionary. @@ -2203,6 +2237,22 @@ class AnyActionObject(ActionObject): syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"] syft_action_data_str_: str = "" + +@serializable() +class AnyActionObject(ActionObject): + """ + This is a catch-all class for all objects that are not + defined in the `action_types` dictionary. + """ + + __canonical_name__ = "AnyActionObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type[Any]] = NoneType # type: ignore + # syft_passthrough_attrs: List[str] = [] + syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"] + syft_action_data_str_: str = "" + def __float__(self) -> float: return float(self.syft_action_data) @@ -2238,3 +2288,23 @@ def has_action_data_empty(args: Any, kwargs: Any) -> bool: if is_action_data_empty(a): return True return False + + +@migrate(ActionObjectV3, ActionObject) +def upgrade_action_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(ActionObject, ActionObjectV3) +def downgrade_action_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(AnyActionObjectV3, AnyActionObject) +def upgrade_anyaction_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(AnyActionObject, AnyActionObjectV3) +def downgrade_anyaction_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 3d81fbbd81b..91a527569bb 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -75,15 +75,8 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: return blob_store_result if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False - np_pointer = self._set( - context, - np_obj, - skip_save_to_blob_store=skip_save_to_blob_store, - ) + np_pointer = self._set(context, np_obj) return np_pointer @service_method( @@ -97,7 +90,6 @@ def set( action_object: ActionObject | TwinObject, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_save_to_blob_store: bool = False, ) -> ActionObject | SyftError: res = self._set( context, @@ -105,7 +97,6 @@ def set( has_result_read_permission=True, add_storage_permission=add_storage_permission, ignore_detached_objs=ignore_detached_objs, - skip_save_to_blob_store=skip_save_to_blob_store, ) if res.is_err(): return SyftError(message=res.value) @@ -123,14 +114,22 @@ def is_detached_obj( if ( isinstance(action_object, TwinObject) and ( - action_object.mock_obj.syft_blob_storage_entry_id is None - or action_object.private_obj.syft_blob_storage_entry_id is None + ( + action_object.mock_obj.syft_action_saved_to_blob_store + and action_object.mock_obj.syft_blob_storage_entry_id is None + ) + or ( + action_object.private_obj.syft_action_saved_to_blob_store + and action_object.private_obj.syft_blob_storage_entry_id is None + ) ) and not ignore_detached_obj ): return True if isinstance(action_object, ActionObject) and ( - action_object.syft_blob_storage_entry_id is None and not ignore_detached_obj + action_object.syft_action_saved_to_blob_store + and action_object.syft_blob_storage_entry_id is None + and not ignore_detached_obj ): return True return False @@ -142,12 +141,8 @@ def _set( has_result_read_permission: bool = False, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_save_to_blob_store: bool = False, ) -> Result[ActionObject, str]: - if ( - self.is_detached_obj(action_object, ignore_detached_objs) - and not skip_save_to_blob_store - ): + if self.is_detached_obj(action_object, ignore_detached_objs): return Err( "You uploaded an ActionObject that is not yet in the blob storage" ) @@ -156,14 +151,26 @@ def _set( if isinstance(action_object, ActionObject): action_object.syft_created_at = DateTime.now() - if not skip_save_to_blob_store: + ( action_object._clear_cache() + if action_object.syft_action_saved_to_blob_store + else None + ) else: # TwinObject action_object.private_obj.syft_created_at = DateTime.now() # type: ignore[unreachable] action_object.mock_obj.syft_created_at = DateTime.now() - if not skip_save_to_blob_store: + + # Clear cache if data is saved to blob storage + ( action_object.private_obj._clear_cache() + if action_object.private_obj.syft_action_saved_to_blob_store + else None + ) + ( action_object.mock_obj._clear_cache() + if action_object.mock_obj.syft_action_saved_to_blob_store + else None + ) # If either context or argument is True, has_result_read_permission is True has_result_read_permission = ( @@ -186,7 +193,8 @@ def _set( blob_storage_service: AbstractService = context.node.get_service( BlobStorageService ) - if not skip_save_to_blob_store: + # if mock is saved to blob store, then add READ permission + if action_object.mock_obj.syft_action_saved_to_blob_store: blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private @@ -528,9 +536,6 @@ def set_result_to_store( return Err(blob_store_result.message) if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False # IMPORTANT: DO THIS ONLY AFTER ._save_to_blob_storage if isinstance(result_action_object, TwinObject): @@ -546,7 +551,6 @@ def set_result_to_store( context, result_action_object, has_result_read_permission=True, - skip_save_to_blob_store=skip_save_to_blob_store, ) if set_result.is_err(): @@ -570,7 +574,7 @@ def blob_permission( store_permissions = [store_permission(x) for x in output_readers] self.store.add_permissions(store_permissions) - if not skip_save_to_blob_store: + if result_blob_id is not None: blob_permissions = [blob_permission(x) for x in output_readers] blob_storage_service.stash.add_permissions(blob_permissions) @@ -816,13 +820,9 @@ def execute( } if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False set_result = self._set( context, result_action_object, - skip_save_to_blob_store=skip_save_to_blob_store, ) if set_result.is_err(): return Err( diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index da8c8aecc05..f73f65f2cd7 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from typing import Any from typing import ClassVar @@ -8,9 +9,14 @@ # relative from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.transforms import drop +from ...types.transforms import make_set_default from .action_object import ActionObject from .action_object import ActionObjectPointer +from .action_object import ActionObjectV3 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @@ -41,7 +47,7 @@ def numpy_like_eq(left: Any, right: Any) -> bool: # 🔵 TODO 7: Map TPActionObjects and their 3rd Party types like numpy type to these # classes for bi-directional lookup. @serializable() -class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyArrayObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyArrayObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -50,6 +56,17 @@ class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + +@serializable() +class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyArrayObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type[Any]] = np.ndarray + syft_pointer_type: ClassVar[type[ActionObjectPointer]] = NumpyArrayObjectPointer + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + # def __eq__(self, other: Any) -> bool: # # 🟡 TODO 8: move __eq__ to a Data / Serdeable type interface on ActionObject # if isinstance(other, NumpyArrayObject): @@ -84,7 +101,7 @@ def __array_ufunc__( @serializable() -class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyScalarObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyScalarObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -92,12 +109,22 @@ class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + +@serializable() +class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyScalarObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = np.number + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + def __float__(self) -> float: return float(self.syft_action_data) @serializable() -class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyBoolObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyBoolObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -106,6 +133,16 @@ class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] +@serializable() +class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyBoolObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = np.bool_ + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + + np_array = np.array([1, 2, 3]) action_types[type(np_array)] = NumpyArrayObject @@ -135,3 +172,33 @@ class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): for scalar_type in SUPPORTED_INT_TYPES + SUPPORTED_FLOAT_TYPES: # type: ignore action_types[scalar_type] = NumpyScalarObject + + +@migrate(NumpyArrayObjectV3, NumpyArrayObject) +def upgrade_numpyarray_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyArrayObject, NumpyArrayObjectV3) +def downgrade_numpyarray_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(NumpyBoolObjectV3, NumpyBoolObject) +def upgrade_numpybool_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyBoolObject, NumpyBoolObjectV3) +def downgrade_numpybool_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(NumpyScalarObjectV3, NumpyScalarObject) +def upgrade_numpyscalar_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyScalarObject, NumpyScalarObjectV3) +def downgrade_numpyscalar_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index d16dec119b0..3238b4f53d6 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from typing import Any from typing import ClassVar @@ -8,14 +9,19 @@ # relative from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.transforms import drop +from ...types.transforms import make_set_default from .action_object import ActionObject +from .action_object import ActionObjectV3 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @serializable() -class PandasDataFrameObject(ActionObject): +class PandasDataFrameObjectV3(ActionObjectV3): __canonical_name__ = "PandasDataframeObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -24,6 +30,17 @@ class PandasDataFrameObject(ActionObject): # this is added for instance checks for dataframes # syft_dont_wrap_attrs = ["shape"] + +@serializable() +class PandasDataFrameObject(ActionObject): + __canonical_name__ = "PandasDataframeObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = DataFrame + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + # this is added for instance checks for dataframes + # syft_dont_wrap_attrs = ["shape"] + def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: return self.__dataframe__(*args, **kwargs) @@ -46,13 +63,22 @@ def __bool__(self) -> bool: @serializable() -class PandasSeriesObject(ActionObject): +class PandasSeriesObjectV3(ActionObjectV3): __canonical_name__ = "PandasSeriesObject" __version__ = SYFT_OBJECT_VERSION_3 syft_internal_type = Series syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + +@serializable() +class PandasSeriesObject(ActionObject): + __canonical_name__ = "PandasSeriesObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type = Series + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + # name: Optional[str] = None # syft_dont_wrap_attrs = ["shape"] @@ -73,3 +99,23 @@ def syft_is_property(self, obj: Any, method: str) -> bool: action_types[DataFrame] = PandasDataFrameObject action_types[Series] = PandasSeriesObject + + +@migrate(PandasSeriesObjectV3, PandasSeriesObject) +def upgrade_pandasseries_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(PandasSeriesObject, PandasSeriesObjectV3) +def downgrade_pandasseries_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(PandasDataFrameObjectV3, PandasDataFrameObject) +def upgrade_pandasdataframe_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(PandasDataFrameObject, PandasDataFrameObjectV3) +def downgrade_pandasdataframe_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 5ab32347136..54a84878379 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -743,9 +743,6 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: raise ValueError(res.message) if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False # TODO, upload to blob storage here if context.node is None: raise ValueError( @@ -755,7 +752,6 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: result = action_service._set( context=context.to_node_context(), action_object=twin, - skip_save_to_blob_store=skip_save_to_blob_store, ) if result.is_err(): raise RuntimeError(f"Failed to create and store twin. Error: {result}") diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index a92134f26f7..016931b0901 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -27,15 +27,19 @@ from ..serde.serializable import serializable from ..service.action.action_object import ActionObject from ..service.action.action_object import ActionObjectPointer +from ..service.action.action_object import ActionObjectV3 from ..service.action.action_object import BASE_PASSTHROUGH_ATTRS from ..service.action.action_types import action_types from ..service.response import SyftError from ..service.response import SyftException from ..service.service import from_api_or_context from ..types.grid_url import GridURL +from ..types.transforms import drop from ..types.transforms import keep +from ..types.transforms import make_set_default from ..types.transforms import transform from .datetime import DateTime +from .syft_migration import migrate from .syft_object import SYFT_OBJECT_VERSION_2 from .syft_object import SYFT_OBJECT_VERSION_3 from .syft_object import SYFT_OBJECT_VERSION_4 @@ -192,7 +196,7 @@ class BlobFileObjectPointer(ActionObjectPointer): @serializable() -class BlobFileObject(ActionObject): +class BlobFileObjectV3(ActionObjectV3): __canonical_name__ = "BlobFileOBject" __version__ = SYFT_OBJECT_VERSION_2 @@ -201,6 +205,16 @@ class BlobFileObject(ActionObject): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS +@serializable() +class BlobFileObject(ActionObject): + __canonical_name__ = "BlobFileOBject" + __version__ = SYFT_OBJECT_VERSION_3 + + syft_internal_type: ClassVar[type[Any]] = BlobFile + syft_pointer_type: ClassVar[type[ActionObjectPointer]] = BlobFileObjectPointer + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + + @serializable() class SecureFilePathLocation(SyftObject): __canonical_name__ = "SecureFilePathLocation" @@ -370,3 +384,13 @@ def storage_entry_to_metadata() -> list[Callable]: action_types[BlobFile] = BlobFileObject + + +@migrate(BlobFileObjectV3, BlobFileObject) +def upgrade_blobfile_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(BlobFileObject, BlobFileObjectV3) +def downgrade_blobfile_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index eae86e9cb5b..f2d2f020ef9 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -109,12 +109,8 @@ def send(self, client: SyftClient, add_storage_permission: bool = True) -> Any: blob_store_result = self._save_to_blob_storage() if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False res = client.api.services.action.set( self, add_storage_permission=add_storage_permission, - skip_save_to_blob_store=skip_save_to_blob_store, ) return res From 4e6d57210d0377b9542423a49545da72b7af1804 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 5 Jul 2024 13:44:45 +0200 Subject: [PATCH 273/309] add tox local + k8s tests --- .gitignore | 3 + ...le.ipynb => 1-dump-database-to-file.ipynb} | 25 +- .../1a-connect-and-migrate.ipynb | 88 ----- .../1b-connect-and-migrate-via-api.ipynb | 373 ------------------ ...m-file.ipynb => 2-migrate-from-file.ipynb} | 91 +++-- .../2-post-migration-tests.ipynb | 272 ------------- packages/grid/devspace.yaml | 2 +- packages/grid/helm/examples/dev/migrated.yaml | 6 - .../src/syft/service/action/action_service.py | 15 +- tox.ini | 99 ++++- 10 files changed, 182 insertions(+), 792 deletions(-) rename notebooks/tutorials/version-upgrades/{1c-dump-to-file.ipynb => 1-dump-database-to-file.ipynb} (80%) delete mode 100644 notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb delete mode 100644 notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb rename notebooks/tutorials/version-upgrades/{2a-migrate-from-file.ipynb => 2-migrate-from-file.ipynb} (85%) delete mode 100644 notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb delete mode 100644 packages/grid/helm/examples/dev/migrated.yaml diff --git a/.gitignore b/.gitignore index 369de6643c7..52b24d684cb 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,6 @@ notebooks/helm/scenario_data.jsonl # tox syft.build.helm generated file out.* .git-blame-ignore-revs + +# migration data +packages/grid/helm/examples/dev/migration.yaml \ No newline at end of file diff --git a/notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb similarity index 80% rename from notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb rename to notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb index d72775ba24e..e5ecd841e8f 100644 --- a/notebooks/tutorials/version-upgrades/1c-dump-to-file.ipynb +++ b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb @@ -50,7 +50,8 @@ "metadata": {}, "outputs": [], "source": [ - "client.jobs[0].result.get()" + "# Check if this node has data on it\n", + "assert len(client.users.get_all()) == 2" ] }, { @@ -71,7 +72,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(migration_data.blobs)" + "assert migration_data.includes_blobs\n", + "assert migration_data.num_action_objects > 0" ] }, { @@ -81,8 +83,8 @@ "metadata": {}, "outputs": [], "source": [ - "blob_path = Path(\"./my_migration.blob\")\n", - "yaml_path = Path(\"my_migration.yaml\")\n", + "blob_path = Path(\"./migration.blob\")\n", + "yaml_path = Path(\"migration.yaml\")\n", "\n", "blob_path.unlink(missing_ok=True)\n", "yaml_path.unlink(missing_ok=True)" @@ -98,7 +100,9 @@ "migration_data.save(blob_path, yaml_path=yaml_path)\n", "\n", "assert blob_path.exists()\n", - "assert yaml_path.exists()" + "assert yaml_path.exists()\n", + "\n", + "print(f\"Saved migration data to {str(blob_path.resolve())}\")" ] }, { @@ -107,6 +111,17 @@ "id": "8", "metadata": {}, "outputs": [], + "source": [ + "if node.node_type.value == \"python\":\n", + " node.land()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb b/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb deleted file mode 100644 index fed82bb72bd..00000000000 --- a/notebooks/tutorials/version-upgrades/1a-connect-and-migrate.ipynb +++ /dev/null @@ -1,88 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "# third party\n", - "import numpy as np\n", - "\n", - "# syft absolute\n", - "import syft as sy\n", - "from syft.service.job.job_stash import Job" - ] - }, - { - "cell_type": "markdown", - "id": "1", - "metadata": {}, - "source": [ - "# Serialization tests" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": {}, - "outputs": [], - "source": [ - "assert sy.serialize(np.array([1, 2, 3])).canonicalName == \"numpy.ndarray\"\n", - "assert sy.serialize(bool).canonicalName == \"builtins.type\"\n", - "assert (\n", - " sy.serialize(Job).canonicalName\n", - " == \"pydantic._internal._model_construction.ModelMetaclass\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "node = sy.orchestra.launch(\n", - " name=\"test_upgradability\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb b/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb deleted file mode 100644 index 7c23086fac9..00000000000 --- a/notebooks/tutorials/version-upgrades/1b-connect-and-migrate-via-api.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "import syft as sy\n", - "from syft.service.log.log import SyftLogV3\n", - "from syft.types.syft_object import Context\n", - "from syft.types.syft_object import SyftObject" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "print(f\"syft version: {sy.__version__}\")" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "TODOS\n", - "- [x] action objects\n", - "- [x] maybe an example of how to migrate one object type in a custom way\n", - "- [x] check SyftObjectRegistry and compare with current implementation\n", - "- [x] run unit tests\n", - "- [ ] finalize notebooks for testing, run in CI\n", - "- [ ] other tasks defined in tickets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "node = sy.orchestra.launch(\n", - " name=\"test_upgradability\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, - "source": [ - "# Client side migrations" - ] - }, - { - "cell_type": "markdown", - "id": "6", - "metadata": {}, - "source": [ - "## document store objects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "migration_dict = client.services.migration.get_migration_objects(get_all=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "migration_dict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "def custom_migration_function(context, obj: SyftObject, klass) -> SyftObject:\n", - " # Here, we are just doing the same, but this is where you would write your custom logic\n", - " return obj.migrate_to(klass.__version__, context)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "# this wont work in the cases where the context is actually used,\n", - "# but since this would need custom logic here anyway you write workarounds for that (manually querying required state)\n", - "\n", - "\n", - "context = Context()\n", - "migrated_objects = []\n", - "for klass, objects in migration_dict.items():\n", - " for obj in objects:\n", - " if isinstance(obj, SyftLogV3):\n", - " migrated_obj = custom_migration_function(context, obj, klass)\n", - " else:\n", - " migrated_obj = obj.migrate_to(klass.__version__, context)\n", - " migrated_objects.append(migrated_obj)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "migrated_objects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "res = client.services.migration.update_migrated_objects(migrated_objects)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(res, sy.SyftSuccess)" - ] - }, - { - "cell_type": "markdown", - "id": "14", - "metadata": {}, - "source": [ - "## Actions and ActionObjects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "migration_action_dict = client.services.migration.get_migration_actionobjects()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "# this wont work in the cases where the context is actually used, but since this you would need custom logic here anyway\n", - "# it doesnt matter\n", - "context = Context()\n", - "migrated_actionobjects = []\n", - "for klass, objects in migration_action_dict.items():\n", - " for obj in objects:\n", - " # custom migration logic here\n", - " migrated_actionobject = obj.migrate_to(klass.__version__, context)\n", - " migrated_actionobjects.append(migrated_actionobject)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "migrated_actionobjects" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "res = client.services.migration.update_migrated_objects(migrated_actionobjects)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(res, sy.SyftSuccess)" - ] - }, - { - "cell_type": "markdown", - "id": "20", - "metadata": {}, - "source": [ - "## Store metadata\n", - "\n", - "- Permissions\n", - "- StoragePermissions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "store_metadata = client.services.migration.get_all_store_metadata()\n", - "store_metadata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [ - "for k, v in store_metadata.items():\n", - " if len(v.permissions):\n", - " print(\n", - " k, len(v.permissions), len(v.permissions) == len(migration_dict.get(k, []))\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "# Test update method with a temp node\n", - "# After update, all metadata should match between the nodes\n", - "\n", - "temp_node = sy.orchestra.launch(\n", - " name=\"temp_node\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - " reset=True,\n", - ")\n", - "\n", - "temp_client = temp_node.login(email=\"info@openmined.org\", password=\"changethis\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "temp_client.services.migration.update_store_metadata(store_metadata)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "for cname, real_partition in node.python_node.document_store.partitions.items():\n", - " temp_partition = temp_node.python_node.document_store.partitions[cname]\n", - "\n", - " temp_perms = dict(temp_partition.permissions.items())\n", - " real_perms = dict(real_partition.permissions.items())\n", - "\n", - " # Only look at migrated items\n", - " temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", - " assert temp_perms == real_perms\n", - "\n", - " temp_storage = dict(temp_partition.storage_permissions.items())\n", - " real_storage = dict(real_partition.storage_permissions.items())\n", - " temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", - "\n", - " assert temp_storage == real_storage\n", - "\n", - "# Action store\n", - "real_partition = node.python_node.action_store\n", - "temp_partition = temp_node.python_node.action_store\n", - "temp_perms = dict(temp_partition.permissions.items())\n", - "real_perms = dict(real_partition.permissions.items())\n", - "\n", - "# Only look at migrated items\n", - "temp_perms = {k: v for k, v in temp_perms.items() if k in real_perms}\n", - "assert temp_perms == real_perms\n", - "\n", - "temp_storage = dict(temp_partition.storage_permissions.items())\n", - "real_storage = dict(real_partition.storage_permissions.items())\n", - "temp_storage = {k: v for k, v in temp_storage.items() if k in real_storage}\n", - "\n", - "assert temp_storage == real_storage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb similarity index 85% rename from notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb rename to notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb index 3495c4a3d00..29690a49f07 100644 --- a/notebooks/tutorials/version-upgrades/2a-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb @@ -8,9 +8,12 @@ "outputs": [], "source": [ "# stdlib\n", + "from pathlib import Path\n", "\n", "# syft absolute\n", - "import syft as sy" + "import syft as sy\n", + "from syft.service.code.user_code import UserCode\n", + "from syft.service.user.user import User" ] }, { @@ -28,7 +31,14 @@ "metadata": {}, "outputs": [], "source": [ - "client = sy.login(email=\"info@openmined.org\", password=\"changethis\", port=8080)" + "node = sy.orchestra.launch(\n", + " name=\"test_upgradability\",\n", + " dev_mode=True,\n", + " reset=True,\n", + " port=\"auto\",\n", + ")\n", + "\n", + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" ] }, { @@ -38,28 +48,16 @@ "metadata": {}, "outputs": [], "source": [ - "migration_data = client.get_migration_data()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "from syft.service.code.user_code import UserCode\n", - "from syft.service.user.user import User\n", + "# Check if this is a new node\n", + "migration_data = client.get_migration_data()\n", "\n", - "# Check if this is a clean node\n", "assert len(migration_data.store_objects[User]) == 1\n", "assert UserCode not in migration_data.store_objects" ] }, { "cell_type": "markdown", - "id": "5", + "id": "4", "metadata": {}, "source": [ "# Load migration data" @@ -68,25 +66,29 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "5", "metadata": {}, "outputs": [], "source": [ - "client.load_migration_data(\"my_migration.blob\")" + "blob_path = Path(\"./migration.blob\")\n", + "print(f\"Loading migration data from {str(blob_path.resolve())}\")\n", + "\n", + "res = client.load_migration_data(blob_path)\n", + "assert isinstance(res, sy.SyftSuccess)" ] }, { "cell_type": "markdown", - "id": "7", + "id": "6", "metadata": {}, "source": [ - "# DS login" + "# Post migration tests" ] }, { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -96,17 +98,17 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "8", "metadata": {}, "outputs": [], "source": [ - "client_ds = sy.login(email=\"ds@openmined.org\", password=\"pw\", port=8080)" + "client_ds = node.login(email=\"ds@openmined.org\", password=\"pw\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "9", "metadata": {}, "outputs": [], "source": [] @@ -114,7 +116,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -125,7 +127,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -135,7 +137,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +147,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +157,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -191,7 +193,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -201,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -211,7 +213,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +225,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -233,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -243,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -253,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -261,6 +263,17 @@ "assert isinstance(logs, str)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "if node.node_type.value == \"python\":\n", + " node.land()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb b/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb deleted file mode 100644 index 010c021b6db..00000000000 --- a/notebooks/tutorials/version-upgrades/2-post-migration-tests.ipynb +++ /dev/null @@ -1,272 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "import syft as sy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "node = sy.orchestra.launch(\n", - " name=\"temp_node\",\n", - " dev_mode=True,\n", - " local_db=True,\n", - " n_consumers=2,\n", - " create_producer=True,\n", - " migrate=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "# Post migration tests" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "client = node.login(email=\"info@openmined.org\", password=\"changethis\")\n", - "client_ds = node.login(email=\"ds@openmined.org\", password=\"pw\")" - ] - }, - { - "cell_type": "markdown", - "id": "4", - "metadata": {}, - "source": [ - "- [x] log in\n", - "- [x] get request / datasets / \n", - "- [x] check request is approved\n", - "- [x] run function\n", - "- [ ] run new function\n", - "- [x] repr (request, code, job)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "client.verify_key" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "# syft absolute\n", - "from syft.client.api import APIRegistry" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "APIRegistry.__api_registry__.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "node.python_node.id" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "code = client.code.get_all()[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "code.status_link.node_uid = node.python_node.id" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "code.status" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "req1 = client.requests[0]\n", - "req2 = client_ds.requests[0]\n", - "assert req1.status.name == \"APPROVED\" and req2.status.name == \"APPROVED\"\n", - "assert isinstance(req1._repr_html_(), str)\n", - "assert isinstance(req2._repr_html_(), str)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": {}, - "outputs": [], - "source": [ - "jobs = client_ds.jobs\n", - "assert isinstance(jobs[0]._repr_html_(), str)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "ds = client_ds.datasets\n", - "asset = ds[0].assets[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "res = client_ds.code.compute_mean(data=asset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "# third party\n", - "import numpy as np\n", - "\n", - "assert all(res == np.array([15, 16, 17, 18, 19]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "jobs = client_ds.jobs.get_all()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "job = jobs[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "job.logs()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "logs = job.logs(_print=False)\n", - "assert isinstance(logs, str)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 9547f70e1b0..4644e2d72e1 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -130,7 +130,7 @@ profiles: patches: - op: add path: deployments.syft.helm.valuesFiles - value: ./helm/examples/dev/migrated.yaml + value: ./helm/examples/dev/migration.yaml - name: domain-tunnel description: "Deploy a domain with tunneling enabled" diff --git a/packages/grid/helm/examples/dev/migrated.yaml b/packages/grid/helm/examples/dev/migrated.yaml deleted file mode 100644 index 6ee1083340e..00000000000 --- a/packages/grid/helm/examples/dev/migrated.yaml +++ /dev/null @@ -1,6 +0,0 @@ -node: - env: - - name: NODE_UID - value: null - - name: NODE_PRIVATE_KEY - value: null diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index aceff1f7f11..0bca010b3ee 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -181,12 +181,15 @@ def _set( if result.is_ok(): if isinstance(action_object, TwinObject): # give read permission to the mock - blob_id = action_object.mock_obj.syft_blob_storage_entry_id - permission = ActionObjectPermission(blob_id, ActionPermission.ALL_READ) - blob_storage_service: AbstractService = context.node.get_service( - BlobStorageService - ) - blob_storage_service.stash.add_permission(permission) + if not skip_save_to_blob_store: + blob_id = action_object.mock_obj.syft_blob_storage_entry_id + permission = ActionObjectPermission( + blob_id, ActionPermission.ALL_READ + ) + blob_storage_service: AbstractService = context.node.get_service( + BlobStorageService + ) + blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private else: diff --git a/tox.ini b/tox.ini index 1ba9d01ab16..f78570e60c7 100644 --- a/tox.ini +++ b/tox.ini @@ -1075,7 +1075,7 @@ allowlist_externals = ; setenv commands = bash -c 'python -c "import syft as sy; print(sy.__version__)"' - pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/0-prepare-migration-data.ipynb -vvvv + pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb -vvvv [testenv:migration.test] description = Migration Test @@ -1091,4 +1091,99 @@ allowlist_externals = commands = bash -c 'python -c "import syft as sy; print(sy.__version__)"' tox -e migration.prepare - pytest -x --nbmake --nbmake-timeout=1000 notebooks/experimental/migration/1-connect-and-migrate.ipynb -vvvv + pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb -vvvv + bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.blob' + bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.yaml' + pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb -vvvv +commands_post = + bash -c 'rm -f notebooks/tutorials/version-upgrades/migration.blob' + bash -c 'rm -f notebooks/tutorials/version-upgrades/migration.yaml' + +[testenv:migration.k8s.test] +description = Migration Test on K8s +setenv = + GITHUB_CI = {env:GITHUB_CI:false} + SYFT_BASE_IMAGE_REGISTRY = {env:SYFT_BASE_IMAGE_REGISTRY:k3d-registry.localhost:5800} + DOMAIN_CLUSTER_NAME = {env:DOMAIN_CLUSTER_NAME:test-domain-1} + NODE_PORT = {env:NODE_PORT:8080} +deps = + {[testenv:syft]deps} + nbmake +changedir = {toxinidir} +passenv=HOME, USER +allowlist_externals = + bash + tox + pytest + devspace + kubectl + grep + sleep + k3d + echo +commands = + ;NOTE migration data dump is created on a local (non-k8s) environment and then uploaded to a k8s cluster + ;This is only needed because our previous deployed version does not have the migration feature + bash -c 'python -c "import syft as sy; print(sy.__version__)"' + tox -e migration.prepare + pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb -vvvv + bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.blob' + bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.yaml' + + # Make migration.yaml available for devspace migration profile and launch cluster for migration + bash -c 'cp notebooks/tutorials/version-upgrades/migration.yaml packages/grid/helm/examples/dev/migration.yaml' + + # Start the cluster + # set env variable for orchestra deployment type + bash -c "export ORCHESTRA_DEPLOYMENT_TYPE=remote" + bash -c "echo Running with GITHUB_CI=$GITHUB_CI; date" + python -c 'import syft as sy; sy.stage_protocol_changes()' + k3d version + + # Deleting Old Cluster + bash -c "k3d cluster delete ${DOMAIN_CLUSTER_NAME} || true" + + # Deleting registry & volumes + bash -c "k3d registry delete k3d-registry.localhost || true" + bash -c "docker volume rm k3d-${DOMAIN_CLUSTER_NAME}-images --force || true" + + # Create registry + tox -e dev.k8s.registry + + # Creating test-domain-1 cluster on port NODE_PORT + # NOTE set DEVSPACE_PROFILE=migrated-domain will start the cluster with variables from migration.yaml + bash -c '\ + export CLUSTER_NAME=${DOMAIN_CLUSTER_NAME} \ + CLUSTER_HTTP_PORT=${NODE_PORT} \ + DEVSPACE_PROFILE=migrated-domain && \ + tox -e dev.k8s.start && \ + tox -e dev.k8s.deploy' + + # free up build cache after build of images + bash -c 'if [[ "$GITHUB_CI" != "false" ]]; then \ + docker image prune --all --force; \ + docker builder prune --all --force; \ + fi' + + sleep 30 + + ; # wait for test-domain-1 + ; bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DOMAIN_CLUSTER_NAME} --namespace syft + ; bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DOMAIN_CLUSTER_NAME} --namespace syft + ; bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DOMAIN_CLUSTER_NAME} --namespace syft + ; bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DOMAIN_CLUSTER_NAME} --namespace syft + ; bash packages/grid/scripts/wait_for.sh service frontend --context k3d-{env:DOMAIN_CLUSTER_NAME} --namespace syft + ; bash -c '(kubectl logs service/frontend --context k3d-${DOMAIN_CLUSTER_NAME} --namespace syft -f &) | grep -q -E "Network:\s+https?://[a-zA-Z0-9.-]+:[0-9]+/" || true' + + ; # Checking logs generated & startup of test-domain 1 + ; bash -c '(kubectl logs service/backend --context k3d-${DOMAIN_CLUSTER_NAME} --namespace syft -f &) | grep -q "Application startup complete" || true' + + # Run Notebook tests + pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb -vvvv + + # deleting clusters created + bash -c "CLUSTER_NAME=${DOMAIN_CLUSTER_NAME} tox -e dev.k8s.destroy || true" + +commands_post = + bash -c 'rm -f notebooks/tutorials/version-upgrades/migration.blob' + bash -c 'rm -f notebooks/tutorials/version-upgrades/migration.yaml' \ No newline at end of file From df623a95c42ac3cd9fbfc030114b2148244284a9 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 5 Jul 2024 13:53:15 +0200 Subject: [PATCH 274/309] comments --- tox.ini | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tox.ini b/tox.ini index f78570e60c7..3d7fddbdd1d 100644 --- a/tox.ini +++ b/tox.ini @@ -1122,15 +1122,19 @@ allowlist_externals = k3d echo commands = - ;NOTE migration data dump is created on a local (non-k8s) environment and then uploaded to a k8s cluster - ;This is only needed because our previous deployed version does not have the migration feature + # - create migration data on local 0.8.6 node + # - dump the data to a file on a local dev node + # - migrate the data to a new cluster + # NOTE This is only needed because our previous deployed version + # does not have the migration feature, once it does we can create + # the migration data on a k8s deployment as well. bash -c 'python -c "import syft as sy; print(sy.__version__)"' tox -e migration.prepare pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb -vvvv bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.blob' bash -c 'ls -l notebooks/tutorials/version-upgrades/migration.yaml' - # Make migration.yaml available for devspace migration profile and launch cluster for migration + # Make migration.yaml available for devspace migration bash -c 'cp notebooks/tutorials/version-upgrades/migration.yaml packages/grid/helm/examples/dev/migration.yaml' # Start the cluster From 138877e1d53a92ebe3db8c239b37459f6b815cb7 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 5 Jul 2024 14:07:37 +0200 Subject: [PATCH 275/309] revert --- packages/syft/src/syft/service/action/action_permissions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 1a71b75b9ae..6131dcf5d08 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -34,8 +34,6 @@ def __init__( permission: ActionPermission, credentials: SyftVerifyKey | None = None, ): - if not isinstance(uid, UID): - raise ValueError(f"uid must be of type UID not {type(uid)}") if credentials is None: if permission not in COMPOUND_ACTION_PERMISSION: raise Exception(f"{permission} not in {COMPOUND_ACTION_PERMISSION}") From 0c9a405e3e8e4cdcff7ff19065bef4c5d121b13f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 5 Jul 2024 14:25:55 +0200 Subject: [PATCH 276/309] bootstrap with formatted uid --- packages/grid/backend/grid/bootstrap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/grid/backend/grid/bootstrap.py b/packages/grid/backend/grid/bootstrap.py index 15859c053aa..0e1878e2ee1 100644 --- a/packages/grid/backend/grid/bootstrap.py +++ b/packages/grid/backend/grid/bootstrap.py @@ -101,8 +101,8 @@ def validate_private_key(private_key: str | bytes) -> str: def validate_uid(node_uid: str) -> str: try: uid = uuid.UUID(node_uid) - if node_uid == uid.hex: - return uid.hex + if node_uid == uid.hex or node_uid == str(uid): + return str(uid) except Exception: pass raise Exception(f"{NODE_UID} is invalid") From 8e1c7c0848bf8322634c0d9e6370f6cd0a5cd809 Mon Sep 17 00:00:00 2001 From: alfred-openmined-bot <145415986+alfred-openmined-bot@users.noreply.github.com> Date: Fri, 5 Jul 2024 15:23:33 +0000 Subject: [PATCH 277/309] bump protocol and remove notebooks --- .../syft/src/syft/protocol/protocol_version.json | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index d2ab6739c4c..21e62327e28 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -486,6 +486,20 @@ "hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86", "action": "add" } + }, + "Dataset": { + "3": { + "version": 3, + "hash": "12a24de2ec144fe54eb873767131abed362827e08cd47a5c4d128acee2f2db91", + "action": "add" + } + }, + "CreateDataset": { + "3": { + "version": 3, + "hash": "c96a925ab43440b9c39ab5eb28d7fc4a4c8bf7934474b7c432dc7b12067c7766", + "action": "add" + } } } } From b2d18705af99358a36ebf65ed7211e12393ccddc Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Sat, 6 Jul 2024 13:33:17 +0200 Subject: [PATCH 278/309] revert load_user_code change --- packages/syft/src/syft/node/node.py | 3 +-- packages/syft/src/syft/service/code/user_code_service.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 9701be0f023..0e1f284c547 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -821,8 +821,7 @@ def post_init(self) -> None: if "usercodeservice" in self.service_path_map: user_code_service = self.get_service(UserCodeService) - # TODO this does not work with un-migrated UserCode - # user_code_service.load_user_code(context=context) + user_code_service.load_user_code(context=context) def reload_user_code() -> None: user_code_service.load_user_code(context=context) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 41aba58575b..81c0a9a0918 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -358,6 +358,8 @@ def load_user_code(self, context: AuthedServiceContext) -> None: result = self.stash.get_all(credentials=context.credentials) if result.is_ok(): user_code_items = result.ok() + # Filter out UserCode items that are not updated to the latest versio + user_code_items = [x for x in user_code_items if isinstance(x, UserCode)] load_approved_policy_code(user_code_items=user_code_items, context=context) def is_execution_allowed( From eb07fc9a615a28fd13ad9d12f68eb2750eba56f7 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 8 Jul 2024 12:31:10 +0800 Subject: [PATCH 279/309] Raise AttributeError in __getattr__ instead of bare Exception --- packages/syft/src/syft/service/response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 2908454096e..57c02deb195 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -35,7 +35,7 @@ def __getattr__(self, name: str) -> Any: ] or name.startswith("_repr"): return super().__getattr__(name) display(self) - raise Exception( + raise AttributeError( f"You have tried accessing `{name}` on a {type(self).__name__} with message: {self.message}" ) From e9a7b828cb207b71792004df7e4b1601f90ac948 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 8 Jul 2024 07:12:26 +0200 Subject: [PATCH 280/309] fix has_result_read_permission --- .../syft/src/syft/service/action/action_service.py | 5 +++++ .../syft/src/syft/service/code/user_code_service.py | 11 ++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 06c44071938..0777c33bd8b 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -364,6 +364,11 @@ def get_mock( def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> bool: return self.store.has_storage_permission(uid) + def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool: + return self.store.has_permissions( + [ActionObjectREAD(uid=uid, credentials=context.credentials)] + ) + # not a public service endpoint def _user_code_execute( self, diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 41aba58575b..1203d46f4c2 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -20,6 +20,7 @@ from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission +from ..action.action_service import ActionService from ..context import AuthedServiceContext from ..output.output_service import ExecutionOutput from ..policy.policy import OutputPolicy @@ -579,7 +580,7 @@ def _call( "which is currently not supported. Run your function with `blocking=False` to run" " as a job on your worker pool" ) - action_service = context.node.get_service("actionservice") + action_service: ActionService = context.node.get_service("actionservice") result_action_object: Result[ActionObject | TwinObject, str] = ( action_service._user_code_execute( context, code, kwarg2id, result_id=result_id @@ -619,14 +620,10 @@ def _call( # res = self.update_code_state(context, code) # print(res) - has_result_read_permission = context.extra_kwargs.get( - "has_result_read_permission", False + has_result_read_permission = action_service.has_read_permission( + context, result.id ) - # TODO: Just to fix the issue with the current implementation - if context.role == ServiceRole.ADMIN: - has_result_read_permission = True - if isinstance(result, TwinObject): if has_result_read_permission: return Ok(result.private) From 4a85d41c59738419ef3ea07950bd28dcae1055a6 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 8 Jul 2024 07:18:03 +0200 Subject: [PATCH 281/309] lint --- packages/syft/src/syft/service/code/user_code_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 1203d46f4c2..6853a9637aa 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -580,7 +580,7 @@ def _call( "which is currently not supported. Run your function with `blocking=False` to run" " as a job on your worker pool" ) - action_service: ActionService = context.node.get_service("actionservice") + action_service: ActionService = context.node.get_service("actionservice") # type: ignore result_action_object: Result[ActionObject | TwinObject, str] = ( action_service._user_code_execute( context, code, kwarg2id, result_id=result_id From 0c41f2d163c34b7bd0272c8f05449e0d9dbdab8a Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 8 Jul 2024 07:31:17 +0200 Subject: [PATCH 282/309] revert back to previous change --- packages/syft/src/syft/service/action/action_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 0777c33bd8b..917b67074a4 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -549,7 +549,7 @@ def set_result_to_store( result_blob_id = result_action_object.syft_blob_storage_entry_id # type: ignore[unreachable] # pass permission information to the action store as extra kwargs - context.extra_kwargs = {"has_result_read_permission": True} + # context.extra_kwargs = {"has_result_read_permission": True} # Since this just meta data about the result, they always have access to it. set_result = self._set( From c860a850bfed9f82099385c49bad95fda63ce468 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 11:16:03 +0200 Subject: [PATCH 283/309] fix custom policies --- packages/syft/src/syft/service/policy/policy.py | 14 ++++++++++---- .../syft/src/syft/types/syft_object_registry.py | 5 +++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index b2528241bb2..588ef4cef66 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -38,6 +38,7 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.syft_object_registry import SyftObjectRegistry from ...types.transforms import TransformContext from ...types.transforms import generate_id from ...types.transforms import transform @@ -57,6 +58,8 @@ from ..response import SyftError from ..response import SyftSuccess +DEFAULT_USER_POLICY_VERSION = 1 + PolicyUserVerifyKeyPartitionKey = PartitionKey( key="user_verify_key", type_=SyftVerifyKey ) @@ -1155,10 +1158,13 @@ def execute_policy_code(user_policy: UserPolicy) -> Any: sys.stdout = stdout sys.stderr = stderr - class_name = f"{user_policy.unique_name}" - if class_name in user_policy.__object_version_registry__.keys(): - policy_class = user_policy.__object_version_registry__[class_name] - else: + class_name = user_policy.unique_name + + try: + policy_class = SyftObjectRegistry.get_serde_class( + class_name, version=DEFAULT_USER_POLICY_VERSION + ) + except Exception: exec(user_policy.byte_code) # nosec policy_class = eval(user_policy.unique_name) # nosec diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index 0d92aba8e68..424c055fe0d 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -56,6 +56,11 @@ def get_canonical_name(cls, obj: Any) -> str: def get_serde_properties(cls, canonical_name: str, version: int) -> tuple: return cls.__object_serialization_registry__[canonical_name][version] + @classmethod + def get_serde_class(cls, canonical_name: str, version: int) -> type["SyftObject"]: + serde_properties = cls.get_serde_properties(canonical_name, version) + return serde_properties[7] + @classmethod def get_serde_properties_bw_compatible( cls, fqn: str, canonical_name: str, version: int From 8c95f482e042781aabb88da4a7bf36436c7c1605 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 11:27:45 +0200 Subject: [PATCH 284/309] add test to CI --- .github/workflows/pr-tests-stack.yml | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/.github/workflows/pr-tests-stack.yml b/.github/workflows/pr-tests-stack.yml index eae0b555c3e..d94230d155e 100644 --- a/.github/workflows/pr-tests-stack.yml +++ b/.github/workflows/pr-tests-stack.yml @@ -415,3 +415,64 @@ jobs: k3d cluster delete test-gateway-1 || true k3d cluster delete test-domain-1 || true k3d registry delete k3d-registry.localhost || true + + pr-tests-migrations: + strategy: + max-parallel: 99 + matrix: + os: [ubuntu-latest] + python-version: ["3.12"] + + runs-on: ${{ matrix.os }} + steps: + - name: "clean .git/config" + if: matrix.os == 'windows-latest' + continue-on-error: true + shell: bash + run: | + echo "deleting ${GITHUB_WORKSPACE}/.git/config" + rm ${GITHUB_WORKSPACE}/.git/config + + - uses: actions/checkout@v4 + + - name: Check for file changes + uses: dorny/paths-filter@v3 + id: changes + with: + base: ${{ github.ref }} + token: ${{ github.token }} + filters: .github/file-filters.yml + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + if: steps.changes.outputs.syft == 'true' + with: + python-version: ${{ matrix.python-version }} + + - name: Install pip packages + if: steps.changes.outputs.syft == 'true' + run: | + python -m pip install --upgrade pip + pip install uv==0.2.17 tox tox-uv==1.9.0 + uv --version + + - name: Get uv cache dir + id: pip-cache + if: steps.changes.outputs.syft == 'true' + shell: bash + run: | + echo "dir=$(uv cache dir)" >> $GITHUB_OUTPUT + + - name: Load github cache + uses: actions/cache@v4 + if: steps.changes.outputs.syft == 'true' + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-uv-py${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} + restore-keys: | + ${{ runner.os }}-uv-py${{ matrix.python-version }}- + + - name: Run migration tests + if: steps.changes.outputs.syft == 'true' + run: | + tox -e migration.test From 22d641b235c0bf8d39be6fc4224837feb82bda2f Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Fri, 5 Jul 2024 11:29:59 -0400 Subject: [PATCH 285/309] fix not removing table / escaped js/css from non-template html --- packages/syft/src/syft/util/patch_ipython.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/util/patch_ipython.py b/packages/syft/src/syft/util/patch_ipython.py index 4910eb77d52..b5f16f512a9 100644 --- a/packages/syft/src/syft/util/patch_ipython.py +++ b/packages/syft/src/syft/util/patch_ipython.py @@ -83,7 +83,7 @@ def display_sanitized_html(obj: SyftObject | DictTuple) -> str | None: template = "\n".join(matching_table + matching_jobs) sanitized_str = escaped_template.sub("", html_str) sanitized_str = escaped_js_css.sub("", sanitized_str) - sanitized_str = jobs_pattern.sub("", html_str) + sanitized_str = jobs_pattern.sub("", sanitized_str) sanitized_str = sanitize_html(sanitized_str) return f"{css_reinsert} {sanitized_str} {template}" return None From f292cdaa2e8b90b68e904dc07ac9cce07477ab1e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 13:22:37 +0200 Subject: [PATCH 286/309] fix actionobject migration --- .../2-migrate-from-file.ipynb | 24 +++++-------------- .../service/migration/migration_service.py | 10 ++++++-- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb index 29690a49f07..24329301c61 100644 --- a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb @@ -207,25 +207,13 @@ "metadata": {}, "outputs": [], "source": [ - "res.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "# third party\n", - "\n", "assert res.shape == (100_000, 5)" ] }, { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +223,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -245,7 +233,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -255,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -266,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -277,7 +265,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "24", "metadata": {}, "outputs": [], "source": [] diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 69eb482b248..a8061ceb4e0 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -496,7 +496,6 @@ def migrate_data( if objects_update_update_result.is_err(): return SyftError(message=objects_update_update_result.value) - # now action objects migration_actionobjects_result = self._get_migration_actionobjects(context) if migration_actionobjects_result.is_err(): @@ -536,6 +535,9 @@ def _get_migration_actionobjects( # Track all object types from action store action_object_types = [Action, ActionObject] action_object_types.extend(ActionObject.__subclasses__()) + klass_by_canonical_name = { + klass.__canonical_name__: klass for klass in action_object_types + } action_object_pending_migration = self._find_klasses_pending_for_migration( context=context, object_types=action_object_types @@ -551,7 +553,11 @@ def _get_migration_actionobjects( for obj in action_store_objects: if get_all or type(obj) in action_object_pending_migration: - result_dict[type(obj)].append(obj) + klass = klass_by_canonical_name.get( + obj.__canonical_name__, + type(obj), + ) + result_dict[klass].append(obj) return Ok(dict(result_dict)) @service_method( From 11e7072496bc76c19e05dcc5f112ce14cb5b4d0b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 13:29:13 +0200 Subject: [PATCH 287/309] add dev install in migration.prepare --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 3d7fddbdd1d..ba9b48cb09a 100644 --- a/tox.ini +++ b/tox.ini @@ -1066,7 +1066,7 @@ commands = [testenv:migration.prepare] description = Prepare Migration Data deps = - syft + syft[dev] nbmake allowlist_externals = bash From 0dbce33a977d353d8e0e2564874c388adc8d64a7 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 13:39:51 +0200 Subject: [PATCH 288/309] add pip show --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index ba9b48cb09a..288af7818f9 100644 --- a/tox.ini +++ b/tox.ini @@ -1074,6 +1074,7 @@ allowlist_externals = ; changedir ; setenv commands = + bash -c 'python -m pip show syft' bash -c 'python -c "import syft as sy; print(sy.__version__)"' pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb -vvvv From 6ebd203b6a94e321373757ecb42568aa49575537 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 13:42:09 +0200 Subject: [PATCH 289/309] specify syft version for prepare --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 288af7818f9..b08ae974555 100644 --- a/tox.ini +++ b/tox.ini @@ -1066,7 +1066,7 @@ commands = [testenv:migration.prepare] description = Prepare Migration Data deps = - syft[dev] + syft==0.8.6 nbmake allowlist_externals = bash From 5f8f125ed1465cb9758d886093ae3602a835b9c3 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 13:53:31 +0200 Subject: [PATCH 290/309] try to fix pycapnp --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index b08ae974555..bf9d9802cd4 100644 --- a/tox.ini +++ b/tox.ini @@ -1066,6 +1066,7 @@ commands = [testenv:migration.prepare] description = Prepare Migration Data deps = + pycapnp==2.0.0 syft==0.8.6 nbmake allowlist_externals = From 11038b4179d0773dd2c390abdc3397eb8b82c4e8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 14:12:51 +0200 Subject: [PATCH 291/309] add pre flag --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index bf9d9802cd4..f1b7246967c 100644 --- a/tox.ini +++ b/tox.ini @@ -1065,8 +1065,8 @@ commands = [testenv:migration.prepare] description = Prepare Migration Data +pip_pre = True deps = - pycapnp==2.0.0 syft==0.8.6 nbmake allowlist_externals = From e3b92349db2bf7e92b75bafcff46f945b42fe683 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 14:15:44 +0200 Subject: [PATCH 292/309] add pre flag --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index f1b7246967c..6e0164a5c31 100644 --- a/tox.ini +++ b/tox.ini @@ -1075,7 +1075,6 @@ allowlist_externals = ; changedir ; setenv commands = - bash -c 'python -m pip show syft' bash -c 'python -c "import syft as sy; print(sy.__version__)"' pytest -x --nbmake --nbmake-timeout=1000 notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb -vvvv From c1754afd7619917eb10a2f768a15e8ef52ae1ff8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 14:19:40 +0200 Subject: [PATCH 293/309] remove version assert --- .../version-upgrades/0-prepare-migration-data.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 57d8fbc2ddb..37be2b11305 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -45,9 +45,9 @@ "print(\n", " f\"latest deployed version: {latest_deployed_version}, installed version: {sy.__version__}\"\n", ")\n", - "assert (\n", - " latest_deployed_version == sy.__version__\n", - "), f\"{latest_deployed_version} does not match installed version {sy.__version__}\"" + "# assert (\n", + "# latest_deployed_version == sy.__version__\n", + "# ), f\"{latest_deployed_version} does not match installed version {sy.__version__}\"" ] }, { From 8444ba39b0bd3a417792fb2a1eae1751dc98ce89 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 14:39:21 +0200 Subject: [PATCH 294/309] mypy --- .../syft/src/syft/service/migration/migration_service.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index a8061ceb4e0..d3e1d23d686 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -535,7 +535,7 @@ def _get_migration_actionobjects( # Track all object types from action store action_object_types = [Action, ActionObject] action_object_types.extend(ActionObject.__subclasses__()) - klass_by_canonical_name = { + klass_by_canonical_name: dict[str, type[SyftObject]] = { klass.__canonical_name__: klass for klass in action_object_types } @@ -553,11 +553,8 @@ def _get_migration_actionobjects( for obj in action_store_objects: if get_all or type(obj) in action_object_pending_migration: - klass = klass_by_canonical_name.get( - obj.__canonical_name__, - type(obj), - ) - result_dict[klass].append(obj) + klass = klass_by_canonical_name.get(obj.__canonical_name__, type(obj)) + result_dict[klass].append(obj) # type: ignore return Ok(dict(result_dict)) @service_method( From 0a1026a62a1b50bfec985236d29f178a34171064 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 16:05:49 +0200 Subject: [PATCH 295/309] fix protocol --- .../syft/src/syft/protocol/data_protocol.py | 14 +- .../src/syft/protocol/protocol_version.json | 469 +++++++++--------- .../syft/src/syft/service/job/job_stash.py | 1 - 3 files changed, 229 insertions(+), 255 deletions(-) diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 9a81261e5f3..4cf71ab8ac7 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -183,13 +183,13 @@ def build_state(self, stop_key: str | None = None) -> dict: hash_str = object_metadata["hash"] state_versions = state_dict[canonical_name] state_version_hashes = [val[0] for val in state_versions.values()] - # if action == "add" and ( - # str(version) in state_versions.keys() - # or hash_str in state_version_hashes - # ): - # raise Exception( - # f"Can't add {object_metadata} already in state {versions}" - # ) + if action == "add" and ( + str(version) in state_versions.keys() + or hash_str in state_version_hashes + ): + raise Exception( + f"Can't add {object_metadata} already in state {versions}" + ) if action == "remove" and ( str(version) not in state_versions.keys() and hash_str not in state_version_hashes diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 95e2f064f61..4458efbe66e 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -27,6 +27,13 @@ "action": "add" } }, + "SyftWorker": { + "3": { + "version": 3, + "hash": "e124f56ddf4565df2be056553eecd15de7c80bd5f5fd0d06e8ff7815bb05563a", + "action": "add" + } + }, "HTMLObject": { "1": { "version": 1, @@ -52,11 +59,6 @@ } }, "NodeSettings": { - "3": { - "version": 3, - "hash": "2d5f6e79f074f75b5cfc2357eac7cf635b8f083421009a513240b4dbbd5a0fc1", - "action": "add" - }, "5": { "version": 5, "hash": "cde18eb23fdffcfba47bc0e85efdbba1d59f1f5d6baa9c9690e1af14b35eb74e", @@ -68,6 +70,51 @@ "action": "add" } }, + "HTTPConnection": { + "2": { + "version": 2, + "hash": "68409295f8916ceb22a8cf4abf89f5e4bcff0d75dc37e16ede37250ada28df59", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "b61b30d10e2d25726c708ef34b69c7b730d41b16b315e7062f3d487e943143d5", + "action": "add" + } + }, + "PythonConnection": { + "2": { + "version": 2, + "hash": "eb479c671fc112b2acbedb88bc5624dfdc9592856c04c22c66410f6c863e1708", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", + "action": "add" + } + }, + "ActionObject": { + "4": { + "version": 4, + "hash": "a4dd2949af0f516d0f640d28e0fdfa026ba8d55bb29eaa7844c926e467606892", + "action": "add" + } + }, + "AnyActionObject": { + "4": { + "version": 4, + "hash": "809bd7ffab211133a9be87e058facecf870a79cb2d4027616f5244323de27091", + "action": "add" + } + }, + "BlobFileOBject": { + "3": { + "version": 3, + "hash": "27901fcd545ad0607dbfcbfa0141ee03b0f0f4bee8d23f2d661a4b22011bfd37", + "action": "add" + } + }, "BlobRetrievalByURL": { "5": { "version": 5, @@ -75,6 +122,30 @@ "action": "add" } }, + "HTTPNodeRoute": { + "2": { + "version": 2, + "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "d26cb313e92b1fbe36995c8ed4103a9168ea6e589b2803ed9a91c23f14bf0c96", + "action": "add" + } + }, + "PythonNodeRoute": { + "2": { + "version": 2, + "hash": "3eca5767ae4a8fbe67744509e58c6d9fb78f38fa0a0f7fcf5960ab4250acc1f0", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", + "action": "add" + } + }, "EnclaveMetadata": { "2": { "version": 2, @@ -94,12 +165,28 @@ "action": "add" } }, - "JobItem": { + "Dataset": { + "3": { + "version": 3, + "hash": "12a24de2ec144fe54eb873767131abed362827e08cd47a5c4d128acee2f2db91", + "action": "add" + } + }, + "CreateDataset": { + "3": { + "version": 3, + "hash": "c96a925ab43440b9c39ab5eb28d7fc4a4c8bf7934474b7c432dc7b12067c7766", + "action": "add" + } + }, + "SyftLog": { "4": { "version": 4, - "hash": "6a7cc7c2bb4dd234c1508b0af4d3b403cd3b7b427578a775bf80dc36891923ed", + "hash": "ad6ef18ccd87fced669f3824d27ab423aaf52574b0cd4f720687aeaba77524e5", "action": "add" - }, + } + }, + "JobItem": { "6": { "version": 6, "hash": "865a2ed791b8abd20d76e9a6bfae7ae7dad51b5ebfd8ff728aab25af93fa5570", @@ -107,193 +194,168 @@ } }, "ExecutionOutput": { - "1": { - "version": 1, - "hash": "c2337099eba14767ead75fcc1b1fa265c1898461ede0b5e7758a0e8d11d1757d", - "action": "add" - }, "2": { "version": 2, "hash": "854fe9df5bcbb5c7e5b7c467bac423cd98c32f93d6876fea7b8eb6c08f6596da", "action": "add" } }, - "CreateCustomImageChange": { - "3": { - "version": 3, - "hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995", - "action": "add" - } - }, - "TwinAPIContextView": { + "PolicyRule": { "1": { "version": 1, - "hash": "7d368102d0b473009af3b8c46e0ea6d35893c5ebb172b373ad7d52553c12c9fa", + "hash": "31a982b94654ce27ad27a6622c6fa26dfe3f759a7824ac21d104390f10a5aa82", "action": "add" } }, - "CustomAPIView": { + "CreatePolicyRule": { "1": { "version": 1, - "hash": "0b9afdd554f0b353c07256e2522342be1302b395d649f1cbabc555e5baecb150", + "hash": "9b82e36c63e10c5b7b76b3b8ec1da1d2dfdce39f2cce98603a418ec221621874", "action": "add" } }, - "CustomApiEndpoint": { + "CreatePolicyRuleConstant": { "1": { "version": 1, - "hash": "13617f3dce60fa692421e0d9deda7ffd365ec02d4a062c18510b48457b6eba02", + "hash": "9e821ddd383b6472f95dad2b56ebaefad225ff49c96b89b4ce35dc99c422ba76", "action": "add" } }, - "PrivateAPIEndpoint": { + "Matches": { "1": { "version": 1, - "hash": "004ec19753263440e2896b4e35d7a6305322934512f473f37d54043af5726fe6", + "hash": "d1e875a6332a481458e83db364dfdf92bd34a87093d9762dfe8e136e5088bc4e", "action": "add" } }, - "PublicAPIEndpoint": { + "PreFill": { "1": { "version": 1, - "hash": "5589b6bdd045ee9c45987dae78fd5a1124530a6c493e2328b304d9273b75177f", + "hash": "22c38b8ad68409493810362e6c48822d3e2919760f26eba2d1de3f2ad394e314", "action": "add" } }, - "UpdateTwinAPIEndpoint": { + "UserOwned": { "1": { "version": 1, - "hash": "6d8effd404f15d4378b1ff3382e0622b9e5a637d9db342d43cfec00fe29c649a", + "hash": "b5cbb44d742fa51b9adf2a48bb56d9ff5ca82a25f8568a2505961bd906d9d084", "action": "add" } }, - "CreateTwinAPIEndpoint": { + "MixedInputPolicy": { "1": { "version": 1, - "hash": "55e0a7b0ac428a6abb771ffcb925ee79cdd752a4b83058aa4b71fbef2a9fee63", + "hash": "0e84e4c91e378717e1a4703574b07e3b1e6a3e5707401b4e0cc8d30088a506b9", "action": "add" } }, - "TwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "e538734d20be3b477e188eb91f66600c2e654bb32e34806ef24329e48238bf18", + "EmptyInputPolicy": { + "2": { + "version": 2, + "hash": "3117e16cbe4dbc344ab90fbbd36ba90dfb518e66f0fb07644bbe7864dcdeb309", "action": "add" } }, - "SyftLog": { - "3": { - "version": 3, - "hash": "8964d48238672e0e5d5db6b932cda4ee8eb77581949ab3f7a38a05b1efec13b7", - "action": "add" - }, - "4": { - "version": 4, - "hash": "ad6ef18ccd87fced669f3824d27ab423aaf52574b0cd4f720687aeaba77524e5", + "UserCode": { + "5": { + "version": 5, + "hash": "c2409f51bf920cce557d288c40b6964ec4df3d8c23e33c5d5668addc30368632", "action": "add" } }, - "SyncState": { - "1": { - "version": 1, - "hash": "a0616775ec8ef0629e2d91e0df9cc4237ea3674727eda1ce367f1897ee35767d", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "925f1b8ccd4b9d542700a111f9c4bdd28bfa55978d805ddb2fb3c108cc940d19", - "action": "add" - }, - "3": { - "version": 3, - "hash": "1b5fd28919cb496f8073a64a57736d477ace1ed969962b1b049cecf766f2661c", + "SubmitUserCode": { + "5": { + "version": 5, + "hash": "3135727b8f0ca7689d47c04e45a2bd6a7693f17c043f76fd2243135196c27232", "action": "add" } }, - "NodePeer": { + "CodeHistory": { "3": { "version": 3, - "hash": "ec0e39fc77ddb542558519d6a1f7c55f41cc037b6312792333792a04feea57e6", + "hash": "1b9bd1d3d096abab5617c2ff597b4c80751f686d16482a2cff4efd8741b84d53", "action": "add" } }, - "AssociationRequestChange": { + "StoreMetadata": { "1": { "version": 1, - "hash": "508550c43e9e3f30243813c23eb6eec20918923d7ba09498cddbcd7e8bfa4539", + "hash": "4a0522eaf28414dd53adcb7d5edb81b4a5b5bbe2e805cb78aa91329c3d6c32a8", "action": "add" } }, - "APIEndpointQueueItem": { + "MigrationData": { "1": { "version": 1, - "hash": "d31b2edfb767401c810584baccd27e4f566181c3ef7706618a82eb25ae20ff6d", + "hash": "c5be6bb4f34b04f814e15468d5231e47540c5b7d2ea0f2770e6cd332f61173c7", "action": "add" } }, - "NodeMetadataUpdate": { - "2": { - "version": 2, - "hash": "520ae8ffc0c057ffa827cb7b267a19fb6b92e3cf3c0a3666ac34e271b6dd0aed", + "SeaweedFSBlobDeposit": { + "3": { + "version": 3, + "hash": "05e61e6328b085b738e5d41c0781d87852d44d218894cb3008f5be46e337f6d8", "action": "remove" + }, + "4": { + "version": 4, + "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", + "action": "add" } }, - "SyncStateItem": { - "1": { - "version": 1, - "hash": "4dbfa0813f5a3f7be0b36249ff2d67e395ad7c9e138c5a122fc7342b8dcc4b92", - "action": "remove" + "NumpyArrayObject": { + "4": { + "version": 4, + "hash": "19e2ff3da78038d2164f86d1f9b0d1facc6008483be60d2852458e90202bb96b", + "action": "add" } }, - "VeilidConnection": { - "1": { - "version": 1, - "hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b", - "action": "remove" + "NumpyScalarObject": { + "4": { + "version": 4, + "hash": "5101d00dd92ac4391cae77629eb48aa25401cc8c5ebb28a8a969cd5eba35fb67", + "action": "add" } }, - "CreateCustomWorkerPoolChange": { - "3": { - "version": 3, - "hash": "e982f2ebcdc6fe23a65a014109e33ba7c487bb7ca5623723cf5ec7642f86828c", + "NumpyBoolObject": { + "4": { + "version": 4, + "hash": "764cd93792c4dfe27b8952fde853626592fe58e1a341b5350b23f38ce474583f", "action": "add" } }, - "NodePeerUpdate": { - "1": { - "version": 1, - "hash": "9e7cd39f6a9f90e8c595452865525e0989df1688236acfd1a665ed047ba47de9", + "PandasDataframeObject": { + "4": { + "version": 4, + "hash": "b70f4bb32ba9f3f5ea89552649bf882d927cf9085fb573cc6d4841b32d653f84", "action": "add" } }, - "JobInfo": { - "2": { - "version": 2, - "hash": "17a3986f1d55549a5ec2cadca6791bcafd84a92f442c220524d7665185064908", + "PandasSeriesObject": { + "4": { + "version": 4, + "hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86", "action": "add" } }, - "HTTPConnection": { - "2": { - "version": 2, - "hash": "68409295f8916ceb22a8cf4abf89f5e4bcff0d75dc37e16ede37250ada28df59", - "action": "remove" - }, + "CreateCustomImageChange": { "3": { "version": 3, - "hash": "b61b30d10e2d25726c708ef34b69c7b730d41b16b315e7062f3d487e943143d5", + "hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995", "action": "add" } }, - "UserCode": { - "4": { - "version": 4, - "hash": "0a7181cd5f76800b6566175ffa7276d0cf38c4ddc5110114430147dfc8bfdb2a", + "CreateCustomWorkerPoolChange": { + "3": { + "version": 3, + "hash": "e982f2ebcdc6fe23a65a014109e33ba7c487bb7ca5623723cf5ec7642f86828c", "action": "add" - }, - "5": { - "version": 5, - "hash": "c2409f51bf920cce557d288c40b6964ec4df3d8c23e33c5d5668addc30368632", + } + }, + "Request": { + "3": { + "version": 3, + "hash": "ba9ebb04cc3e8b3ae3302fd42a67e47261a0a330bae5f189d8f4819cf2804711", "action": "add" } }, @@ -304,220 +366,133 @@ "action": "add" } }, - "PolicyRule": { + "TwinAPIContextView": { "1": { "version": 1, - "hash": "31a982b94654ce27ad27a6622c6fa26dfe3f759a7824ac21d104390f10a5aa82", + "hash": "7d368102d0b473009af3b8c46e0ea6d35893c5ebb172b373ad7d52553c12c9fa", "action": "add" } }, - "CreatePolicyRule": { + "CustomAPIView": { "1": { "version": 1, - "hash": "9b82e36c63e10c5b7b76b3b8ec1da1d2dfdce39f2cce98603a418ec221621874", + "hash": "0b9afdd554f0b353c07256e2522342be1302b395d649f1cbabc555e5baecb150", "action": "add" } }, - "CreatePolicyRuleConstant": { + "CustomApiEndpoint": { "1": { "version": 1, - "hash": "9e821ddd383b6472f95dad2b56ebaefad225ff49c96b89b4ce35dc99c422ba76", + "hash": "13617f3dce60fa692421e0d9deda7ffd365ec02d4a062c18510b48457b6eba02", "action": "add" } }, - "Matches": { + "PrivateAPIEndpoint": { "1": { "version": 1, - "hash": "d1e875a6332a481458e83db364dfdf92bd34a87093d9762dfe8e136e5088bc4e", + "hash": "004ec19753263440e2896b4e35d7a6305322934512f473f37d54043af5726fe6", "action": "add" } }, - "PreFill": { + "PublicAPIEndpoint": { "1": { "version": 1, - "hash": "22c38b8ad68409493810362e6c48822d3e2919760f26eba2d1de3f2ad394e314", + "hash": "5589b6bdd045ee9c45987dae78fd5a1124530a6c493e2328b304d9273b75177f", "action": "add" } }, - "UserOwned": { + "UpdateTwinAPIEndpoint": { "1": { "version": 1, - "hash": "b5cbb44d742fa51b9adf2a48bb56d9ff5ca82a25f8568a2505961bd906d9d084", + "hash": "6d8effd404f15d4378b1ff3382e0622b9e5a637d9db342d43cfec00fe29c649a", "action": "add" } }, - "MixedInputPolicy": { + "CreateTwinAPIEndpoint": { "1": { "version": 1, - "hash": "0e84e4c91e378717e1a4703574b07e3b1e6a3e5707401b4e0cc8d30088a506b9", - "action": "add" - } - }, - "Request": { - "3": { - "version": 3, - "hash": "ba9ebb04cc3e8b3ae3302fd42a67e47261a0a330bae5f189d8f4819cf2804711", - "action": "add" - } - }, - "SubmitUserCode": { - "5": { - "version": 5, - "hash": "3135727b8f0ca7689d47c04e45a2bd6a7693f17c043f76fd2243135196c27232", + "hash": "55e0a7b0ac428a6abb771ffcb925ee79cdd752a4b83058aa4b71fbef2a9fee63", "action": "add" } }, - "CodeHistory": { - "3": { - "version": 3, - "hash": "1b9bd1d3d096abab5617c2ff597b4c80751f686d16482a2cff4efd8741b84d53", + "TwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "e538734d20be3b477e188eb91f66600c2e654bb32e34806ef24329e48238bf18", "action": "add" } }, - "PythonConnection": { - "2": { - "version": 2, - "hash": "eb479c671fc112b2acbedb88bc5624dfdc9592856c04c22c66410f6c863e1708", + "SyncState": { + "1": { + "version": 1, + "hash": "a0616775ec8ef0629e2d91e0df9cc4237ea3674727eda1ce367f1897ee35767d", "action": "remove" }, - "3": { - "version": 3, - "hash": "1084c85a59c0436592530b5fe9afc2394088c8d16faef2b19fdb9fb83ff0f0e2", - "action": "add" - } - }, - "HTTPNodeRoute": { "2": { "version": 2, - "hash": "2134ea812f7c6ea41522727ae087245c4b1195ffbad554db638070861cd9eb1c", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "d26cb313e92b1fbe36995c8ed4103a9168ea6e589b2803ed9a91c23f14bf0c96", + "hash": "925f1b8ccd4b9d542700a111f9c4bdd28bfa55978d805ddb2fb3c108cc940d19", "action": "add" - } - }, - "PythonNodeRoute": { - "2": { - "version": 2, - "hash": "3eca5767ae4a8fbe67744509e58c6d9fb78f38fa0a0f7fcf5960ab4250acc1f0", - "action": "remove" }, "3": { "version": 3, - "hash": "1bc413ec7c1d498ec945878e21e00affd9bd6d53b564b1e10e52feb09f177d04", + "hash": "1b5fd28919cb496f8073a64a57736d477ace1ed969962b1b049cecf766f2661c", "action": "add" } }, - "SeaweedFSBlobDeposit": { + "NodePeer": { "3": { "version": 3, - "hash": "05e61e6328b085b738e5d41c0781d87852d44d218894cb3008f5be46e337f6d8", - "action": "remove" - }, - "4": { - "version": 4, - "hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881", + "hash": "ec0e39fc77ddb542558519d6a1f7c55f41cc037b6312792333792a04feea57e6", "action": "add" } }, - "SyftWorker": { - "3": { - "version": 3, - "hash": "e124f56ddf4565df2be056553eecd15de7c80bd5f5fd0d06e8ff7815bb05563a", + "NodePeerUpdate": { + "1": { + "version": 1, + "hash": "9e7cd39f6a9f90e8c595452865525e0989df1688236acfd1a665ed047ba47de9", "action": "add" } }, - "StoreMetadata": { + "AssociationRequestChange": { "1": { "version": 1, - "hash": "4a0522eaf28414dd53adcb7d5edb81b4a5b5bbe2e805cb78aa91329c3d6c32a8", + "hash": "508550c43e9e3f30243813c23eb6eec20918923d7ba09498cddbcd7e8bfa4539", "action": "add" } }, - "MigrationData": { + "APIEndpointQueueItem": { "1": { "version": 1, - "hash": "c5be6bb4f34b04f814e15468d5231e47540c5b7d2ea0f2770e6cd332f61173c7", + "hash": "d31b2edfb767401c810584baccd27e4f566181c3ef7706618a82eb25ae20ff6d", "action": "add" } }, - "EmptyInputPolicy": { + "NodeMetadataUpdate": { "2": { "version": 2, - "hash": "3117e16cbe4dbc344ab90fbbd36ba90dfb518e66f0fb07644bbe7864dcdeb309", - "action": "add" - } - }, - "ActionObject": { - "4": { - "version": 4, - "hash": "a4dd2949af0f516d0f640d28e0fdfa026ba8d55bb29eaa7844c926e467606892", - "action": "add" - } - }, - "AnyActionObject": { - "4": { - "version": 4, - "hash": "809bd7ffab211133a9be87e058facecf870a79cb2d4027616f5244323de27091", - "action": "add" - } - }, - "BlobFileOBject": { - "3": { - "version": 3, - "hash": "27901fcd545ad0607dbfcbfa0141ee03b0f0f4bee8d23f2d661a4b22011bfd37", - "action": "add" - } - }, - "NumpyArrayObject": { - "4": { - "version": 4, - "hash": "19e2ff3da78038d2164f86d1f9b0d1facc6008483be60d2852458e90202bb96b", - "action": "add" - } - }, - "NumpyScalarObject": { - "4": { - "version": 4, - "hash": "5101d00dd92ac4391cae77629eb48aa25401cc8c5ebb28a8a969cd5eba35fb67", - "action": "add" - } - }, - "NumpyBoolObject": { - "4": { - "version": 4, - "hash": "764cd93792c4dfe27b8952fde853626592fe58e1a341b5350b23f38ce474583f", - "action": "add" - } - }, - "PandasDataframeObject": { - "4": { - "version": 4, - "hash": "b70f4bb32ba9f3f5ea89552649bf882d927cf9085fb573cc6d4841b32d653f84", - "action": "add" + "hash": "520ae8ffc0c057ffa827cb7b267a19fb6b92e3cf3c0a3666ac34e271b6dd0aed", + "action": "remove" } }, - "PandasSeriesObject": { - "4": { - "version": 4, - "hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86", - "action": "add" + "JobInfo": { + "2": { + "version": 2, + "hash": "89dbd4a810586b49498be1f5299b565a19871487e14a120433b0a4cf607b6dee", + "action": "remove" } }, - "Dataset": { - "3": { - "version": 3, - "hash": "12a24de2ec144fe54eb873767131abed362827e08cd47a5c4d128acee2f2db91", - "action": "add" + "SyncStateItem": { + "1": { + "version": 1, + "hash": "4dbfa0813f5a3f7be0b36249ff2d67e395ad7c9e138c5a122fc7342b8dcc4b92", + "action": "remove" } }, - "CreateDataset": { - "3": { - "version": 3, - "hash": "c96a925ab43440b9c39ab5eb28d7fc4a4c8bf7934474b7c432dc7b12067c7766", - "action": "add" + "VeilidConnection": { + "1": { + "version": 1, + "hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b", + "action": "remove" } } } diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 72105624ac8..7edd4dd0a04 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -780,7 +780,6 @@ def downgrade_job() -> list[Callable]: return [drop("requested_by")] -@serializable() class JobInfo(SyftObject): __canonical_name__ = "JobInfo" __version__ = SYFT_OBJECT_VERSION_2 From cd944dfd40a3612ed00f6d3373879e85b0e2bb02 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 8 Jul 2024 16:46:16 +0200 Subject: [PATCH 296/309] ignore version upgrades for notebook tests --- tox.ini | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6e0164a5c31..8049e3dbb0f 100644 --- a/tox.ini +++ b/tox.ini @@ -281,7 +281,12 @@ commands = bash -c "echo Running with ORCHESTRA_DEPLOYMENT_TYPE=$ORCHESTRA_DEPLOYMENT_TYPE DEV_MODE=$DEV_MODE TEST_NOTEBOOK_PATHS=$TEST_NOTEBOOK_PATHS; ENABLE_SIGNUP=$ENABLE_SIGNUP; date" bash -c "for subfolder in $(echo ${TEST_NOTEBOOK_PATHS} | tr ',' ' '); do \ if [[ $subfolder == *tutorials* ]]; then \ - pytest -x --nbmake "$subfolder" -p no:randomly --ignore=tutorials/model-training -n $(python -c 'import multiprocessing; print(multiprocessing.cpu_count())') -vvvv && \ + pytest -x --nbmake "$subfolder" \ + -p no:randomly \ + --ignore=tutorials/model-training \ + --ignore=tutorials/version-upgrades \ + -n $(python -c 'import multiprocessing; print(multiprocessing.cpu_count())') \ + -vvvv && \ pytest -x --nbmake tutorials/model-training -p no:randomly -vvvv; \ else \ pytest -x --nbmake "$subfolder" -p no:randomly -k 'not 11-container-images-k8s.ipynb' -vvvv; \ From b0ab7dab37526a82f9641174415f16404ba373aa Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 9 Jul 2024 10:26:38 +0200 Subject: [PATCH 297/309] remove prints --- packages/syft/src/syft/service/action/action_service.py | 7 ++----- packages/syft/src/syft/store/kv_document_store.py | 9 --------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index fcd51277a1c..c4c54c85a24 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -574,11 +574,8 @@ def store_permission( def blob_permission( x: SyftVerifyKey | None = None, - ) -> ActionObjectPermission | None: - if result_blob_id: - return ActionObjectPermission(result_blob_id, read_permission, x) - else: - return None + ) -> ActionObjectPermission: + return ActionObjectPermission(result_blob_id, read_permission, x) if len(output_readers) > 0: store_permissions = [store_permission(x) for x in output_readers] diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 847a5ca0da9..83ca343b7c3 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -495,11 +495,6 @@ def _update( return Err(f"Failed to update obj {obj}, you have no permission") except Exception as e: - # third party - # stdlib - import traceback - - print(traceback.format_exc()) return Err(f"Failed to update obj {obj} with error: {e}") def _get_all_from_store( @@ -704,10 +699,6 @@ def _migrate_data( try: migrated_value = value.migrate_to(to_klass.__version__, context) except Exception: - # stdlib - import traceback - - print(traceback.format_exc()) return Err( f"Failed to migrate data to {to_klass} for qk {to_klass.__version__}: {key}" ) From e87a0dd87fe1119b2ee8c427c1831412325c6035 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 9 Jul 2024 10:50:27 +0200 Subject: [PATCH 298/309] add deprecated warning --- packages/syft/src/syft/util/schema.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/syft/src/syft/util/schema.py b/packages/syft/src/syft/util/schema.py index b2a573d2952..7565e051814 100644 --- a/packages/syft/src/syft/util/schema.py +++ b/packages/syft/src/syft/util/schema.py @@ -8,6 +8,9 @@ # syft absolute import syft as sy +# relative +from .decorators import deprecated + RELATIVE_PATH_TO_FRONTEND = "/../../../../grid/frontend/" SCHEMA_FOLDER = "schema" @@ -210,6 +213,9 @@ def resolve_references(json_mappings: dict[str, dict]) -> dict[str, dict]: return json_mappings +@deprecated( + reason="generate_json_schemas is outdated, #1603 for more info", +) def generate_json_schemas(output_path: str | None = None) -> None: # TODO: should we also replace this with the SyftObjectRegistry? json_mappings = process_type_bank(sy.serde.recursive.TYPE_BANK) From 38542852e3cf865e8421ee1e3c7cd5df832bcac5 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 10 Jul 2024 12:20:12 +0200 Subject: [PATCH 299/309] add serviceregistry --- packages/syft/src/syft/node/node.py | 164 ++++++++---------- .../syft/src/syft/node/service_registry.py | 116 +++++++++++++ .../src/syft/service/action/action_service.py | 2 + packages/syft/src/syft/service/service.py | 4 + 4 files changed, 193 insertions(+), 93 deletions(-) create mode 100644 packages/syft/src/syft/node/service_registry.py diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index eb98ce801c4..b4e9383b69d 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -42,48 +42,30 @@ from ..protocol.data_protocol import get_data_protocol from ..service.action.action_object import Action from ..service.action.action_object import ActionObject -from ..service.action.action_service import ActionService from ..service.action.action_store import ActionStore from ..service.action.action_store import DictActionStore from ..service.action.action_store import MongoActionStore from ..service.action.action_store import SQLiteActionStore -from ..service.api.api_service import APIService -from ..service.attestation.attestation_service import AttestationService from ..service.blob_storage.service import BlobStorageService -from ..service.code.status_service import UserCodeStatusService from ..service.code.user_code_service import UserCodeService from ..service.code.user_code_stash import UserCodeStash -from ..service.code_history.code_history_service import CodeHistoryService from ..service.context import AuthedServiceContext from ..service.context import NodeServiceContext from ..service.context import UnauthedServiceContext from ..service.context import UserLoginCredentials -from ..service.data_subject.data_subject_member_service import DataSubjectMemberService -from ..service.data_subject.data_subject_service import DataSubjectService -from ..service.dataset.dataset_service import DatasetService -from ..service.enclave.enclave_service import EnclaveService -from ..service.job.job_service import JobService from ..service.job.job_stash import Job from ..service.job.job_stash import JobStash from ..service.job.job_stash import JobStatus from ..service.job.job_stash import JobType -from ..service.log.log_service import LogService -from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadata -from ..service.migration.migration_service import MigrationService from ..service.network.network_service import NetworkService from ..service.network.utils import PeerHealthCheckTask -from ..service.notification.notification_service import NotificationService from ..service.notifier.notifier_service import NotifierService -from ..service.output.output_service import OutputService -from ..service.policy.policy_service import PolicyService -from ..service.project.project_service import ProjectService from ..service.queue.base_queue import AbstractMessageHandler from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer from ..service.queue.queue import APICallMessageHandler from ..service.queue.queue import QueueManager -from ..service.queue.queue_service import QueueService from ..service.queue.queue_stash import APIEndpointQueueItem from ..service.queue.queue_stash import ActionQueueItem from ..service.queue.queue_stash import QueueItem @@ -91,23 +73,19 @@ from ..service.queue.zmq_queue import QueueConfig from ..service.queue.zmq_queue import ZMQClientConfig from ..service.queue.zmq_queue import ZMQQueueConfig -from ..service.request.request_service import RequestService from ..service.response import SyftError from ..service.service import AbstractService from ..service.service import ServiceConfigRegistry from ..service.service import UserServiceConfigRegistry from ..service.settings.settings import NodeSettings from ..service.settings.settings import NodeSettingsUpdate -from ..service.settings.settings_service import SettingsService from ..service.settings.settings_stash import SettingsStash -from ..service.sync.sync_service import SyncService from ..service.user.user import User from ..service.user.user import UserCreate from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..service.user.user_stash import UserStash -from ..service.worker.image_registry_service import SyftImageRegistryService from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image @@ -115,7 +93,6 @@ from ..service.worker.worker_pool import WorkerPool from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_pool_stash import SyftWorkerPoolStash -from ..service.worker.worker_service import WorkerService from ..service.worker.worker_stash import WorkerStash from ..store.blob_storage import BlobStorageConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig @@ -144,6 +121,7 @@ from ..util.util import thread_ident from .credentials import SyftSigningKey from .credentials import SyftVerifyKey +from .service_registry import ServiceRegistry from .worker_settings import WorkerSettings logger = logging.getLogger(__name__) @@ -416,7 +394,7 @@ def __init__( ) # construct services only after init stores - self._construct_services() + self.services: ServiceRegistry = ServiceRegistry(node=self) create_admin_new( # nosec B106 name=root_username, @@ -732,7 +710,7 @@ def _find_klasses_pending_for_migration( credentials=self.verify_key, role=ServiceRole.ADMIN, ) - migration_state_service = self.get_service(MigrationService) + migration_state_service = self.services.migration klasses_to_be_migrated = [] @@ -888,65 +866,74 @@ def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash def _construct_services(self) -> None: - service_path_map: dict[str, AbstractService] = {} - initialized_services: list[AbstractService] = [] - - # A dict of service and init kwargs. - # - "svc" expects a callable (class or function) - # - The callable must return AbstractService or None - # - "store" expects a store type - # - By default all services get the document store - # - Pass a custom "store" to override this - default_services: list[dict] = [ - {"svc": ActionService, "store": self.action_store}, - {"svc": UserService}, - {"svc": AttestationService}, - {"svc": WorkerService}, - {"svc": SettingsService}, - {"svc": DatasetService}, - {"svc": UserCodeService}, - {"svc": LogService}, - {"svc": RequestService}, - {"svc": QueueService}, - {"svc": JobService}, - {"svc": APIService}, - {"svc": DataSubjectService}, - {"svc": NetworkService}, - {"svc": PolicyService}, - {"svc": NotifierService}, - {"svc": NotificationService}, - {"svc": DataSubjectMemberService}, - {"svc": ProjectService}, - {"svc": EnclaveService}, - {"svc": CodeHistoryService}, - {"svc": MetadataService}, - {"svc": BlobStorageService}, - {"svc": MigrationService}, - {"svc": SyftWorkerImageService}, - {"svc": SyftWorkerPoolService}, - {"svc": SyftImageRegistryService}, - {"svc": SyncService}, - {"svc": OutputService}, - {"svc": UserCodeStatusService}, # this is lazy - ] - - for svc_kwargs in default_services: - ServiceCls = svc_kwargs.pop("svc") - svc_kwargs.setdefault("store", self.document_store) - - svc_instance = ServiceCls(**svc_kwargs) - if not svc_instance: - continue - elif not isinstance(svc_instance, AbstractService): - raise ValueError( - f"Service {ServiceCls.__name__} must be an instance of AbstractService" - ) + self.services = ServiceRegistry(node=self) - service_path_map[ServiceCls.__name__.lower()] = svc_instance - initialized_services.append(ServiceCls) + @property + def service_path_map(self) -> dict[str, AbstractService]: + return self.services.service_path_map - self.services = initialized_services - self.service_path_map = service_path_map + @property + def initialized_services(self) -> list[AbstractService]: + return self.services.services + # service_path_map: dict[str, AbstractService] = {} + # initialized_services: list[AbstractService] = [] + + # # A dict of service and init kwargs. + # # - "svc" expects a callable (class or function) + # # - The callable must return AbstractService or None + # # - "store" expects a store type + # # - By default all services get the document store + # # - Pass a custom "store" to override this + # default_services: list[dict] = [ + # {"svc": ActionService, "store": self.action_store}, + # {"svc": UserService}, + # {"svc": AttestationService}, + # {"svc": WorkerService}, + # {"svc": SettingsService}, + # {"svc": DatasetService}, + # {"svc": UserCodeService}, + # {"svc": LogService}, + # {"svc": RequestService}, + # {"svc": QueueService}, + # {"svc": JobService}, + # {"svc": APIService}, + # {"svc": DataSubjectService}, + # {"svc": NetworkService}, + # {"svc": PolicyService}, + # {"svc": NotifierService}, + # {"svc": NotificationService}, + # {"svc": DataSubjectMemberService}, + # {"svc": ProjectService}, + # {"svc": EnclaveService}, + # {"svc": CodeHistoryService}, + # {"svc": MetadataService}, + # {"svc": BlobStorageService}, + # {"svc": MigrationService}, + # {"svc": SyftWorkerImageService}, + # {"svc": SyftWorkerPoolService}, + # {"svc": SyftImageRegistryService}, + # {"svc": SyncService}, + # {"svc": OutputService}, + # {"svc": UserCodeStatusService}, # this is lazy + # ] + + # for svc_kwargs in default_services: + # ServiceCls = svc_kwargs.pop("svc") + # svc_kwargs.setdefault("store", self.document_store) + + # svc_instance = ServiceCls(**svc_kwargs) + # if not svc_instance: + # continue + # elif not isinstance(svc_instance, AbstractService): + # raise ValueError( + # f"Service {ServiceCls.__name__} must be an instance of AbstractService" + # ) + + # service_path_map[ServiceCls.__name__.lower()] = svc_instance + # initialized_services.append(ServiceCls) + + # self.services = initialized_services + # self.service_path_map = service_path_map def get_service_method(self, path_or_func: str | Callable) -> Callable: if callable(path_or_func): @@ -954,21 +941,12 @@ def get_service_method(self, path_or_func: str | Callable) -> Callable: return self._get_service_method_from_path(path_or_func) def get_service(self, path_or_func: str | Callable) -> AbstractService: - if callable(path_or_func): - path_or_func = path_or_func.__qualname__ - return self._get_service_from_path(path_or_func) - - def _get_service_from_path(self, path: str) -> AbstractService: - path_list = path.split(".") - if len(path_list) > 1: - _ = path_list.pop() - service_name = path_list.pop() - return self.service_path_map[service_name.lower()] + return self.services.get_service(path_or_func) def _get_service_method_from_path(self, path: str) -> Callable: path_list = path.split(".") method_name = path_list.pop() - service_obj = self._get_service_from_path(path=path) + service_obj = self.services._get_service_from_path(path=path) return getattr(service_obj, method_name) diff --git a/packages/syft/src/syft/node/service_registry.py b/packages/syft/src/syft/node/service_registry.py new file mode 100644 index 00000000000..eb0ac666334 --- /dev/null +++ b/packages/syft/src/syft/node/service_registry.py @@ -0,0 +1,116 @@ +# stdlib +from collections.abc import Callable +import typing + +# relative +from ..service.action.action_service import ActionService +from ..service.action.action_store import ActionStore +from ..service.api.api_service import APIService +from ..service.attestation.attestation_service import AttestationService +from ..service.blob_storage.service import BlobStorageService +from ..service.code.status_service import UserCodeStatusService +from ..service.code.user_code_service import UserCodeService +from ..service.code_history.code_history_service import CodeHistoryService +from ..service.data_subject.data_subject_member_service import DataSubjectMemberService +from ..service.data_subject.data_subject_service import DataSubjectService +from ..service.dataset.dataset_service import DatasetService +from ..service.enclave.enclave_service import EnclaveService +from ..service.job.job_service import JobService +from ..service.log.log_service import LogService +from ..service.metadata.metadata_service import MetadataService +from ..service.migration.migration_service import MigrationService +from ..service.network.network_service import NetworkService +from ..service.notification.notification_service import NotificationService +from ..service.notifier.notifier_service import NotifierService +from ..service.output.output_service import OutputService +from ..service.policy.policy_service import PolicyService +from ..service.project.project_service import ProjectService +from ..service.queue.queue_service import QueueService +from ..service.request.request_service import RequestService +from ..service.service import AbstractService +from ..service.settings.settings_service import SettingsService +from ..service.sync.sync_service import SyncService +from ..service.user.user_service import UserService +from ..service.worker.image_registry_service import SyftImageRegistryService +from ..service.worker.worker_image_service import SyftWorkerImageService +from ..service.worker.worker_pool_service import SyftWorkerPoolService +from ..service.worker.worker_service import WorkerService +from .node import Node + + +class ServiceRegistry: + # Services + action: ActionService + user: UserService + attestation: AttestationService + worker: WorkerService + settings: SettingsService + dataset: DatasetService + user_code: UserCodeService + log: LogService + request: RequestService + queue: QueueService + job: JobService + api: APIService + data_subject: DataSubjectService + network: NetworkService + policy: PolicyService + notifier: NotifierService + notification: NotificationService + data_subject_member: DataSubjectMemberService + project: ProjectService + enclave: EnclaveService + code_history: CodeHistoryService + metadata: MetadataService + blob_storage: BlobStorageService + migration: MigrationService + syft_worker_image: SyftWorkerImageService + syft_worker_pool: SyftWorkerPoolService + syft_image_registry: SyftImageRegistryService + sync: SyncService + output: OutputService + user_code_status: UserCodeStatusService + + def __init__(self, node: Node) -> None: + self.node = node + self.service_classes = self.get_service_classes() + self.services: list[AbstractService] = [] + self.service_path_map: dict[str, AbstractService] = {} + self._construct_services() + + @classmethod + def get_service_classes( + cls, + ) -> dict[str, type[AbstractService]]: + return { + name: cls + for name, cls in typing.get_type_hints(cls).items() + if issubclass(cls, AbstractService) + } + + def _construct_services(self) -> None: + for field_name, service_cls in self.get_service_classes().items(): + if issubclass(service_cls.store_type, ActionStore): + svc_kwargs = {"store": self.node.action_store} + else: + svc_kwargs = {"store": self.node.document_store} + + service = service_cls(**svc_kwargs) + setattr(self, field_name, service) + self.services.append(service) + self.service_path_map[service.__class__.__name__.lower()] = service + + def get_service(self, path_or_func: str | Callable) -> AbstractService: + if callable(path_or_func): + path_or_func = path_or_func.__qualname__ + return self._get_service_from_path(path_or_func) + + def _get_service_from_path(self, path: str) -> AbstractService: + try: + path_list = path.split(".") + if len(path_list) > 1: + _ = path_list.pop() + service_name = path_list.pop() + return self.service_path_map[service_name.lower()] + except KeyError: + raise ValueError(f"Service {path} not found.") diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index c4c54c85a24..1d0170d49d8 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -54,6 +54,8 @@ @serializable() class ActionService(AbstractService): + store_type = ActionStore + def __init__(self, store: ActionStore) -> None: self.store = store diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index c92695e2f6a..ddf989d397e 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Callable from copy import deepcopy +import functools from functools import partial import inspect from inspect import Parameter @@ -31,6 +32,7 @@ from ..serde.signature import Signature from ..serde.signature import signature_remove_context from ..serde.signature import signature_remove_self +from ..store.document_store import DocumentStore from ..store.linked_obj import LinkedObject from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftBaseObject @@ -57,6 +59,7 @@ class AbstractService: node: AbstractNode node_uid: UID + store_type: type = DocumentStore def resolve_link( self, @@ -351,6 +354,7 @@ def wrapper(func: Any) -> Callable: input_signature = deepcopy(signature) + @functools.wraps(func) def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: communication_protocol = kwargs.pop("communication_protocol", None) From ad3ac5ffcff49f38a8b997d02077b95c3e67b750 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 10 Jul 2024 12:42:29 +0200 Subject: [PATCH 300/309] fix circular import --- packages/syft/src/syft/node/service_registry.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/node/service_registry.py b/packages/syft/src/syft/node/service_registry.py index eb0ac666334..79c0a0e9a8f 100644 --- a/packages/syft/src/syft/node/service_registry.py +++ b/packages/syft/src/syft/node/service_registry.py @@ -1,6 +1,8 @@ # stdlib from collections.abc import Callable import typing +from typing import Any +from typing import TYPE_CHECKING # relative from ..service.action.action_service import ActionService @@ -35,7 +37,10 @@ from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_service import WorkerService -from .node import Node + +if TYPE_CHECKING: + # relative + from .node import Node class ServiceRegistry: @@ -71,7 +76,7 @@ class ServiceRegistry: output: OutputService user_code_status: UserCodeStatusService - def __init__(self, node: Node) -> None: + def __init__(self, node: "Node") -> None: self.node = node self.service_classes = self.get_service_classes() self.services: list[AbstractService] = [] @@ -90,10 +95,11 @@ def get_service_classes( def _construct_services(self) -> None: for field_name, service_cls in self.get_service_classes().items(): + svc_kwargs: dict[str, Any] = {} if issubclass(service_cls.store_type, ActionStore): - svc_kwargs = {"store": self.node.action_store} + svc_kwargs["store"] = self.node.action_store else: - svc_kwargs = {"store": self.node.document_store} + svc_kwargs["store"] = self.node.document_store service = service_cls(**svc_kwargs) setattr(self, field_name, service) From bf3e99f248a95dd649faf92928e754211fe898e9 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 10 Jul 2024 15:03:45 +0200 Subject: [PATCH 301/309] serializable --- packages/syft/src/syft/node/node.py | 5 +- .../syft/src/syft/node/service_registry.py | 48 +++++++++++++------ 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index b4e9383b69d..f13187852b1 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -394,7 +394,7 @@ def __init__( ) # construct services only after init stores - self.services: ServiceRegistry = ServiceRegistry(node=self) + self.services: ServiceRegistry = ServiceRegistry.for_node(self) create_admin_new( # nosec B106 name=root_username, @@ -865,9 +865,6 @@ def job_stash(self) -> JobStash: def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash - def _construct_services(self) -> None: - self.services = ServiceRegistry(node=self) - @property def service_path_map(self) -> dict[str, AbstractService]: return self.services.service_path_map diff --git a/packages/syft/src/syft/node/service_registry.py b/packages/syft/src/syft/node/service_registry.py index 79c0a0e9a8f..4c50f6651c0 100644 --- a/packages/syft/src/syft/node/service_registry.py +++ b/packages/syft/src/syft/node/service_registry.py @@ -1,10 +1,13 @@ # stdlib from collections.abc import Callable +from dataclasses import dataclass +from dataclasses import field import typing from typing import Any from typing import TYPE_CHECKING # relative +from ..serde.serializable import serializable from ..service.action.action_service import ActionService from ..service.action.action_store import ActionStore from ..service.api.api_service import APIService @@ -43,8 +46,9 @@ from .node import Node +@serializable() +@dataclass class ServiceRegistry: - # Services action: ActionService user: UserService attestation: AttestationService @@ -76,12 +80,20 @@ class ServiceRegistry: output: OutputService user_code_status: UserCodeStatusService - def __init__(self, node: "Node") -> None: - self.node = node - self.service_classes = self.get_service_classes() - self.services: list[AbstractService] = [] - self.service_path_map: dict[str, AbstractService] = {} - self._construct_services() + services: list[AbstractService] = field(default_factory=list, init=False) + service_path_map: dict[str, AbstractService] = field( + default_factory=dict, init=False + ) + + @classmethod + def for_node(cls, node: "Node") -> "ServiceRegistry": + return cls(**cls._construct_services(node)) + + def __post_init__(self) -> None: + for name, service_cls in self.get_service_classes().items(): + service = getattr(self, name) + self.services.append(service) + self.service_path_map[service_cls.__name__.lower()] = service @classmethod def get_service_classes( @@ -93,18 +105,19 @@ def get_service_classes( if issubclass(cls, AbstractService) } - def _construct_services(self) -> None: - for field_name, service_cls in self.get_service_classes().items(): + @classmethod + def _construct_services(cls, node: "Node") -> dict[str, AbstractService]: + service_dict = {} + for field_name, service_cls in cls.get_service_classes().items(): svc_kwargs: dict[str, Any] = {} if issubclass(service_cls.store_type, ActionStore): - svc_kwargs["store"] = self.node.action_store + svc_kwargs["store"] = node.action_store else: - svc_kwargs["store"] = self.node.document_store + svc_kwargs["store"] = node.document_store service = service_cls(**svc_kwargs) - setattr(self, field_name, service) - self.services.append(service) - self.service_path_map[service.__class__.__name__.lower()] = service + service_dict[field_name] = service + return service_dict def get_service(self, path_or_func: str | Callable) -> AbstractService: if callable(path_or_func): @@ -120,3 +133,10 @@ def _get_service_from_path(self, path: str) -> AbstractService: return self.service_path_map[service_name.lower()] except KeyError: raise ValueError(f"Service {path} not found.") + + def to_dict(self) -> dict[str, AbstractService]: + d: dict[str, AbstractService] = {} + for name in self.get_service_classes().keys(): + service = getattr(self, name) + d[name] = service + return d From b1caaedbc0d6c094505c802519888f8af1bdd5f6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 10 Jul 2024 15:47:52 +0200 Subject: [PATCH 302/309] remove comments --- packages/syft/src/syft/node/node.py | 59 ----------------------------- 1 file changed, 59 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index f13187852b1..2e5fbcbc3f5 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -872,65 +872,6 @@ def service_path_map(self) -> dict[str, AbstractService]: @property def initialized_services(self) -> list[AbstractService]: return self.services.services - # service_path_map: dict[str, AbstractService] = {} - # initialized_services: list[AbstractService] = [] - - # # A dict of service and init kwargs. - # # - "svc" expects a callable (class or function) - # # - The callable must return AbstractService or None - # # - "store" expects a store type - # # - By default all services get the document store - # # - Pass a custom "store" to override this - # default_services: list[dict] = [ - # {"svc": ActionService, "store": self.action_store}, - # {"svc": UserService}, - # {"svc": AttestationService}, - # {"svc": WorkerService}, - # {"svc": SettingsService}, - # {"svc": DatasetService}, - # {"svc": UserCodeService}, - # {"svc": LogService}, - # {"svc": RequestService}, - # {"svc": QueueService}, - # {"svc": JobService}, - # {"svc": APIService}, - # {"svc": DataSubjectService}, - # {"svc": NetworkService}, - # {"svc": PolicyService}, - # {"svc": NotifierService}, - # {"svc": NotificationService}, - # {"svc": DataSubjectMemberService}, - # {"svc": ProjectService}, - # {"svc": EnclaveService}, - # {"svc": CodeHistoryService}, - # {"svc": MetadataService}, - # {"svc": BlobStorageService}, - # {"svc": MigrationService}, - # {"svc": SyftWorkerImageService}, - # {"svc": SyftWorkerPoolService}, - # {"svc": SyftImageRegistryService}, - # {"svc": SyncService}, - # {"svc": OutputService}, - # {"svc": UserCodeStatusService}, # this is lazy - # ] - - # for svc_kwargs in default_services: - # ServiceCls = svc_kwargs.pop("svc") - # svc_kwargs.setdefault("store", self.document_store) - - # svc_instance = ServiceCls(**svc_kwargs) - # if not svc_instance: - # continue - # elif not isinstance(svc_instance, AbstractService): - # raise ValueError( - # f"Service {ServiceCls.__name__} must be an instance of AbstractService" - # ) - - # service_path_map[ServiceCls.__name__.lower()] = svc_instance - # initialized_services.append(ServiceCls) - - # self.services = initialized_services - # self.service_path_map = service_path_map def get_service_method(self, path_or_func: str | Callable) -> Callable: if callable(path_or_func): From a2c27c3cd1391dd8735f6e11fd0ab8ba072a7abe Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 10 Jul 2024 15:50:49 +0200 Subject: [PATCH 303/309] remove unused method --- packages/syft/src/syft/node/service_registry.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/packages/syft/src/syft/node/service_registry.py b/packages/syft/src/syft/node/service_registry.py index 4c50f6651c0..0d3c514394c 100644 --- a/packages/syft/src/syft/node/service_registry.py +++ b/packages/syft/src/syft/node/service_registry.py @@ -133,10 +133,3 @@ def _get_service_from_path(self, path: str) -> AbstractService: return self.service_path_map[service_name.lower()] except KeyError: raise ValueError(f"Service {path} not found.") - - def to_dict(self) -> dict[str, AbstractService]: - d: dict[str, AbstractService] = {} - for name in self.get_service_classes().keys(): - service = getattr(self, name) - d[name] = service - return d From e1d1d55594413413ff3e1a98706d04caf30a1490 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Mon, 8 Jul 2024 17:02:07 -0400 Subject: [PATCH 304/309] add first pass at conda install in CI --- .github/workflows/conda-install.yml | 55 +++++++++++++++++++++++++++++ .github/workflows/nightlies.yml | 5 +++ 2 files changed, 60 insertions(+) create mode 100644 .github/workflows/conda-install.yml diff --git a/.github/workflows/conda-install.yml b/.github/workflows/conda-install.yml new file mode 100644 index 00000000000..2f5666e3b88 --- /dev/null +++ b/.github/workflows/conda-install.yml @@ -0,0 +1,55 @@ +name: Conda Install - PySyft + +on: + workflow_call: + + pull_request: + branches: + - dev + - main + - "0.8" + + workflow_dispatch: + inputs: + none: + description: "Run Version Tests Manually" + required: false + +concurrency: + group: stack-${{ github.event_name == 'pull_request' && format('{0}-{1}', github.workflow, github.event.pull_request.number) || github.workflow_ref }} + cancel-in-progress: true + +jobs: + build: + strategy: + max-parallel: 99 + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.12"] + fail-fast: false + + runs-on: ${{matrix.os}} + + steps: + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + activate-environment: syft_conda_env + python-version: ${{ matrix.python-version }} + use-only-tar-bz2: true + - name: Conda info + shell: bash -el {0} + run: conda info + - name: Install syft (Windows) + if: matrix.os == 'windows-latest' + shell: pwsh + run: | + pip install ./packages/syft + echo "SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)')" >> $GITHUB_ENV + - name: Install syft (MacOS or Linux) + if: matrix.os != 'windows-latest' + shell: bash -el {0} + run: | + pip install ./packages/syft + echo "SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)')" >> $GITHUB_ENV diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 3db29a3663c..9a62e50301d 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -27,3 +27,8 @@ jobs: if: github.repository == 'OpenMined/PySyft' # don't run on forks uses: OpenMined/PySyft/.github/workflows/container-scan.yml@dev secrets: inherit + + call-conda-install: + if: github.repository == 'OpenMined/PySyft' # don't run on forks + uses: OpenMined/PySyft/.github/workflows/conda-install.yml@dev + secrets: inherit From 3b3e75f78b8a57aed7747572a92ecf48232886e7 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Mon, 8 Jul 2024 17:09:13 -0400 Subject: [PATCH 305/309] remove calling from PRs --- .github/workflows/conda-install.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/conda-install.yml b/.github/workflows/conda-install.yml index 2f5666e3b88..8f8ce660490 100644 --- a/.github/workflows/conda-install.yml +++ b/.github/workflows/conda-install.yml @@ -3,12 +3,6 @@ name: Conda Install - PySyft on: workflow_call: - pull_request: - branches: - - dev - - main - - "0.8" - workflow_dispatch: inputs: none: From 33b91c7ff78130bbcc2a1d24ee602b704af4cdc2 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Tue, 9 Jul 2024 11:53:37 -0400 Subject: [PATCH 306/309] clean up and add version check --- .github/workflows/conda-install.yml | 34 ++++++++++++++++++----------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/.github/workflows/conda-install.yml b/.github/workflows/conda-install.yml index 8f8ce660490..a53ac91698e 100644 --- a/.github/workflows/conda-install.yml +++ b/.github/workflows/conda-install.yml @@ -9,12 +9,8 @@ on: description: "Run Version Tests Manually" required: false -concurrency: - group: stack-${{ github.event_name == 'pull_request' && format('{0}-{1}', github.workflow, github.event.pull_request.number) || github.workflow_ref }} - cancel-in-progress: true - jobs: - build: + constall-install-syft: strategy: max-parallel: 99 matrix: @@ -31,19 +27,31 @@ jobs: auto-update-conda: true activate-environment: syft_conda_env python-version: ${{ matrix.python-version }} - use-only-tar-bz2: true - - name: Conda info - shell: bash -el {0} - run: conda info - name: Install syft (Windows) if: matrix.os == 'windows-latest' shell: pwsh run: | - pip install ./packages/syft - echo "SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)')" >> $GITHUB_ENV + python -m pip install ./packages/syft + EXPECTED_VERSION=$(python packages/grid/VERSION) + SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)') + # Compare the versions + if [ "$EXPECTED_VERSION" != "$SYFT_VERSION" ]; then + echo "Expected version: $EXPECTED_VERSION" + echo "Actual version: $SYFT_VERSION" + echo "Version mismatch. Failing the CI job." + exit 1 + fi - name: Install syft (MacOS or Linux) if: matrix.os != 'windows-latest' shell: bash -el {0} run: | - pip install ./packages/syft - echo "SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)')" >> $GITHUB_ENV + python -m pip install ./packages/syft + EXPECTED_VERSION=$(python packages/grid/VERSION) + SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)') + # Compare the versions + if [ "$EXPECTED_VERSION" != "$SYFT_VERSION" ]; then + echo "Expected version: $EXPECTED_VERSION" + echo "Actual version: $SYFT_VERSION" + echo "Version mismatch. Failing the CI job." + exit 1 + fi From 26240ce9e620df6e4463419edcae245adde26cf1 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Tue, 9 Jul 2024 15:33:19 -0400 Subject: [PATCH 307/309] update windows version check to powershell equivalent --- .github/workflows/conda-install.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/conda-install.yml b/.github/workflows/conda-install.yml index a53ac91698e..c7b7988cdf9 100644 --- a/.github/workflows/conda-install.yml +++ b/.github/workflows/conda-install.yml @@ -35,12 +35,12 @@ jobs: EXPECTED_VERSION=$(python packages/grid/VERSION) SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)') # Compare the versions - if [ "$EXPECTED_VERSION" != "$SYFT_VERSION" ]; then - echo "Expected version: $EXPECTED_VERSION" - echo "Actual version: $SYFT_VERSION" - echo "Version mismatch. Failing the CI job." + if ($expectedVersion -ne $syftVersion) { + Write-Output "Expected version: $expectedVersion" + Write-Output "Actual version: $syftVersion" + Write-Output "Version mismatch." exit 1 - fi + } - name: Install syft (MacOS or Linux) if: matrix.os != 'windows-latest' shell: bash -el {0} @@ -52,6 +52,6 @@ jobs: if [ "$EXPECTED_VERSION" != "$SYFT_VERSION" ]; then echo "Expected version: $EXPECTED_VERSION" echo "Actual version: $SYFT_VERSION" - echo "Version mismatch. Failing the CI job." + echo "Version mismatch." exit 1 fi From 2e29c61cf7668a9aabcd617a95e2bd7fe4f65a93 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Tue, 9 Jul 2024 15:42:11 -0400 Subject: [PATCH 308/309] try to fix powershell var setting --- .github/workflows/conda-install.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/conda-install.yml b/.github/workflows/conda-install.yml index c7b7988cdf9..9498ef77b91 100644 --- a/.github/workflows/conda-install.yml +++ b/.github/workflows/conda-install.yml @@ -32,8 +32,8 @@ jobs: shell: pwsh run: | python -m pip install ./packages/syft - EXPECTED_VERSION=$(python packages/grid/VERSION) - SYFT_VERSION=$(python -c 'import syft; print(syft.__version__)') + $expectedVersion = python packages/grid/VERSION + $syftVersion = python -c 'import syft; print(syft.__version__)' # Compare the versions if ($expectedVersion -ne $syftVersion) { Write-Output "Expected version: $expectedVersion" From 790174e0bc9161908d7d14b4a467b02ad954ad1f Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:45:33 +0530 Subject: [PATCH 309/309] Prevent server data reset due to uvicorn hotreload --- packages/syft/src/syft/node/node.py | 16 +++++------ packages/syft/src/syft/node/server.py | 15 ++++++++--- packages/syft/src/syft/node/utils.py | 38 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 packages/syft/src/syft/node/utils.py diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index f13187852b1..e34e5090cd2 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -12,10 +12,8 @@ import logging import os from pathlib import Path -import shutil import subprocess # nosec import sys -import tempfile from time import sleep import traceback from typing import Any @@ -122,6 +120,9 @@ from .credentials import SyftSigningKey from .credentials import SyftVerifyKey from .service_registry import ServiceRegistry +from .utils import get_named_node_uid +from .utils import get_temp_dir_for_node +from .utils import remove_temp_dir_for_node from .worker_settings import WorkerSettings logger = logging.getLogger(__name__) @@ -655,7 +656,7 @@ def named( association_request_auto_approval: bool = False, background_tasks: bool = False, ) -> Node: - uid = UID.with_seed(name) + uid = get_named_node_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() key = SyftSigningKey(signing_key=SigningKey(name_hash)) blob_storage_config = None @@ -952,18 +953,13 @@ def get_temp_dir(self, dir_name: str = "") -> Path: Get a temporary directory unique to the node. Provide all dbs, blob dirs, and locks using this directory. """ - root = os.getenv("SYFT_TEMP_ROOT", "syft") - p = Path(tempfile.gettempdir(), root, str(self.id), dir_name) - p.mkdir(parents=True, exist_ok=True) - return p + return get_temp_dir_for_node(self.id, dir_name) def remove_temp_dir(self) -> None: """ Remove the temporary directory for this node. """ - rootdir = self.get_temp_dir() - if rootdir.exists(): - shutil.rmtree(rootdir, ignore_errors=True) + remove_temp_dir_for_node(self.id) def update_self(self, settings: NodeSettings) -> None: updateable_attrs = ( diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 4628cae6005..a3451f304a7 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -31,6 +31,8 @@ from .gateway import Gateway from .node import NodeType from .routes import make_routes +from .utils import get_named_node_uid +from .utils import remove_temp_dir_for_node if os_name() == "macOS": # needed on MacOS to prevent [__NSCFConstantString initialize] may have been in @@ -73,12 +75,11 @@ def app_factory() -> FastAPI: kwargs = settings.model_dump() if settings.dev_mode: print( - f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. " + f"WARN: private key is based on node name: {settings.name} in dev_mode. " "Don't run this in production." ) worker = worker_class.named(**kwargs) else: - del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode worker = worker_class(**kwargs) app = FastAPI(title=settings.name) @@ -119,7 +120,15 @@ def run_uvicorn( starting_uvicorn_event: multiprocessing.synchronize.Event, **kwargs: Any, ) -> None: - if kwargs.get("reset"): + should_reset = kwargs.get("dev_mode") and kwargs.get("reset") + + if should_reset: + print("Found `reset=True` in the launch configuration. Resetting the node...") + named_node_uid = get_named_node_uid(kwargs.get("name")) + remove_temp_dir_for_node(named_node_uid) + # Explicitly set `reset` to False to prevent multiple resets during hot-reload + kwargs["reset"] = False + # Kill all old python processes try: python_pids = find_python_processes_on_port(port) for pid in python_pids: diff --git a/packages/syft/src/syft/node/utils.py b/packages/syft/src/syft/node/utils.py new file mode 100644 index 00000000000..3048fd3fa94 --- /dev/null +++ b/packages/syft/src/syft/node/utils.py @@ -0,0 +1,38 @@ +# future +from __future__ import annotations + +# stdlib +import os +from pathlib import Path +import shutil +import tempfile + +# relative +from ..types.uid import UID + + +def get_named_node_uid(name: str) -> UID: + """ + Get a unique identifier for a named node. + """ + return UID.with_seed(name) + + +def get_temp_dir_for_node(node_uid: UID, dir_name: str = "") -> Path: + """ + Get a temporary directory unique to the node. + Provide all dbs, blob dirs, and locks using this directory. + """ + root = os.getenv("SYFT_TEMP_ROOT", "syft") + p = Path(tempfile.gettempdir(), root, str(node_uid), dir_name) + p.mkdir(parents=True, exist_ok=True) + return p + + +def remove_temp_dir_for_node(node_uid: UID) -> None: + """ + Remove the temporary directory for this node. + """ + rootdir = get_temp_dir_for_node(node_uid) + if rootdir.exists(): + shutil.rmtree(rootdir, ignore_errors=True)