diff --git a/.github/workflows/container-scan.yml b/.github/workflows/container-scan.yml index 303eb11bc40..211a8022f62 100644 --- a/.github/workflows/container-scan.yml +++ b/.github/workflows/container-scan.yml @@ -224,7 +224,7 @@ jobs: name: syft.sbom.json path: syft.sbom.json - scan-mongo-latest-trivy: + scan-postgres-latest-trivy: permissions: contents: read # for actions/checkout to fetch code security-events: write # for github/codeql-action/upload-sarif to upload SARIF results @@ -238,24 +238,24 @@ jobs: continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "mongo:7.0.0" + image-ref: "postgres:16.1" format: "cyclonedx" - output: "mongo-trivy-results.sbom.json" + output: "postgres-trivy-results.sbom.json" timeout: "10m0s" #Upload SBOM to GitHub Artifacts - name: Upload SBOM to GitHub Artifacts uses: actions/upload-artifact@v4 with: - name: mongo-trivy-results.sbom.json - path: mongo-trivy-results.sbom.json + name: postgres-trivy-results.sbom.json + path: postgres-trivy-results.sbom.json #Generate sarif file - name: Run Trivy vulnerability scanner continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "mongo:7.0.0" + image-ref: "postgres:16.1" format: "sarif" output: "trivy-results.sarif" timeout: "10m0s" @@ -266,7 +266,7 @@ jobs: with: sarif_file: "trivy-results.sarif" - scan-mongo-latest-snyk: + scan-postgres-latest-snyk: permissions: contents: read # for actions/checkout to fetch code security-events: write # for github/codeql-action/upload-sarif to upload SARIF results @@ -281,7 +281,7 @@ jobs: # This is where you will need to introduce the Snyk API token created with your Snyk account SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} with: - image: mongo:7.0.0 + image: postgres:16.1 args: --sarif-file-output=snyk-code.sarif # Replace any "undefined" security severity values with 0. The undefined value is used in the case diff --git a/.isort.cfg b/.isort.cfg index aeb09bb8f36..26309a07039 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -20,4 +20,4 @@ import_heading_localfolder=relative ignore_comments=False force_grid_wrap=True honor_noqa=True -skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/*,packages/syft/tests/mongomock/* \ No newline at end of file +skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/* \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3487d8d0915..e1c50cd3b96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,41 +3,36 @@ repos: rev: v4.5.0 hooks: - id: check-ast - exclude: ^(packages/syft/tests/mongomock) always_run: true - id: trailing-whitespace always_run: true - exclude: ^(docs/|.+\.md|.bumpversion.cfg|packages/syft/tests/mongomock) + exclude: ^(docs/|.+\.md|.bumpversion.cfg) - id: check-docstring-first always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: check-json always_run: true - exclude: ^(packages/grid/frontend/|packages/syft/tests/mongomock|.vscode) + exclude: ^(packages/grid/frontend/|.vscode) - id: check-added-large-files always_run: true exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif) - id: check-yaml always_run: true - exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/|packages/syft/tests/mongomock) + exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/) - id: check-merge-conflict always_run: true args: ["--assume-in-merge"] - id: check-executables-have-shebangs always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: debug-statements always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: name-tests-test always_run: true - exclude: ^(.*/tests/utils/)|^(.*fixtures.py|packages/syft/tests/mongomock)|^(tests/scenarios/bigquery/helpers) + exclude: ^(.*/tests/utils/)|^(.*fixtures.py)|^(tests/scenarios/bigquery/helpers) - id: requirements-txt-fixer always_run: true - exclude: "packages/syft/tests/mongomock" - id: mixed-line-ending args: ["--fix=lf"] - exclude: '\.bat|\.csv|\.ps1$|packages/syft/tests/mongomock' + exclude: '\.bat|\.csv|\.ps1$' - repo: https://github.com/MarcoGorelli/absolufy-imports # This repository has been archived by the owner on Aug 15, 2023. It is now read-only. rev: v0.3.1 @@ -88,7 +83,6 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --show-fixes] - exclude: packages/syft/tests/mongomock types_or: [python, pyi, jupyter] - id: ruff-format types_or: [python, pyi, jupyter] @@ -178,7 +172,7 @@ repos: rev: "v3.0.0-alpha.9-for-vscode" hooks: - id: prettier - exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) + exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|.vscode) # - repo: meta # hooks: diff --git a/.vscode/launch.json b/.vscode/launch.json index bb5d6e9c00a..7e30fe06537 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,13 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, { "name": "Syft Debugger", "type": "debugpy", diff --git a/docs/source/api_reference/syft.store.mongo_client.rst b/docs/source/api_reference/syft.store.mongo_client.rst deleted file mode 100644 index a21d43700aa..00000000000 --- a/docs/source/api_reference/syft.store.mongo_client.rst +++ /dev/null @@ -1,31 +0,0 @@ -syft.store.mongo\_client -======================== - -.. automodule:: syft.store.mongo_client - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - MongoClient - MongoClientCache - MongoStoreClientConfig - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.mongo_codecs.rst b/docs/source/api_reference/syft.store.mongo_codecs.rst deleted file mode 100644 index 1d91b779e95..00000000000 --- a/docs/source/api_reference/syft.store.mongo_codecs.rst +++ /dev/null @@ -1,35 +0,0 @@ -syft.store.mongo\_codecs -======================== - -.. automodule:: syft.store.mongo_codecs - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - fallback_syft_encoder - - - - - - .. rubric:: Classes - - .. autosummary:: - - SyftMongoBinaryDecoder - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.mongo_document_store.rst b/docs/source/api_reference/syft.store.mongo_document_store.rst deleted file mode 100644 index 30fdb6bc6ca..00000000000 --- a/docs/source/api_reference/syft.store.mongo_document_store.rst +++ /dev/null @@ -1,40 +0,0 @@ -syft.store.mongo\_document\_store -================================= - -.. automodule:: syft.store.mongo_document_store - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - from_mongo - syft_obj_to_mongo - to_mongo - - - - - - .. rubric:: Classes - - .. autosummary:: - - MongoBsonObject - MongoDocumentStore - MongoStoreConfig - MongoStorePartition - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.rst b/docs/source/api_reference/syft.store.rst index b21cf230488..e83e8699025 100644 --- a/docs/source/api_reference/syft.store.rst +++ b/docs/source/api_reference/syft.store.rst @@ -32,8 +32,5 @@ syft.store.kv_document_store syft.store.linked_obj syft.store.locks - syft.store.mongo_client - syft.store.mongo_codecs - syft.store.mongo_document_store syft.store.sqlite_document_store diff --git a/notebooks/api/0.8/12-custom-api-endpoint.ipynb b/notebooks/api/0.8/12-custom-api-endpoint.ipynb index aa60e30dd87..c58c78c1795 100644 --- a/notebooks/api/0.8/12-custom-api-endpoint.ipynb +++ b/notebooks/api/0.8/12-custom-api-endpoint.ipynb @@ -657,6 +657,9 @@ "# stdlib\n", "import time\n", "\n", + "# syft absolute\n", + "from syft.service.job.job_stash import JobStatus\n", + "\n", "# Iterate over the Jobs waiting them to finish their pipelines.\n", "job_pool = [\n", " (log_call_job, \"Logging Private Function Call\"),\n", @@ -665,13 +668,19 @@ "]\n", "for job, expected_log in job_pool:\n", " updated_job = datasite_client.api.services.job.get(job.id)\n", - " while updated_job.status.value != \"completed\":\n", + " while updated_job.status in {JobStatus.CREATED, JobStatus.PROCESSING}:\n", " updated_job = datasite_client.api.services.job.get(job.id)\n", " time.sleep(1)\n", - " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n", - " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n", - " _print=False\n", - " )" + "\n", + " assert (\n", + " updated_job.status == JobStatus.COMPLETED\n", + " ), f\"Job {updated_job.id} exited with status {updated_job.status} and result {updated_job.result}\"\n", + " if updated_job.status == JobStatus.COMPLETED:\n", + " print(f\"Job {updated_job.id} completed\")\n", + " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n", + " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n", + " _print=False\n", + " )" ] }, { @@ -683,6 +692,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -693,7 +707,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb index c64e6c40f4a..a92f7987e68 100644 --- a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb +++ b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb @@ -78,7 +78,7 @@ "If you want to deploy your Kubernetes cluster in a resource-constrained environment, use the following flags to override the default configurations. Please note that you will need at least 1 CPU and 2 GB of RAM on Docker, and some tests may not work in such low-resource environments:\n", "\n", "```sh\n", - "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n", + "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n", "```\n", "\n", "\n", @@ -89,7 +89,7 @@ "If you would like to set your own default password even for the production style deployment, use the following command:\n", "\n", "```sh\n", - "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set mongo.secret.rootPassword=\"example\"\n", + "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set postgres.secret.rootPassword=\"example\"\n", "```\n", "\n" ] diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index c51ba31c8fd..984d2f174f4 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -9,7 +9,7 @@ ARG TORCH_VERSION="2.2.2" # ==================== [BUILD STEP] Python Dev Base ==================== # -FROM cgr.dev/chainguard/wolfi-base as syft_deps +FROM cgr.dev/chainguard/wolfi-base AS syft_deps ARG PYTHON_VERSION ARG UV_VERSION @@ -45,7 +45,7 @@ RUN --mount=type=cache,target=/root/.cache,sharing=locked \ # ==================== [Final] Setup Syft Server ==================== # -FROM cgr.dev/chainguard/wolfi-base as backend +FROM cgr.dev/chainguard/wolfi-base AS backend ARG PYTHON_VERSION ARG UV_VERSION @@ -84,9 +84,10 @@ ENV \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ STACK_API_KEY="changeme" \ - MONGO_HOST="localhost" \ - MONGO_PORT="27017" \ - MONGO_USERNAME="root" \ - MONGO_PASSWORD="example" + POSTGRESQL_DBNAME="syftdb_postgres" \ + POSTGRESQL_HOST="localhost" \ + POSTGRESQL_PORT="5432" \ + POSTGRESQL_USERNAME="syft_postgres" \ + POSTGRESQL_PASSWORD="example" CMD ["bash", "./grid/start.sh"] diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 63bda939c29..e92d6783ae7 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -126,10 +126,11 @@ def get_emails_enabled(self) -> Self: # NETWORK_CHECK_INTERVAL: int = int(os.getenv("NETWORK_CHECK_INTERVAL", 60)) # DATASITE_CHECK_INTERVAL: int = int(os.getenv("DATASITE_CHECK_INTERVAL", 60)) CONTAINER_HOST: str = str(os.getenv("CONTAINER_HOST", "docker")) - MONGO_HOST: str = str(os.getenv("MONGO_HOST", "")) - MONGO_PORT: int = int(os.getenv("MONGO_PORT", 27017)) - MONGO_USERNAME: str = str(os.getenv("MONGO_USERNAME", "")) - MONGO_PASSWORD: str = str(os.getenv("MONGO_PASSWORD", "")) + POSTGRESQL_DBNAME: str = str(os.getenv("POSTGRESQL_DBNAME", "")) + POSTGRESQL_HOST: str = str(os.getenv("POSTGRESQL_HOST", "")) + POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 5432)) + POSTGRESQL_USERNAME: str = str(os.getenv("POSTGRESQL_USERNAME", "")) + POSTGRESQL_PASSWORD: str = str(os.getenv("POSTGRESQL_PASSWORD", "")) DEV_MODE: bool = True if os.getenv("DEV_MODE", "false").lower() == "true" else False # ZMQ stuff QUEUE_PORT: int = int(os.getenv("QUEUE_PORT", 5556)) @@ -137,7 +138,7 @@ def get_emails_enabled(self) -> Self: True if os.getenv("CREATE_PRODUCER", "false").lower() == "true" else False ) N_CONSUMERS: int = int(os.getenv("N_CONSUMERS", 1)) - SQLITE_PATH: str = os.path.expandvars("$HOME/data/db/") + SQLITE_PATH: str = os.path.expandvars("/tmp/data/db") SINGLE_CONTAINER_MODE: bool = str_to_bool(os.getenv("SINGLE_CONTAINER_MODE", False)) CONSUMER_SERVICE_NAME: str | None = os.getenv("CONSUMER_SERVICE_NAME") INMEMORY_WORKERS: bool = str_to_bool(os.getenv("INMEMORY_WORKERS", True)) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 3f401d7e349..7d8d011de5d 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -1,3 +1,6 @@ +# stdlib +from pathlib import Path + # syft absolute from syft.abstract_server import ServerType from syft.server.datasite import Datasite @@ -14,10 +17,8 @@ from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.sqlite_document_store import SQLiteStoreClientConfig -from syft.store.sqlite_document_store import SQLiteStoreConfig +from syft.store.db.postgres import PostgresDBConfig +from syft.store.db.sqlite import SQLiteDBConfig from syft.types.uid import UID # server absolute @@ -36,23 +37,26 @@ def queue_config() -> ZMQQueueConfig: return queue_config -def mongo_store_config() -> MongoStoreConfig: - mongo_client_config = MongoStoreClientConfig( - hostname=settings.MONGO_HOST, - port=settings.MONGO_PORT, - username=settings.MONGO_USERNAME, - password=settings.MONGO_PASSWORD, - ) - - return MongoStoreConfig(client_config=mongo_client_config) - +def sql_store_config() -> SQLiteDBConfig: + # Check if the directory exists, and create it if it doesn't + sqlite_path = Path(settings.SQLITE_PATH) + if not sqlite_path.exists(): + sqlite_path.mkdir(parents=True, exist_ok=True) -def sql_store_config() -> SQLiteStoreConfig: - client_config = SQLiteStoreClientConfig( + return SQLiteDBConfig( filename=f"{UID.from_string(get_server_uid_env())}.sqlite", path=settings.SQLITE_PATH, ) - return SQLiteStoreConfig(client_config=client_config) + + +def postgresql_store_config() -> PostgresDBConfig: + return PostgresDBConfig( + host=settings.POSTGRESQL_HOST, + port=settings.POSTGRESQL_PORT, + user=settings.POSTGRESQL_USERNAME, + password=settings.POSTGRESQL_PASSWORD, + database=settings.POSTGRESQL_DBNAME, + ) def seaweedfs_config() -> SeaweedFSConfig: @@ -87,20 +91,19 @@ def seaweedfs_config() -> SeaweedFSConfig: worker_class = worker_classes[server_type] single_container_mode = settings.SINGLE_CONTAINER_MODE -store_config = sql_store_config() if single_container_mode else mongo_store_config() +db_config = sql_store_config() if single_container_mode else postgresql_store_config() + blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() worker: Server = worker_class( name=server_name, server_side_type=server_side_type, - action_store_config=store_config, - document_store_config=store_config, enable_warnings=enable_warnings, blob_storage_config=blob_storage_config, local_db=single_container_mode, queue_config=queue_config, - migrate=True, + migrate=False, in_memory_workers=settings.INMEMORY_WORKERS, smtp_username=settings.SMTP_USERNAME, smtp_password=settings.SMTP_PASSWORD, @@ -109,4 +112,5 @@ def seaweedfs_config() -> SeaweedFSConfig: smtp_host=settings.SMTP_HOST, association_request_auto_approval=settings.ASSOCIATION_REQUEST_AUTO_APPROVAL, background_tasks=True, + db_config=db_config, ) diff --git a/packages/grid/default.env b/packages/grid/default.env index e1bc5c42557..49697538cd8 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -77,14 +77,6 @@ KANIKO_VERSION="v1.23.2" # Jax JAX_ENABLE_X64=True -# Mongo -MONGO_IMAGE=mongo -MONGO_VERSION="7.0.8" -MONGO_HOST=mongo -MONGO_PORT=27017 -MONGO_USERNAME=root -MONGO_PASSWORD=example - # Redis REDIS_PORT=6379 REDIS_STORE_DB_ID=0 @@ -110,4 +102,13 @@ ENABLE_SIGNUP=False DOCKER_IMAGE_ENCLAVE_ATTESTATION=openmined/syft-enclave-attestation # Rathole Config -RATHOLE_PORT=2333 \ No newline at end of file +RATHOLE_PORT=2333 + +# PostgresSQL Config +# POSTGRESQL_IMAGE=postgres +# export POSTGRESQL_VERSION="15" +POSTGRESQL_DBNAME=syftdb_postgres +POSTGRESQL_HOST=postgres +POSTGRESQL_PORT=5432 +POSTGRESQL_USERNAME=syft_postgres +POSTGRESQL_PASSWORD=example diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 08121fcaaa6..b9a31457d4e 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -79,12 +79,12 @@ deployments: - ./helm/examples/dev/base.yaml dev: - mongo: + postgres: labelSelector: app.kubernetes.io/name: syft - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres ports: - - port: "27017" + - port: "5432" seaweedfs: labelSelector: app.kubernetes.io/name: syft @@ -94,6 +94,7 @@ dev: - port: "8888" # filer - port: "8333" # S3 - port: "4001" # mount azure + - port: "5432" # mount postgres backend: labelSelector: app.kubernetes.io/name: syft @@ -205,8 +206,8 @@ profiles: path: dev.seaweedfs # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27018:27017 + path: dev.postgres.ports[0].port + value: 5433:5432 - op: replace path: dev.backend.ports[0].port value: 5679:5678 @@ -268,8 +269,8 @@ profiles: value: ./helm/examples/dev/enclave.yaml # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27019:27017 + path: dev.postgres.ports[0].port + value: 5434:5432 - op: replace path: dev.backend.ports[0].port value: 5680:5678 diff --git a/packages/grid/helm/examples/azure/azure.high.yaml b/packages/grid/helm/examples/azure/azure.high.yaml index 3234a62b757..4733fed4cc7 100644 --- a/packages/grid/helm/examples/azure/azure.high.yaml +++ b/packages/grid/helm/examples/azure/azure.high.yaml @@ -38,5 +38,5 @@ registry: frontend: resourcesPreset: medium -mongo: +postgres: resourcesPreset: large diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 4999ae40aed..3fc1ad5c4da 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -21,7 +21,7 @@ server: secret: defaultRootPassword: changethis -mongo: +postgres: resourcesPreset: null resources: null diff --git a/packages/grid/helm/examples/gcp/gcp.high.yaml b/packages/grid/helm/examples/gcp/gcp.high.yaml index efdbbe72e68..2a430807fac 100644 --- a/packages/grid/helm/examples/gcp/gcp.high.yaml +++ b/packages/grid/helm/examples/gcp/gcp.high.yaml @@ -97,7 +97,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/grid/helm/examples/gcp/gcp.low.yaml b/packages/grid/helm/examples/gcp/gcp.low.yaml index 94cfc324b0b..8e9e3e7ba35 100644 --- a/packages/grid/helm/examples/gcp/gcp.low.yaml +++ b/packages/grid/helm/examples/gcp/gcp.low.yaml @@ -97,7 +97,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/grid/helm/examples/gcp/gcp.nosync.yaml b/packages/grid/helm/examples/gcp/gcp.nosync.yaml index 02935edfd8f..8e622be5254 100644 --- a/packages/grid/helm/examples/gcp/gcp.nosync.yaml +++ b/packages/grid/helm/examples/gcp/gcp.nosync.yaml @@ -67,7 +67,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 3293056ba2a..2d1a6880c33 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -99,18 +99,20 @@ spec: - name: REVERSE_TUNNEL_ENABLED value: "true" {{- end }} - # MongoDB - - name: MONGO_PORT - value: {{ .Values.mongo.port | quote }} - - name: MONGO_HOST - value: "mongo" - - name: MONGO_USERNAME - value: {{ .Values.mongo.username | quote }} - - name: MONGO_PASSWORD + # Postgres + - name: POSTGRESQL_PORT + value: {{ .Values.postgres.port | quote }} + - name: POSTGRESQL_HOST + value: "postgres" + - name: POSTGRESQL_USERNAME + value: {{ .Values.postgres.username | quote }} + - name: POSTGRESQL_PASSWORD valueFrom: secretKeyRef: - name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} key: rootPassword + - name: POSTGRESQL_DBNAME + value: {{ .Values.postgres.dbname | quote }} # SMTP - name: SMTP_HOST value: {{ .Values.server.smtp.host | quote }} diff --git a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml b/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml deleted file mode 100644 index 91060b90a9b..00000000000 --- a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml +++ /dev/null @@ -1,69 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - name: mongo - labels: - {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo -spec: - replicas: 1 - updateStrategy: - type: RollingUpdate - selector: - matchLabels: - {{- include "common.selectorLabels" . | nindent 6 }} - app.kubernetes.io/component: mongo - serviceName: mongo-headless - podManagementPolicy: OrderedReady - template: - metadata: - labels: - {{- include "common.labels" . | nindent 8 }} - app.kubernetes.io/component: mongo - {{- if .Values.mongo.podLabels }} - {{- toYaml .Values.mongo.podLabels | nindent 8 }} - {{- end }} - {{- if .Values.mongo.podAnnotations }} - annotations: {{- toYaml .Values.mongo.podAnnotations | nindent 8 }} - {{- end }} - spec: - {{- if .Values.mongo.nodeSelector }} - nodeSelector: {{- .Values.mongo.nodeSelector | toYaml | nindent 8 }} - {{- end }} - containers: - - name: mongo-container - image: mongo:7 - imagePullPolicy: Always - resources: {{ include "common.resources.set" (dict "resources" .Values.mongo.resources "preset" .Values.mongo.resourcesPreset) | nindent 12 }} - env: - - name: MONGO_INITDB_ROOT_USERNAME - value: {{ .Values.mongo.username | required "mongo.username is required" | quote }} - - name: MONGO_INITDB_ROOT_PASSWORD - valueFrom: - secretKeyRef: - name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} - key: rootPassword - {{- if .Values.mongo.env }} - {{- toYaml .Values.mongo.env | nindent 12 }} - {{- end }} - volumeMounts: - - mountPath: /data/db - name: mongo-data - readOnly: false - subPath: '' - ports: - - name: mongo-port - containerPort: 27017 - terminationGracePeriodSeconds: 5 - volumeClaimTemplates: - - metadata: - name: mongo-data - labels: - {{- include "common.volumeLabels" . | nindent 8 }} - app.kubernetes.io/component: mongo - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: {{ .Values.mongo.storageSize | quote }} diff --git a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml similarity index 57% rename from packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml rename to packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml index 7cb97ee3592..4855a7868ff 100644 --- a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml +++ b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml @@ -1,15 +1,15 @@ apiVersion: v1 kind: Service metadata: - name: mongo-headless + name: postgres-headless labels: {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres spec: clusterIP: None ports: - - name: mongo - port: 27017 + - name: postgres + port: 5432 selector: {{- include "common.selectorLabels" . | nindent 4 }} - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml similarity index 69% rename from packages/grid/helm/syft/templates/mongo/mongo-secret.yaml rename to packages/grid/helm/syft/templates/postgres/postgres-secret.yaml index 02c58d276ca..63a990c0d9a 100644 --- a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml +++ b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml @@ -1,17 +1,17 @@ -{{- $secretName := "mongo-secret" }} +{{- $secretName := "postgres-secret" }} apiVersion: v1 kind: Secret metadata: name: {{ $secretName }} labels: {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres type: Opaque data: rootPassword: {{ include "common.secrets.set" (dict "secret" $secretName "key" "rootPassword" "randomDefault" .Values.global.randomizedSecrets - "default" .Values.mongo.secret.rootPassword + "default" .Values.postgres.secret.rootPassword "context" $) - }} + }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml similarity index 57% rename from packages/grid/helm/syft/templates/mongo/mongo-service.yaml rename to packages/grid/helm/syft/templates/postgres/postgres-service.yaml index a789f4e8f86..9cd8b156bdd 100644 --- a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml +++ b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml @@ -1,17 +1,17 @@ apiVersion: v1 kind: Service metadata: - name: mongo + name: postgres labels: {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres spec: type: ClusterIP selector: {{- include "common.selectorLabels" . | nindent 4 }} - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres ports: - - name: mongo - port: 27017 + - name: postgres + port: 5432 protocol: TCP - targetPort: 27017 + targetPort: 5432 \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml new file mode 100644 index 00000000000..986031b17e9 --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml @@ -0,0 +1,72 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: postgres + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: postgres + serviceName: postgres-headless + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: postgres + {{- if .Values.postgres.podLabels }} + {{- toYaml .Values.postgres.podLabels | nindent 8 }} + {{- end }} + {{- if .Values.postgres.podAnnotations }} + annotations: {{- toYaml .Values.postgres.podAnnotations | nindent 8 }} + {{- end }} + spec: + {{- if .Values.postgres.nodeSelector }} + nodeSelector: {{- .Values.postgres.nodeSelector | toYaml | nindent 8 }} + {{- end }} + containers: + - name: postgres-container + image: postgres:16.1 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.postgres.resources "preset" .Values.postgres.resourcesPreset) | nindent 12 }} + env: + - name: POSTGRES_USER + value: {{ .Values.postgres.username | required "postgres.username is required" | quote }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} + key: rootPassword + - name: POSTGRES_DB + value: {{ .Values.postgres.dbname | required "postgres.dbname is required" | quote }} + {{- if .Values.postgres.env }} + {{- toYaml .Values.postgres.env | nindent 12 }} + {{- end }} + volumeMounts: + - mountPath: tmp/data/db + name: postgres-data + readOnly: false + subPath: '' + ports: + - name: postgres-port + containerPort: 5432 + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: postgres-data + labels: + {{- include "common.volumeLabels" . | nindent 8 }} + app.kubernetes.io/component: postgres + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.postgres.storageSize | quote }} + diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index c1cec54b441..ba3eebffde8 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -12,10 +12,12 @@ global: # ================================================================================= -mongo: - # MongoDB config - port: 27017 - username: root +postgres: +# Postgres config + port: 5432 + username: syft_postgres + dbname: syftdb_postgres + host: postgres # Extra environment vars env: null @@ -35,12 +37,11 @@ mongo: storageSize: 5Gi # Mongo secret name. Override this if you want to use a self-managed secret. - secretKeyName: mongo-secret + secretKeyName: postgres-secret # default/custom secret raw values secret: - rootPassword: null - + rootPassword: null # ================================================================================= frontend: diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 547445d20bb..d4ee2cff521 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -31,11 +31,10 @@ syft = boto3==1.34.56 forbiddenfruit==0.1.4 packaging>=23.0 - pyarrow==15.0.0 + pyarrow==17.0.0 pycapnp==2.0.0 pydantic[email]==2.6.0 pydantic-settings==2.2.1 - pymongo==4.6.3 pynacl==1.5.0 pyzmq>=23.2.1,<=25.1.1 requests==2.32.3 @@ -66,8 +65,12 @@ syft = jinja2==3.1.4 tenacity==8.3.0 nh3==0.2.17 + psycopg[binary]==3.1.19 + psycopg[pool]==3.1.19 ipython<8.27.0 dynaconf==3.2.6 + sqlalchemy==2.0.32 + psycopg2-binary==2.9.9 install_requires = %(syft)s @@ -111,9 +114,9 @@ telemetry = opentelemetry-instrumentation==0.48b0 opentelemetry-instrumentation-requests==0.48b0 opentelemetry-instrumentation-fastapi==0.48b0 - opentelemetry-instrumentation-pymongo==0.48b0 opentelemetry-instrumentation-botocore==0.48b0 opentelemetry-instrumentation-logging==0.48b0 + opentelemetry-instrumentation-sqlalchemy==0.48b0 ; opentelemetry-instrumentation-asyncio==0.48b0 ; opentelemetry-instrumentation-sqlite3==0.48b0 ; opentelemetry-instrumentation-threading==0.48b0 diff --git a/packages/syft/src/syft/abstract_server.py b/packages/syft/src/syft/abstract_server.py index c222cf4ea5a..3b7885f0a0e 100644 --- a/packages/syft/src/syft/abstract_server.py +++ b/packages/syft/src/syft/abstract_server.py @@ -5,6 +5,7 @@ # relative from .serde.serializable import serializable +from .store.db.db import DBConfig from .types.uid import UID if TYPE_CHECKING: @@ -41,6 +42,7 @@ class AbstractServer: server_side_type: ServerSideType | None in_memory_workers: bool services: "ServiceRegistry" + db_config: DBConfig def get_service(self, path_or_func: str | Callable) -> "AbstractService": raise NotImplementedError diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index a5aa64ca736..85fe33545fa 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -29,8 +29,7 @@ from ..serde.serializable import serializable from ..serde.serialize import _serialize from ..serde.signature import Signature -from ..serde.signature import signature_remove_context -from ..serde.signature import signature_remove_self +from ..serde.signature import signature_remove from ..server.credentials import SyftSigningKey from ..server.credentials import SyftVerifyKey from ..service.context import AuthedServiceContext @@ -738,10 +737,17 @@ def __getattr__(self, name: str) -> Any: ) def __getitem__(self, key: str | int) -> Any: + if hasattr(self, "get_index"): + return self.get_index(key) if hasattr(self, "get_all"): return self.get_all()[key] raise NotImplementedError + def __iter__(self) -> Any: + if hasattr(self, "get_all"): + return iter(self.get_all()) + raise NotImplementedError + def _repr_html_(self) -> Any: if self.path == "settings": return self.get()._repr_html_() @@ -1117,9 +1123,10 @@ def build_endpoint_tree( api_module = APIModule(path="", refresh_callback=self.refresh_api_callback) for v in endpoints.values(): signature = v.signature + args_to_remove = ["context"] if not v.has_self: - signature = signature_remove_self(signature) - signature = signature_remove_context(signature) + args_to_remove.append("self") + signature = signature_remove(signature, args_to_remove) if isinstance(v, APIEndpoint): endpoint_function = generate_remote_function( self, diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 692bd017565..d6fd164f44f 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -605,6 +605,49 @@ def login( result = post_process_result(result, unwrap_on_success=True) return result + def forgot_password( + self, + email: str, + ) -> SyftSigningKey | None: + credentials = {"email": email} + if self.proxy_target_uid: + obj = forward_message_to_proxy( + self.make_call, + proxy_target_uid=self.proxy_target_uid, + path="forgot_password", + kwargs=credentials, + ) + else: + response = self.server.services.user.forgot_password( + context=ServerServiceContext(server=self.server), email=email + ) + obj = post_process_result(response, unwrap_on_success=True) + + return obj + + def reset_password( + self, + token: str, + new_password: str, + ) -> SyftSigningKey | None: + payload = {"token": token, "new_password": new_password} + if self.proxy_target_uid: + obj = forward_message_to_proxy( + self.make_call, + proxy_target_uid=self.proxy_target_uid, + path="reset_password", + kwargs=payload, + ) + else: + response = self.server.services.user.reset_password( + context=ServerServiceContext(server=self.server), + token=token, + new_password=new_password, + ) + obj = post_process_result(response, unwrap_on_success=True) + + return obj + def register(self, new_user: UserCreate) -> SyftSigningKey | None: if self.proxy_target_uid: response = forward_message_to_proxy( diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index 1cbdb44c488..6410c990eac 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -14,6 +14,7 @@ # relative from ..serde.serializable import serializable +from ..serde.serialize import _serialize from ..service.response import SyftSuccess from ..types.base import SyftBaseModel from ..types.errors import SyftException @@ -83,6 +84,10 @@ def merged_custom_cmds(self, sep: str = ";") -> str: class WorkerConfig(SyftBaseModel): pass + def hash(self) -> str: + _bytes = _serialize(self, to_bytes=True, for_hashing=True) + return sha256(_bytes).digest().hex() + @serializable(canonical_name="CustomWorkerConfig", version=1) class CustomWorkerConfig(WorkerConfig): diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 0d295b81982..efed6023ab8 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -184,6 +184,7 @@ def deploy_to_python( debug: bool = False, migrate: bool = False, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -216,6 +217,7 @@ def deploy_to_python( "migrate": migrate, "deployment_type": deployment_type_enum, "consumer_type": consumer_type, + "db_url": db_url, } if port: @@ -329,6 +331,7 @@ def launch( migrate: bool = False, from_state_folder: str | Path | None = None, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ) -> ServerHandle: if from_state_folder is not None: with open(f"{from_state_folder}/config.json") as f: @@ -378,11 +381,12 @@ def launch( debug=debug, migrate=migrate, consumer_type=consumer_type, + db_url=db_url, ) display( SyftInfo( message=f"You have launched a development server at http://{host}:{server_handle.port}." - + "It is intended only for local use." + + " It is intended only for local use." ) ) return server_handle diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 1ea9d1ae203..0c848585119 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -81,8 +81,14 @@ def handle_annotation_repr_(annotation: type) -> str: """Handle typing representation.""" origin = typing.get_origin(annotation) args = typing.get_args(annotation) + + def get_annotation_repr_for_arg(arg: type) -> str: + if hasattr(arg, "__canonical_name__"): + return arg.__canonical_name__ + return getattr(arg, "__name__", str(arg)) + if origin and args: - args_repr = ", ".join(getattr(arg, "__name__", str(arg)) for arg in args) + args_repr = ", ".join(get_annotation_repr_for_arg(arg) for arg in args) origin_repr = getattr(origin, "__name__", str(origin)) # Handle typing.Union and types.UnionType diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 57e224f83e5..023833d9887 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -8,6 +8,75 @@ "3": { "version": 3, "hash": "85e2f6d9a6a12e6029b153be4fd2b4efae3bb2b48e617c08b447c5db01bea6c4", + } + }, + "Notification": { + "2": { + "version": 2, + "hash": "812d3a612422fb1cf53caa13ec34a7bdfcf033a7c24b7518f527af144cb45f3c", + "action": "add" + } + }, + "SyftWorkerImage": { + "2": { + "version": 2, + "hash": "afd3a69719cd6d08b1121676ca8d80ca37be96ee5ed5893dc73733fbf47fd035", + "action": "add" + } + }, + "WorkerSettings": { + "2": { + "version": 2, + "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa", + "action": "add" + } + }, + "MongoDict": { + "1": { + "version": 1, + "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", + "action": "remove" + } + }, + "MongoStoreConfig": { + "1": { + "version": 1, + "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", + "action": "remove" + } + }, + "JobItem": { + "1": { + "version": 1, + "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", + "action": "remove" + } + }, + "DictStoreConfig": { + "1": { + "version": 1, + "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", + "action": "remove" + } + }, + "QueueItem": { + "2": { + "version": 2, + "hash": "1d8615f6daabcd2a285b2f36fd7bef1df76cdd119dd49c02069c50fd1b9c3ff4", + "action": "add" + } + }, + "ActionQueueItem": { + "2": { + "version": 2, + "hash": "bfda6ef87e4045d663324bb91a215ea06e1f173aec1fb4d9ddd337cdc1f0787f", + "action": "add" + } + }, + "APIEndpointQueueItem": { + "2": { + "version": 2, + "hash": "3a46370205152fa23a7d2bfa47130dbf2e2bc7ef31f6d3fe4c92fd8d683770b5", "action": "add" } } diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py new file mode 100644 index 00000000000..ee86241716c --- /dev/null +++ b/packages/syft/src/syft/serde/json_serde.py @@ -0,0 +1,452 @@ +# stdlib +import base64 +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +import json +import typing +from typing import Any +from typing import Generic +from typing import TypeVar +from typing import Union +from typing import get_args +from typing import get_origin + +# third party +import pydantic + +# syft absolute +import syft as sy + +# relative +from ..server.credentials import SyftSigningKey +from ..server.credentials import SyftVerifyKey +from ..types.datetime import DateTime +from ..types.syft_object import BaseDateTime +from ..types.syft_object_registry import SyftObjectRegistry +from ..types.uid import LineageID +from ..types.uid import UID +from .recursive import DEFAULT_EXCLUDE_ATTRS + +T = TypeVar("T") + +JSON_CANONICAL_NAME_FIELD = "__canonical_name__" +JSON_VERSION_FIELD = "__version__" +JSON_DATA_FIELD = "data" + +JsonPrimitive = str | int | float | bool | None +Json = JsonPrimitive | list["Json"] | dict[str, "Json"] + + +def _noop_fn(obj: Any) -> Any: + return obj + + +@dataclass +class JSONSerde(Generic[T]): + klass: type[T] + serialize_fn: Callable[[T], Json] + deserialize_fn: Callable[[Json], T] + + def serialize(self, obj: T) -> Json: + return self.serialize_fn(obj) + + def deserialize(self, obj: Json) -> T: + return self.deserialize_fn(obj) + + +JSON_SERDE_REGISTRY: dict[type[T], JSONSerde[T]] = {} + + +def register_json_serde( + type_: type[T], + serialize: Callable[[T], Json] | None = None, + deserialize: Callable[[Json], T] | None = None, +) -> None: + if type_ in JSON_SERDE_REGISTRY: + raise ValueError(f"Type {type_} is already registered") + + if serialize is None: + serialize = _noop_fn + + if deserialize is None: + deserialize = _noop_fn + + JSON_SERDE_REGISTRY[type_] = JSONSerde( + klass=type_, + serialize_fn=serialize, + deserialize_fn=deserialize, + ) + + +# Standard JSON primitives +register_json_serde(int) +register_json_serde(str) +register_json_serde(bool) +register_json_serde(float) +register_json_serde(type(None)) +register_json_serde(pydantic.EmailStr) + +# Syft primitives +register_json_serde(UID, lambda uid: uid.no_dash, lambda s: UID(s)) +register_json_serde(LineageID, lambda uid: uid.no_dash, lambda s: LineageID(s)) +register_json_serde( + DateTime, lambda dt: dt.utc_timestamp, lambda f: DateTime(utc_timestamp=f) +) +register_json_serde( + BaseDateTime, lambda dt: dt.utc_timestamp, lambda f: BaseDateTime(utc_timestamp=f) +) +register_json_serde(SyftVerifyKey, lambda key: str(key), SyftVerifyKey.from_string) +register_json_serde(SyftSigningKey, lambda key: str(key), SyftSigningKey.from_string) + + +def _validate_json(value: T) -> T: + # Throws TypeError if value is not JSON-serializable + json.dumps(value) + return value + + +def _is_optional_annotation(annotation: Any) -> bool: + try: + return annotation | None == annotation + except TypeError: + return False + + +def _is_annotated_type(annotation: Any) -> bool: + return get_origin(annotation) == typing.Annotated + + +def _unwrap_optional_annotation(annotation: Any) -> Any: + """Return the type anntation with None type removed, if it is present. + + Args: + annotation (Any): type annotation + + Returns: + Any: type annotation without None type + """ + if _is_optional_annotation(annotation): + args = get_args(annotation) + return Union[tuple(arg for arg in args if arg is not type(None))] # noqa + return annotation + + +def _unwrap_annotated(annotation: Any) -> Any: + # Convert Annotated[T, ...] to T + return get_args(annotation)[0] + + +def _unwrap_type_annotation(annotation: Any) -> Any: + """ + recursively unwrap type annotations, removing Annotated and Optional types + """ + if _is_annotated_type(annotation): + res = _unwrap_annotated(annotation) + return _unwrap_type_annotation(res) + elif _is_optional_annotation(annotation): + res = _unwrap_optional_annotation(annotation) + return _unwrap_type_annotation(res) + return annotation + + +def _annotation_issubclass(annotation: Any, cls: type) -> bool: + # issubclass throws TypeError if annotation is not a valid type (eg Union) + try: + return issubclass(annotation, cls) + except TypeError: + return False + + +def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: + canonical_name, version = SyftObjectRegistry.get_canonical_name_version(obj) + serde_attributes = SyftObjectRegistry.get_serde_properties(canonical_name, version) + exclude_attrs = serde_attributes[4] + + result: dict[str, Json] = { + JSON_CANONICAL_NAME_FIELD: canonical_name, + JSON_VERSION_FIELD: version, + } + + all_exclude_attrs = set(exclude_attrs) | DEFAULT_EXCLUDE_ATTRS + + for key, type_ in obj.model_fields.items(): + if key in all_exclude_attrs: + continue + result[key] = serialize_json(getattr(obj, key), type_.annotation) + + result = _add_searchable_and_unique_attrs(obj, result, raise_errors=False) + + return result + + +def get_property_return_type(obj: Any, attr_name: str) -> Any: + """ + Get the return type annotation of a @property. + """ + cls = type(obj) + attr = getattr(cls, attr_name, None) + + if isinstance(attr, property): + return attr.fget.__annotations__.get("return", None) + + return None + + +def _add_searchable_and_unique_attrs( + obj: pydantic.BaseModel, obj_dict: dict[str, Json], raise_errors: bool = True +) -> dict[str, Json]: + """ + Add searchable attrs and unique attrs to the serialized object dict, if they are not already present. + Needed for adding non-field attributes (like @property) + + Args: + obj (pydantic.BaseModel): Object to serialize. + obj_dict (dict[str, Json]): Serialized object dict. Should contain the object's fields. + raise_errors (bool, optional): Raise errors if an attribute cannot be accessed. + If False, the attribute will be skipped. Defaults to True. + + Raises: + Exception: Any exception raised when accessing an attribute. + + Returns: + dict[str, Json]: Serialized object dict including searchable attributes. + """ + searchable_attrs: list[str] = getattr(obj, "__attr_searchable__", []) + unique_attrs: list[str] = getattr(obj, "__attr_unique__", []) + + attrs_to_add = set(searchable_attrs) | set(unique_attrs) + for attr in attrs_to_add: + if attr not in obj_dict: + try: + value = getattr(obj, attr) + except Exception as e: + if raise_errors: + raise e + else: + continue + property_annotation = get_property_return_type(obj, attr) + obj_dict[attr] = serialize_json( + value, validate=False, annotation=property_annotation + ) + + return obj_dict + + +def _deserialize_pydantic_from_json( + obj_dict: dict[str, Json], +) -> pydantic.BaseModel: + try: + canonical_name = obj_dict[JSON_CANONICAL_NAME_FIELD] + version = obj_dict[JSON_VERSION_FIELD] + obj_type = SyftObjectRegistry.get_serde_class(canonical_name, version) + + result = {} + for key, type_ in obj_type.model_fields.items(): + if key not in obj_dict: + continue + result[key] = deserialize_json(obj_dict[key], type_.annotation) + + return obj_type.model_validate(result) + except Exception as e: + print(f"Failed to deserialize Pydantic model: {e}") + print(json.dumps(obj_dict, indent=2)) + raise ValueError(f"Failed to deserialize Pydantic model: {e}") + + +def _is_serializable_iterable(annotation: Any) -> bool: + # we can only serialize typed iterables without Union/Any + # NOTE optional is allowed + + # 1. check if it is an iterable + if get_origin(annotation) not in {list, tuple, set, frozenset}: + return False + + # 2. check if iterable annotation is serializable + args = get_args(annotation) + if len(args) != 1: + return False + + inner_type = _unwrap_type_annotation(args[0]) + return inner_type in JSON_SERDE_REGISTRY or _annotation_issubclass( + inner_type, pydantic.BaseModel + ) + + +def _serialize_iterable_to_json(value: Any, annotation: Any) -> Json: + # No need to validate in recursive calls + return [serialize_json(v, validate=False) for v in value] + + +def _deserialize_iterable_from_json(value: Json, annotation: Any) -> Any: + if not isinstance(value, list): + raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") + + annotation = _unwrap_type_annotation(annotation) + + if not _is_serializable_iterable(annotation): + raise ValueError(f"Cannot deserialize {annotation} from JSON") + + inner_type = _unwrap_type_annotation(get_args(annotation)[0]) + return [deserialize_json(v, inner_type) for v in value] + + +def _is_serializable_mapping(annotation: Any) -> bool: + """ + Mapping is serializable if: + - it is a dict + - the key type is str + - the value type is serializable and not a Union + """ + if get_origin(annotation) != dict: + return False + + args = get_args(annotation) + if len(args) != 2: + return False + + key_type, value_type = args + # JSON only allows string keys + if not isinstance(key_type, str): + return False + + # check if value type is serializable + value_type = _unwrap_type_annotation(value_type) + return value_type in JSON_SERDE_REGISTRY or _annotation_issubclass( + value_type, pydantic.BaseModel + ) + + +def _serialize_mapping_to_json(value: Any, annotation: Any) -> Json: + _, value_type = get_args(annotation) + # No need to validate in recursive calls + return {k: serialize_json(v, value_type, validate=False) for k, v in value.items()} + + +def _deserialize_mapping_from_json(value: Json, annotation: Any) -> Any: + if not isinstance(value, dict): + raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") + + annotation = _unwrap_type_annotation(annotation) + + if not _is_serializable_mapping(annotation): + raise ValueError(f"Cannot deserialize {annotation} from JSON") + + _, value_type = get_args(annotation) + return {k: deserialize_json(v, value_type) for k, v in value.items()} + + +def _serialize_to_json_bytes(obj: Any) -> str: + obj_bytes = sy.serialize(obj, to_bytes=True) + return base64.b64encode(obj_bytes).decode("utf-8") + + +def _deserialize_from_json_bytes(obj: str) -> Any: + obj_bytes = base64.b64decode(obj) + return sy.deserialize(obj_bytes, from_bytes=True) + + +def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> Json: + """ + Serialize a value to a JSON-serializable object, using the schema defined by the + provided annotation. + + Serialization is always done according to the annotation, as the same annotation + is used for deserialization. If the annotation is not provided or is ambiguous, + the JSON serialization will fall back to serializing bytes. Examples: + - int, `list[int]` are strictly typed + - `str | int`, `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes + - Optional types (like int | None) are serialized to the not-None type + + The function chooses the appropriate serialization method in the following order: + 1. Method registered in `JSON_SERDE_REGISTRY` for the annotation type. + 2. Pydantic model serialization, including all `SyftObjects`. + 3. Iterable serialization, if the annotation is a strict iterable (e.g., `list[int]`). + 4. Mapping serialization, if the annotation is a strictly typed mapping with string keys. + 5. Serialize the object to bytes and encode it as base64. + + Args: + value (Any): Value to serialize. + annotation (Any, optional): Type annotation for the value. Defaults to None. + + Returns: + Json: JSON-serializable object. + """ + if annotation is None: + annotation = type(value) + + if value is None: + return None + + # Remove None type from annotation if it is present. + annotation = _unwrap_type_annotation(annotation) + + if annotation in JSON_SERDE_REGISTRY: + result = JSON_SERDE_REGISTRY[annotation].serialize(value) + elif _annotation_issubclass(annotation, pydantic.BaseModel): + result = _serialize_pydantic_to_json(value) + elif _annotation_issubclass(annotation, Enum): + result = value.name + + # JSON recursive types + # only strictly annotated iterables and mappings are supported + # example: list[int] is supported, but not list[int | str] + elif _is_serializable_iterable(annotation): + result = _serialize_iterable_to_json(value, annotation) + elif _is_serializable_mapping(annotation): + result = _serialize_mapping_to_json(value, annotation) + else: + result = _serialize_to_json_bytes(value) + + if validate: + _validate_json(result) + + return result + + +def deserialize_json(value: Json, annotation: Any = None) -> Any: + """Deserialize a JSON-serializable object to a value, using the schema defined by the + provided annotation. Inverse of `serialize_json`. + + Args: + value (Json): JSON-serializable object. + annotation (Any): Type annotation for the value. + + Returns: + Any: Deserialized value. + """ + if ( + isinstance(value, dict) + and JSON_CANONICAL_NAME_FIELD in value + and JSON_VERSION_FIELD in value + ): + return _deserialize_pydantic_from_json(value) + + if value is None: + return None + + # Remove None type from annotation if it is present. + if annotation is None: + raise ValueError("Annotation is required for deserialization") + + annotation = _unwrap_type_annotation(annotation) + + if annotation in JSON_SERDE_REGISTRY: + return JSON_SERDE_REGISTRY[annotation].deserialize(value) + elif _annotation_issubclass(annotation, pydantic.BaseModel): + return _deserialize_pydantic_from_json(value) + elif _annotation_issubclass(annotation, Enum): + return annotation[value] + elif isinstance(value, list): + return _deserialize_iterable_from_json(value, annotation) + elif isinstance(value, dict): + return _deserialize_mapping_from_json(value, annotation) + elif isinstance(value, str): + return _deserialize_from_json_bytes(value) + else: + raise ValueError(f"Cannot deserialize {value} to {annotation}") + + +def is_json_primitive(value: Any) -> bool: + serialized = serialize_json(value, validate=False) + return isinstance(serialized, JsonPrimitive) # type: ignore diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index ed0379ae51b..d5694b4efe6 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -25,6 +25,7 @@ recursive_scheme = get_capnp_schema("recursive_serde.capnp").RecursiveSerde SPOOLED_FILE_MAX_SIZE_SERDE = 50 * (1024**2) # 50MB +DEFAULT_EXCLUDE_ATTRS: set[str] = {"syft_pre_hooks__", "syft_post_hooks__"} def get_types(cls: type, keys: list[str] | None = None) -> list[type] | None: @@ -192,9 +193,7 @@ def recursive_serde_register( attribute_list.update(["value"]) exclude_attrs = [] if exclude_attrs is None else exclude_attrs - attribute_list = ( - attribute_list - set(exclude_attrs) - {"syft_pre_hooks__", "syft_post_hooks__"} - ) + attribute_list = attribute_list - set(exclude_attrs) - DEFAULT_EXCLUDE_ATTRS if inheritable_attrs and attribute_list and not is_pydantic: # only set __syft_serializable__ for non-pydantic classes because diff --git a/packages/syft/src/syft/serde/signature.py b/packages/syft/src/syft/serde/signature.py index 23b0a556fca..0887d148367 100644 --- a/packages/syft/src/syft/serde/signature.py +++ b/packages/syft/src/syft/serde/signature.py @@ -86,6 +86,15 @@ def signature_remove_context(signature: Signature) -> Signature: ) +def signature_remove(signature: Signature, args: list[str]) -> Signature: + params = dict(signature.parameters) + for arg in args: + params.pop(arg, None) + return Signature( + list(params.values()), return_annotation=signature.return_annotation + ) + + def get_str_signature_from_docstring(doc: str, callable_name: str) -> str | None: if not doc or callable_name not in doc: return None diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 89cc5ffaab5..6fadf6261f9 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -18,7 +18,6 @@ import pyarrow.parquet as pq import pydantic from pydantic._internal._model_construction import ModelMetaclass -from pymongo.collection import Collection # relative from ..types.dicttuple import DictTuple @@ -58,11 +57,6 @@ # exceptions recursive_serde_register(cls=TypeError, canonical_name="TypeError", version=1) -# mongo collection -recursive_serde_register_type( - Collection, canonical_name="pymongo_collection", version=1 -) - def serialize_dataframe(df: DataFrame) -> bytes: table = pa.Table.from_pandas(df) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 1285e787428..c3d25b10440 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -18,6 +18,7 @@ from time import sleep import traceback from typing import Any +from typing import TypeVar from typing import cast # third party @@ -39,10 +40,6 @@ from ..protocol.data_protocol import get_data_protocol from ..service.action.action_object import Action from ..service.action.action_object import ActionObject -from ..service.action.action_store import ActionStore -from ..service.action.action_store import DictActionStore -from ..service.action.action_store import MongoActionStore -from ..service.action.action_store import SQLiteActionStore from ..service.code.user_code_stash import UserCodeStash from ..service.context import AuthedServiceContext from ..service.context import ServerServiceContext @@ -55,6 +52,7 @@ from ..service.metadata.server_metadata import ServerMetadata from ..service.network.utils import PeerHealthCheckTask from ..service.notifier.notifier_service import NotifierService +from ..service.output.output_service import OutputStash from ..service.queue.base_queue import AbstractMessageHandler from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer @@ -75,12 +73,10 @@ from ..service.service import UserServiceConfigRegistry from ..service.settings.settings import ServerSettings from ..service.settings.settings import ServerSettingsUpdate -from ..service.settings.settings_stash import SettingsStash from ..service.user.user import User from ..service.user.user import UserCreate from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole -from ..service.user.user_stash import UserStash from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image @@ -92,12 +88,17 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ..store.dict_document_store import DictStoreConfig +from ..store.db.db import DBConfig +from ..store.db.db import DBManager +from ..store.db.postgres import PostgresDBConfig +from ..store.db.postgres import PostgresDBManager +from ..store.db.sqlite import SQLiteDBConfig +from ..store.db.sqlite import SQLiteDBManager +from ..store.db.stash import ObjectStash from ..store.document_store import StoreConfig from ..store.document_store_errors import NotFoundException from ..store.document_store_errors import StashException from ..store.linked_obj import LinkedObject -from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig from ..types.datetime import DATETIME_FORMAT @@ -128,6 +129,8 @@ logger = logging.getLogger(__name__) +SyftT = TypeVar("SyftT", bound=SyftObject) + # if user code needs to be serded and its not available we can call this to refresh # the code for a specific server UID and thread CODE_RELOADER: dict[int, Callable] = {} @@ -307,6 +310,7 @@ def __init__( signing_key: SyftSigningKey | SigningKey | None = None, action_store_config: StoreConfig | None = None, document_store_config: StoreConfig | None = None, + db_config: DBConfig | None = None, root_email: str | None = default_root_email, root_username: str | None = default_root_username, root_password: str | None = default_root_password, @@ -336,6 +340,7 @@ def __init__( association_request_auto_approval: bool = False, background_tasks: bool = False, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -351,6 +356,7 @@ def __init__( self.server_side_type = ServerSideType(server_side_type) self.client_cache: dict = {} self.peer_client_cache: dict = {} + self._settings = None if isinstance(server_type, str): server_type = ServerType(server_type) @@ -396,27 +402,33 @@ def __init__( if reset: self.remove_temp_dir() - use_sqlite = local_db or (processes > 0 and not is_subprocess) document_store_config = document_store_config or self.get_default_store( - use_sqlite=use_sqlite, store_type="Document Store", ) action_store_config = action_store_config or self.get_default_store( - use_sqlite=use_sqlite, store_type="Action Store", ) - self.init_stores( - action_store_config=action_store_config, - document_store_config=document_store_config, - ) + db_config = DBConfig.from_connection_string(db_url) if db_url else db_config + + if db_config is None: + db_config = SQLiteDBConfig( + filename=f"{self.id}_json.db", + path=self.get_temp_dir("db"), + ) + + self.db_config = db_config + + self.db = self.init_stores(db_config=self.db_config) # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) + self.db.init_tables(reset=reset) + self.action_store = self.services.action.stash - create_admin_new( # nosec B106 + create_root_admin_if_not_exists( name=root_username, email=root_email, - password=root_password, + password=root_password, # nosec server=self, ) @@ -520,21 +532,19 @@ def runs_in_docker(self) -> bool: and any("docker" in line for line in open(path)) ) - def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig: - if use_sqlite: - path = self.get_temp_dir("db") - file_name: str = f"{self.id}.sqlite" - if self.dev_mode: - # leave this until the logger shows this in the notebook - print(f"{store_type}'s SQLite DB path: {path/file_name}") - logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") - return SQLiteStoreConfig( - client_config=SQLiteStoreClientConfig( - filename=file_name, - path=path, - ) + def get_default_store(self, store_type: str) -> StoreConfig: + path = self.get_temp_dir("db") + file_name: str = f"{self.id}.sqlite" + # if self.dev_mode: + # leave this until the logger shows this in the notebook + # print(f"{store_type}'s SQLite DB path: {path/file_name}") + # logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") + return SQLiteStoreConfig( + client_config=SQLiteStoreClientConfig( + filename=file_name, + path=path, ) - return DictStoreConfig() + ) def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: if config is None: @@ -648,6 +658,7 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None: worker_stash=self.worker_stash, ) producer.run() + address = producer.address else: port = queue_config.client_config.queue_port @@ -754,6 +765,8 @@ def named( association_request_auto_approval: bool = False, background_tasks: bool = False, consumer_type: ConsumerType | None = None, + db_url: str | None = None, + db_config: DBConfig | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -785,6 +798,8 @@ def named( association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, consumer_type=consumer_type, + db_url=db_url, + db_config=db_config, ) def is_root(self, credentials: SyftVerifyKey) -> bool: @@ -906,59 +921,36 @@ def reload_user_code() -> None: if ti is not None: CODE_RELOADER[ti] = reload_user_code - def init_stores( - self, - document_store_config: StoreConfig, - action_store_config: StoreConfig, - ) -> None: - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - if isinstance(document_store_config, MongoStoreConfig): - document_store_config.client_config.server_obj_python_id = id(self) - - self.document_store_config = document_store_config - self.document_store = document_store_config.store_type( - server_uid=self.id, - root_verify_key=self.verify_key, - store_config=document_store_config, - ) - - if isinstance(action_store_config, SQLiteStoreConfig): - self.action_store: ActionStore = SQLiteActionStore( + def init_stores(self, db_config: DBConfig) -> DBManager: + if isinstance(db_config, SQLiteDBConfig): + db = SQLiteDBManager( + config=db_config, server_uid=self.id, - store_config=action_store_config, root_verify_key=self.verify_key, - document_store=self.document_store, ) - elif isinstance(action_store_config, MongoStoreConfig): - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - action_store_config.client_config.server_obj_python_id = id(self) - - self.action_store = MongoActionStore( + elif isinstance(db_config, PostgresDBConfig): + db = PostgresDBManager( # type: ignore + config=db_config, server_uid=self.id, root_verify_key=self.verify_key, - store_config=action_store_config, - document_store=self.document_store, ) else: - self.action_store = DictActionStore( - server_uid=self.id, - root_verify_key=self.verify_key, - document_store=self.document_store, - ) + raise SyftException(public_message=f"Unsupported DB config: {db_config}") + + self.queue_stash = QueueStash(store=db) - self.action_store_config = action_store_config - self.queue_stash = QueueStash(store=self.document_store) + print(f"Using {db_config.__class__.__name__} and {db_config.connection_string}") + + return db @property def job_stash(self) -> JobStash: return self.services.job.stash + @property + def output_stash(self) -> OutputStash: + return self.services.output.stash + @property def worker_stash(self) -> WorkerStash: return self.services.worker.stash @@ -979,6 +971,12 @@ def get_service_method(self, path_or_func: str | Callable) -> Callable: def get_service(self, path_or_func: str | Callable) -> AbstractService: return self.services.get_service(path_or_func) + @as_result(ValueError) + def get_stash(self, object_type: SyftT) -> ObjectStash[SyftT]: + if object_type not in self.services.stashes: + raise ValueError(f"Stash for {object_type} not found.") + return self.services.stashes[object_type] + def _get_service_method_from_path(self, path: str) -> Callable: path_list = path.split(".") method_name = path_list.pop() @@ -1016,10 +1014,12 @@ def update_self(self, settings: ServerSettings) -> None: # it should be removed once the settings are refactored and the inconsistencies between # settings and services are resolved. def get_settings(self) -> ServerSettings | None: + if self._settings: + return self._settings # type: ignore if self.signing_key is None: raise ValueError(f"{self} has no signing key") - settings_stash = SettingsStash(store=self.document_store) + settings_stash = self.services.settings.stash try: settings = settings_stash.get_all(self.signing_key.verify_key).unwrap() @@ -1027,6 +1027,7 @@ def get_settings(self) -> ServerSettings | None: if len(settings) > 0: setting = settings[0] self.update_self(setting) + self._settings = setting return setting else: return None @@ -1039,7 +1040,7 @@ def settings(self) -> ServerSettings: if self.signing_key is None: raise ValueError(f"{self} has no signing key") - settings_stash = SettingsStash(store=self.document_store) + settings_stash = self.services.settings.stash error_msg = f"Cannot get server settings for '{self.name}'" all_settings = settings_stash.get_all(self.signing_key.verify_key).unwrap( @@ -1479,7 +1480,9 @@ def add_queueitem_to_queue( result_obj.syft_server_location = self.id result_obj.syft_client_verify_key = credentials - if not self.services.action.store.exists(uid=action.result_id): + if not self.services.action.stash.exists( + credentials=credentials, uid=action.result_id + ): self.services.action.set_result_to_store( result_action_object=result_obj, context=context, @@ -1687,7 +1690,7 @@ def get_unauthed_context( @as_result(SyftException, StashException) def create_initial_settings(self, admin_email: str) -> ServerSettings: - settings_stash = SettingsStash(store=self.document_store) + settings_stash = self.services.settings.stash if self.signing_key is None: logger.debug("create_initial_settings failed as there is no signing key") @@ -1741,17 +1744,32 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings: ).unwrap() -def create_admin_new( +def create_root_admin_if_not_exists( name: str, email: str, password: str, - server: AbstractServer, + server: Server, ) -> User | None: - user_stash = UserStash(store=server.document_store) + """ + If no root admin exists: + - all exists checks on the user stash will fail, as we cannot get the role for the admin to check if it exists + - result: a new admin is always created + + If a root admin exists with a different email: + - cause: DEFAULT_USER_EMAIL env variable is set to a different email than the root admin in the db + - verify_key_exists will return True + - result: no new admin is created, as the server already has a root admin + """ + user_stash = server.services.user.stash + + email_exists = user_stash.email_exists(email=email).unwrap() + if email_exists: + logger.debug("Admin not created, a user with this email already exists") + return None - user_exists = user_stash.email_exists(email=email).unwrap() - if user_exists: - logger.debug("Admin not created, admin already exists") + verify_key_exists = user_stash.verify_key_exists(server.verify_key).unwrap() + if verify_key_exists: + logger.debug("Admin not created, this server already has a root admin") return None create_user = UserCreate( @@ -1766,12 +1784,12 @@ def create_admin_new( # 🟡 TODO: change later but for now this gives the main user super user automatically user = create_user.to(User) user.signing_key = server.signing_key - user.verify_key = user.signing_key.verify_key + user.verify_key = server.verify_key new_user = user_stash.set( - credentials=server.signing_key.verify_key, + credentials=server.verify_key, obj=user, - ignore_duplicates=True, + ignore_duplicates=False, ).unwrap() logger.debug(f"Created admin {new_user.email}") diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index d7c3555f10c..dfb7f331972 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -3,15 +3,15 @@ from dataclasses import dataclass from dataclasses import field import typing -from typing import Any from typing import TYPE_CHECKING +from typing import TypeVar # relative from ..serde.serializable import serializable from ..service.action.action_service import ActionService -from ..service.action.action_store import ActionStore from ..service.api.api_service import APIService from ..service.attestation.attestation_service import AttestationService +from ..service.blob_storage.remote_profile import RemoteProfileService from ..service.blob_storage.service import BlobStorageService from ..service.code.status_service import UserCodeStatusService from ..service.code.user_code_service import UserCodeService @@ -40,12 +40,17 @@ from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_service import WorkerService +from ..store.db.stash import ObjectStash +from ..types.syft_object import SyftObject if TYPE_CHECKING: # relative from .server import Server +StashT = TypeVar("StashT", bound=SyftObject) + + @serializable(canonical_name="ServiceRegistry", version=1) @dataclass class ServiceRegistry: @@ -79,11 +84,13 @@ class ServiceRegistry: sync: SyncService output: OutputService user_code_status: UserCodeStatusService + remote_profile: RemoteProfileService services: list[AbstractService] = field(default_factory=list, init=False) service_path_map: dict[str, AbstractService] = field( default_factory=dict, init=False ) + stashes: dict[StashT, ObjectStash[StashT]] = field(default_factory=dict, init=False) @classmethod def for_server(cls, server: "Server") -> "ServiceRegistry": @@ -95,6 +102,11 @@ def __post_init__(self) -> None: self.services.append(service) self.service_path_map[service_cls.__name__.lower()] = service + # TODO ActionService now has same stash, but interface is still different. Fix this. + if hasattr(service, "stash") and not issubclass(service_cls, ActionService): + stash: ObjectStash = service.stash + self.stashes[stash.object_type] = stash + @classmethod def get_service_classes( cls, @@ -109,13 +121,7 @@ def get_service_classes( def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: service_dict = {} for field_name, service_cls in cls.get_service_classes().items(): - svc_kwargs: dict[str, Any] = {} - if issubclass(service_cls.store_type, ActionStore): - svc_kwargs["store"] = server.action_store - else: - svc_kwargs["store"] = server.document_store - - service = service_cls(**svc_kwargs) + service = service_cls(store=server.db) # type: ignore service_dict[field_name] = service return service_dict @@ -133,3 +139,6 @@ def _get_service_from_path(self, path: str) -> AbstractService: return self.service_path_map[service_name.lower()] except KeyError: raise ValueError(f"Service {path} not found.") + + def __iter__(self) -> typing.Iterator[AbstractService]: + return iter(self.services) diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 80f15a6d5ba..e1982953a32 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -1,6 +1,8 @@ # stdlib from collections.abc import Callable from contextlib import asynccontextmanager +import json +import logging import multiprocessing import multiprocessing.synchronize import os @@ -25,6 +27,7 @@ from ..abstract_server import ServerSideType from ..client.client import API_PATH from ..deployment_type import DeploymentType +from ..store.db.db import DBConfig from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT from ..util.telemetry import TRACING_ENABLED @@ -46,6 +49,9 @@ WAIT_TIME_SECONDS = 20 +logger = logging.getLogger("uvicorn") + + class AppSettings(BaseSettings): name: str server_type: ServerType = ServerType.DATASITE @@ -61,6 +67,8 @@ class AppSettings(BaseSettings): n_consumers: int = 0 association_request_auto_approval: bool = False background_tasks: bool = False + db_config: DBConfig | None = None + db_url: str | None = None model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -91,6 +99,10 @@ def app_factory() -> FastAPI: worker_class = worker_classes[settings.server_type] kwargs = settings.model_dump() + + logger.info( + f"Starting server with settings: {kwargs} and worker class: {worker_class}" + ) if settings.dev_mode: print( f"WARN: private key is based on server name: {settings.name} in dev_mode. " @@ -175,6 +187,8 @@ def run_uvicorn( env_prefix = AppSettings.model_config.get("env_prefix", "") for key, value in kwargs.items(): key_with_prefix = f"{env_prefix}{key.upper()}" + if isinstance(value, dict): + value = json.dumps(value) os.environ[key_with_prefix] = str(value) # The `serve_server` function calls `run_uvicorn` in a separate process using `multiprocessing.Process`. @@ -220,6 +234,7 @@ def serve_server( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, + db_url: str | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -249,6 +264,7 @@ def serve_server( "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, "deployment_type": deployment_type, + "db_url": db_url, }, ) diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index 57a69d2a4eb..3e10cc7d5fa 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -1,6 +1,9 @@ # future from __future__ import annotations +# stdlib +from collections.abc import Callable + # third party from typing_extensions import Self @@ -13,16 +16,21 @@ from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig +from ..store.db.db import DBConfig from ..store.document_store import StoreConfig +from ..types.syft_migration import migrate from ..types.syft_object import SYFT_OBJECT_VERSION_1 +from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject +from ..types.transforms import TransformContext +from ..types.transforms import drop from ..types.uid import UID @serializable() class WorkerSettings(SyftObject): __canonical_name__ = "WorkerSettings" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID name: str @@ -30,28 +38,58 @@ class WorkerSettings(SyftObject): server_side_type: ServerSideType deployment_type: DeploymentType = DeploymentType.REMOTE signing_key: SyftSigningKey - document_store_config: StoreConfig - action_store_config: StoreConfig + db_config: DBConfig blob_store_config: BlobStorageConfig | None = None queue_config: QueueConfig | None = None log_level: int | None = None @classmethod def from_server(cls, server: AbstractServer) -> Self: - if server.server_side_type: - server_side_type: str = server.server_side_type.value - else: - server_side_type = ServerSideType.HIGH_SIDE + server_side_type = server.server_side_type or ServerSideType.HIGH_SIDE return cls( id=server.id, name=server.name, server_type=server.server_type, signing_key=server.signing_key, - document_store_config=server.document_store_config, - action_store_config=server.action_store_config, + db_config=server.db_config, server_side_type=server_side_type, blob_store_config=server.blob_store_config, queue_config=server.queue_config, log_level=server.log_level, deployment_type=server.deployment_type, ) + + +@serializable() +class WorkerSettingsV1(SyftObject): + __canonical_name__ = "WorkerSettings" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + name: str + server_type: ServerType + server_side_type: ServerSideType + deployment_type: DeploymentType = DeploymentType.REMOTE + signing_key: SyftSigningKey + document_store_config: StoreConfig + action_store_config: StoreConfig + blob_store_config: BlobStorageConfig | None = None + queue_config: QueueConfig | None = None + log_level: int | None = None + + +def set_db_config(context: TransformContext) -> TransformContext: + if context.output: + context.output["db_config"] = ( + context.server.db_config if context.server is not None else DBConfig() + ) + return context + + +@migrate(WorkerSettingsV1, WorkerSettings) +def migrate_workersettings_v1_to_v2() -> list[Callable]: + return [ + drop("document_store_config"), + drop("action_store_config"), + set_db_config, + ] diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index d7e566b733a..275767665c8 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -254,8 +254,8 @@ class ActionObjectPointer: "created_date", # syft "updated_date", # syft "deleted_date", # syft - "to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs "__attr_searchable__", # syft + "__attr_unique__", # syft "__canonical_name__", # syft "__version__", # syft "__args__", # pydantic diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 03992eeab07..ab6f9b7ce9a 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -17,12 +17,29 @@ class ActionPermission(Enum): ALL_WRITE = 32 EXECUTE = 64 ALL_EXECUTE = 128 + ALL_OWNER = 256 + + @property + def as_compound(self) -> "ActionPermission": + if self in COMPOUND_ACTION_PERMISSION: + return self + elif self == ActionPermission.READ: + return ActionPermission.ALL_READ + elif self == ActionPermission.WRITE: + return ActionPermission.ALL_WRITE + elif self == ActionPermission.EXECUTE: + return ActionPermission.ALL_EXECUTE + elif self == ActionPermission.OWNER: + return ActionPermission.ALL_OWNER + else: + raise Exception(f"Invalid compound permission {self}") COMPOUND_ACTION_PERMISSION = { ActionPermission.ALL_READ, ActionPermission.ALL_WRITE, ActionPermission.ALL_EXECUTE, + ActionPermission.ALL_OWNER, } @@ -64,6 +81,10 @@ def permission_string(self) -> str: return f"{self.credentials.verify}_{self.permission.name}" return f"{self.permission.name}" + @property + def compound_permission_string(self) -> str: + return self.permission.as_compound.name + def _coll_repr_(self) -> dict[str, Any]: return { "uid": str(self.uid), @@ -122,3 +143,7 @@ def _coll_repr_(self) -> dict[str, Any]: "uid": str(self.uid), "server_uid": str(self.server_uid), } + + @property + def permission_string(self) -> str: + return str(self.server_uid) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 94935631f01..bd871a18164 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -9,6 +9,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime @@ -43,8 +44,8 @@ from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionPermission -from .action_store import ActionStore -from .action_store import KeyValueActionStore +from .action_permissions import StoragePermission +from .action_store import ActionObjectStash from .action_types import action_type_for_type from .numpy import NumpyArrayObject from .pandas import PandasDataFrameObject # noqa: F401 @@ -55,10 +56,10 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): - store_type = ActionStore + stash: ActionObjectStash - def __init__(self, store: KeyValueActionStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + self.stash = ActionObjectStash(store) @service_method(path="action.np_array", name="np_array") def np_array(self, context: AuthedServiceContext, data: Any) -> Any: @@ -178,7 +179,7 @@ def _set( or has_result_read_permission ) - self.store.set( + self.stash.set_or_update( uid=action_object.id, credentials=context.credentials, syft_object=action_object, @@ -237,7 +238,7 @@ def resolve_links( ) -> ActionObject: """Get an object from the action store""" # If user has permission to get the object / object exists - result = self.store.get(uid=uid, credentials=context.credentials).unwrap() + result = self.stash.get(uid=uid, credentials=context.credentials).unwrap() # If it's not a leaf if result.is_link: @@ -271,7 +272,7 @@ def _get( resolve_nested: bool = True, ) -> ActionObject | TwinObject: """Get an object from the action store""" - obj = self.store.get( + obj = self.stash.get( uid=uid, credentials=context.credentials, has_permission=has_permission ).unwrap() @@ -314,7 +315,7 @@ def get_pointer( self, context: AuthedServiceContext, uid: UID ) -> ActionObjectPointer: """Get a pointer from the action store""" - obj = self.store.get_pointer( + obj = self.stash.get_pointer( uid=uid, credentials=context.credentials, server_uid=context.server.id ).unwrap() @@ -328,7 +329,7 @@ def get_pointer( @service_method(path="action.get_mock", name="get_mock", roles=GUEST_ROLE_LEVEL) def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: """Get a pointer from the action store""" - return self.store.get_mock(uid=uid).unwrap() + return self.stash.get_mock(credentials=context.credentials, uid=uid).unwrap() @service_method( path="action.has_storage_permission", @@ -336,10 +337,12 @@ def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: roles=GUEST_ROLE_LEVEL, ) def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> bool: - return self.store.has_storage_permission(uid) + return self.stash.has_storage_permission( + StoragePermission(uid=uid, server_uid=context.server.id) + ) def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool: - return self.store.has_permissions( + return self.stash.has_permissions( [ActionObjectREAD(uid=uid, credentials=context.credentials)] ) @@ -558,7 +561,7 @@ def blob_permission( if len(output_readers) > 0: store_permissions = [store_permission(x) for x in output_readers] - self.store.add_permissions(store_permissions) + self.stash.add_permissions(store_permissions) if result_blob_id is not None: blob_permissions = [blob_permission(x) for x in output_readers] @@ -880,12 +883,12 @@ def has_read_permission_for_action_result( ActionObjectREAD(uid=_id, credentials=context.credentials) for _id in action_obj_ids ] - return self.store.has_permissions(permissions) + return self.stash.has_permissions(permissions) @service_method(path="action.exists", name="exists", roles=GUEST_ROLE_LEVEL) def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool: """Checks if the given object id exists in the Action Store""" - return self.store.exists(obj_id) + return self.stash.exists(context.credentials, obj_id) @service_method( path="action.delete", @@ -896,7 +899,7 @@ def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool: def delete( self, context: AuthedServiceContext, uid: UID, soft_delete: bool = False ) -> SyftSuccess: - obj = self.store.get(uid=uid, credentials=context.credentials).unwrap() + obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap() return_msg = [] @@ -952,7 +955,7 @@ def _delete_from_action_store( soft_delete: bool = False, ) -> SyftSuccess: if soft_delete: - obj = self.store.get(uid=uid, credentials=context.credentials).unwrap() + obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap() if isinstance(obj, TwinObject): self._soft_delete_action_obj( @@ -964,7 +967,7 @@ def _delete_from_action_store( if isinstance(obj, ActionObject): self._soft_delete_action_obj(context=context, action_obj=obj).unwrap() else: - self.store.delete(credentials=context.credentials, uid=uid).unwrap() + self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() return SyftSuccess(message=f"Action object with uid '{uid}' deleted.") diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 0228165edb9..e6597d19d25 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -1,128 +1,53 @@ # future from __future__ import annotations -# stdlib -import threading - # relative from ...serde.serializable import serializable -from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.dict_document_store import DictStoreConfig -from ...store.document_store import BasePartitionSettings -from ...store.document_store import DocumentStore -from ...store.document_store import StoreConfig +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException -from ...store.document_store_errors import ObjectCRUDPermissionException from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result from ...types.syft_object import SyftObject from ...types.twin_object import TwinObject -from ...types.uid import LineageID from ...types.uid import UID +from .action_object import ActionObject from .action_object import is_action_data_empty from .action_permissions import ActionObjectEXECUTE -from .action_permissions import ActionObjectOWNER from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionObjectWRITE -from .action_permissions import ActionPermission from .action_permissions import StoragePermission -lock = threading.RLock() - - -class ActionStore: - pass - - -@serializable(canonical_name="KeyValueActionStore", version=1) -class KeyValueActionStore(ActionStore): - """Generic Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including connection configuration, database name, or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - def __init__( - self, - server_uid: UID, - store_config: StoreConfig, - root_verify_key: SyftVerifyKey | None = None, - document_store: DocumentStore | None = None, - ) -> None: - self.server_uid = server_uid - self.store_config = store_config - self.settings = BasePartitionSettings(name="Action") - self.data = self.store_config.backing_store( - "data", self.settings, self.store_config - ) - self.permissions = self.store_config.backing_store( - "permissions", self.settings, self.store_config, ddtype=set - ) - self.storage_permissions = self.store_config.backing_store( - "storage_permissions", self.settings, self.store_config, ddtype=set - ) - if root_verify_key is None: - root_verify_key = SyftSigningKey.generate().verify_key - self.root_verify_key = root_verify_key - - self.__user_stash = None - if document_store is not None: - # relative - from ...service.user.user_stash import UserStash - - self.__user_stash = UserStash(store=document_store) +@serializable(canonical_name="ActionObjectSQLStore", version=1) +class ActionObjectStash(ObjectStash[ActionObject]): + # We are storing ActionObject, Action, TwinObject + allow_any_type = True @as_result(NotFoundException, SyftException) def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False - ) -> SyftObject: + ) -> ActionObject: uid = uid.id # We only need the UID from LineageID or UID - - # if you get something you need READ permission - read_permission = ActionObjectREAD(uid=uid, credentials=credentials) - - if not has_permission and not self.has_permission(read_permission): - raise SyftException(public_message=f"Permission: {read_permission} denied") - - # TODO: Remove try/except? - try: - if isinstance(uid, LineageID): - syft_object = self.data[uid.id] - elif isinstance(uid, UID): - syft_object = self.data[uid] - else: - raise SyftException( - public_message=f"Unrecognized UID type: {type(uid)}" - ) - return syft_object - except Exception as e: - raise NotFoundException.from_exception( - e, public_message=f"Object {uid} not found" - ) + # TODO remove and use get_by_uid instead + return self.get_by_uid( + credentials=credentials, + uid=uid, + has_permission=has_permission, + ).unwrap() @as_result(NotFoundException, SyftException) - def get_mock(self, uid: UID) -> SyftObject: + def get_mock(self, credentials: SyftVerifyKey, uid: UID) -> SyftObject: uid = uid.id # We only need the UID from LineageID or UID - try: - syft_object = self.data[uid] - - if isinstance(syft_object, TwinObject) and not is_action_data_empty( - syft_object.mock - ): - return syft_object.mock - raise NotFoundException(public_message=f"No mock found for object {uid}") - except Exception as e: - raise NotFoundException.from_exception( - e, public_message=f"Object {uid} not found" - ) + obj = self.get_by_uid( + credentials=credentials, uid=uid, has_permission=True + ).unwrap() + if isinstance(obj, TwinObject) and not is_action_data_empty(obj.mock): + return obj.mock + raise NotFoundException(public_message=f"No mock found for object {uid}") @as_result(NotFoundException, SyftException) def get_pointer( @@ -133,34 +58,27 @@ def get_pointer( ) -> SyftObject: uid = uid.id # We only need the UID from LineageID or UID - try: - if uid not in self.data: - raise SyftException(public_message="Permission denied") - - obj = self.data[uid] - read_permission = ActionObjectREAD(uid=uid, credentials=credentials) - - # if you have permission you can have private data - if self.has_permission(read_permission): - if isinstance(obj, TwinObject): - return obj.private.syft_point_to(server_uid) - return obj.syft_point_to(server_uid) + obj = self.get_by_uid( + credentials=credentials, uid=uid, has_permission=True + ).unwrap() + has_permissions = self.has_permission( + ActionObjectREAD(uid=uid, credentials=credentials) + ) - # if its a twin with a mock anyone can have this + if has_permissions: if isinstance(obj, TwinObject): - return obj.mock.syft_point_to(server_uid) + return obj.private.syft_point_to(server_uid) + return obj.syft_point_to(server_uid) # type: ignore - # finally worst case you get ActionDataEmpty so you can still trace - return obj.as_empty().syft_point_to(server_uid) - # TODO: Check if this can be removed - except Exception as e: - raise SyftException(public_message=str(e)) + # if its a twin with a mock anyone can have this + if isinstance(obj, TwinObject): + return obj.mock.syft_point_to(server_uid) - def exists(self, uid: UID) -> bool: - return uid.id in self.data # We only need the UID from LineageID or UID + # finally worst case you get ActionDataEmpty so you can still trace + return obj.as_empty().syft_point_to(server_uid) # type: ignore @as_result(SyftException, StashException) - def set( + def set_or_update( # type: ignore self, uid: UID, credentials: SyftVerifyKey, @@ -170,284 +88,47 @@ def set( ) -> UID: uid = uid.id # We only need the UID from LineageID or UID - # if you set something you need WRITE permission - write_permission = ActionObjectWRITE(uid=uid, credentials=credentials) - can_write = self.has_permission(write_permission) - - if not self.exists(uid=uid): - # attempt to claim it for writing + if self.exists(credentials=credentials, uid=uid): + permissions: list[ActionObjectPermission] = [] if has_result_read_permission: - ownership_result = self.take_ownership(uid=uid, credentials=credentials) - can_write = True if ownership_result.is_ok() else False + permissions.append(ActionObjectREAD(uid=uid, credentials=credentials)) else: - # root takes owneship, but you can still write - ownership_result = self.take_ownership( - uid=uid, credentials=self.root_verify_key + permissions.extend( + [ + ActionObjectWRITE(uid=uid, credentials=credentials), + ActionObjectEXECUTE(uid=uid, credentials=credentials), + ] ) - can_write = True if ownership_result.is_ok() else False - - if not can_write: - raise SyftException(public_message=f"Permission: {write_permission} denied") - - self.data[uid] = syft_object - if uid not in self.permissions: - # create default permissions - self.permissions[uid] = set() - if has_result_read_permission: - self.add_permission(ActionObjectREAD(uid=uid, credentials=credentials)) - else: - self.add_permissions( - [ - ActionObjectWRITE(uid=uid, credentials=credentials), - ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] - ) - - if uid not in self.storage_permissions: - # create default storage permissions - self.storage_permissions[uid] = set() - if add_storage_permission: - self.add_storage_permission( - StoragePermission(uid=uid, server_uid=self.server_uid) - ) - - return uid - - @as_result(SyftException) - def take_ownership(self, uid: UID, credentials: SyftVerifyKey) -> bool: - uid = uid.id # We only need the UID from LineageID or UID - - # first person using this UID can claim ownership - if uid in self.permissions or uid in self.data: - raise SyftException(public_message=f"Object {uid} already owned") - - self.add_permissions( - [ - ActionObjectOWNER(uid=uid, credentials=credentials), - ActionObjectWRITE(uid=uid, credentials=credentials), - ActionObjectREAD(uid=uid, credentials=credentials), - ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] - ) - - return True - - @as_result(StashException) - def delete(self, uid: UID, credentials: SyftVerifyKey) -> UID: - uid = uid.id # We only need the UID from LineageID or UID - - # if you delete something you need OWNER permission - # is it bad to evict a key and have someone else reuse it? - # perhaps we should keep permissions but no data? - owner_permission = ActionObjectOWNER(uid=uid, credentials=credentials) - - if not self.has_permission(owner_permission): - raise StashException( - public_message=f"Permission: {owner_permission} denied" - ) - - if uid in self.data: - del self.data[uid] - if uid in self.permissions: - del self.permissions[uid] - - return uid - - def has_permission(self, permission: ActionObjectPermission) -> bool: - if not isinstance(permission.permission, ActionPermission): - # If we reached this point, it's a malformed object error, let it bubble up - raise TypeError(f"ObjectPermission type: {permission.permission} not valid") - - if ( - permission.credentials is not None - and self.root_verify_key.verify == permission.credentials.verify - ): - return True - - if self.__user_stash is not None: - # relative - from ...service.user.user_roles import ServiceRole - - res = self.__user_stash.get_by_verify_key( - credentials=permission.credentials, - verify_key=permission.credentials, - ) - - if ( - res.is_ok() - and (user := res.ok()) is not None - and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - ): - return True - - if ( - permission.uid in self.permissions - and permission.permission_string in self.permissions[permission.uid] - ): - return True - - # 🟡 TODO 14: add ALL_READ, ALL_EXECUTE etc - if permission.permission == ActionPermission.OWNER: - pass - elif permission.permission == ActionPermission.READ: - pass - elif permission.permission == ActionPermission.WRITE: - pass - elif permission.permission == ActionPermission.EXECUTE: - pass - - return False - - def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: - return all(self.has_permission(p) for p in permissions) - - def add_permission(self, permission: ActionObjectPermission) -> None: - permissions = self.permissions[permission.uid] - permissions.add(permission.permission_string) - self.permissions[permission.uid] = permissions - - def remove_permission(self, permission: ActionObjectPermission) -> None: - permissions = self.permissions[permission.uid] - permissions.remove(permission.permission_string) - self.permissions[permission.uid] = permissions - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - for permission in permissions: - self.add_permission(permission) - - @as_result(ObjectCRUDPermissionException) - def _get_permissions_for_uid(self, uid: UID) -> set[str]: - if uid in self.permissions: - return self.permissions[uid] - raise ObjectCRUDPermissionException( - public_message=f"No permissions found for uid: {uid}" - ) - - @as_result(SyftException) - def get_all_permissions(self) -> dict[UID, set[str]]: - return dict(self.permissions.items()) - - def add_storage_permission(self, permission: StoragePermission) -> None: - permissions = self.storage_permissions[permission.uid] - permissions.add(permission.server_uid) - self.storage_permissions[permission.uid] = permissions - - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission) - - def remove_storage_permission(self, permission: StoragePermission) -> None: - permissions = self.storage_permissions[permission.uid] - permissions.remove(permission.server_uid) - self.storage_permissions[permission.uid] = permissions - - def has_storage_permission(self, permission: StoragePermission | UID) -> bool: - if isinstance(permission, UID): - permission = StoragePermission(uid=permission, server_uid=self.server_uid) - - if permission.uid in self.storage_permissions: - return permission.server_uid in self.storage_permissions[permission.uid] - - return False - - @as_result(ObjectCRUDPermissionException) - def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: - if uid in self.storage_permissions: - return self.storage_permissions[uid] - raise ObjectCRUDPermissionException(f"No storage permissions found for {uid}") - - @as_result(SyftException) - def get_all_storage_permissions(self) -> dict[UID, set[UID]]: - return dict(self.storage_permissions.items()) - - def _all( - self, - credentials: SyftVerifyKey, - has_permission: bool | None = False, - ) -> list[SyftObject]: - # this checks permissions - res = [self.get(uid, credentials, has_permission) for uid in self.data.keys()] - return [x.ok() for x in res if x.is_ok()] # type: ignore - - @as_result(ObjectCRUDPermissionException) - def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey) -> bool: - has_root_permission = credentials == self.root_verify_key - - if not has_root_permission: - raise ObjectCRUDPermissionException( - public_message="You don't have permissions to migrate data." - ) - - for key, value in self.data.items(): - try: - if value.__canonical_name__ != to_klass.__canonical_name__: - continue - migrated_value = value.migrate_to(to_klass.__version__) - except Exception as e: - raise SyftException.from_exception( - e, - public_message=f"Failed to migrate data to {to_klass} for qk: {key}", + storage_permission = [] + if add_storage_permission: + storage_permission.append( + StoragePermission(uid=uid, server_uid=self.server_uid) ) - self.set( - uid=key, + self.update( credentials=credentials, - syft_object=migrated_value, + obj=syft_object, ).unwrap() + self.add_permissions(permissions).unwrap() + self.add_storage_permissions(storage_permission).unwrap() + return uid - return True - - -@serializable(canonical_name="DictActionStore", version=1) -class DictActionStore(KeyValueActionStore): - """Dictionary-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - def __init__( - self, - server_uid: UID, - store_config: StoreConfig | None = None, - root_verify_key: SyftVerifyKey | None = None, - document_store: DocumentStore | None = None, - ) -> None: - store_config = store_config if store_config is not None else DictStoreConfig() - super().__init__( - server_uid=server_uid, - store_config=store_config, - root_verify_key=root_verify_key, - document_store=document_store, + owner_credentials = ( + credentials if has_result_read_permission else self.root_verify_key ) + # if not has_result_read_permission + # root takes owneship, but you can still write and execute + super().set( + credentials=owner_credentials, + obj=syft_object, + add_permissions=[ + ActionObjectWRITE(uid=uid, credentials=credentials), + ActionObjectEXECUTE(uid=uid, credentials=credentials), + ], + add_storage_permission=add_storage_permission, + ).unwrap() + return uid -@serializable(canonical_name="SQLiteActionStore", version=1) -class SQLiteActionStore(KeyValueActionStore): - """SQLite-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - SQLite specific configuration, including connection settings or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - pass - - -@serializable(canonical_name="MongoActionStore", version=1) -class MongoActionStore(KeyValueActionStore): - """Mongo-Based Action store. - - Parameters: - store_config: StoreConfig - Mongo specific configuration. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - pass + def set(self, *args, **kwargs): # type: ignore + raise Exception("Use `ActionObjectStash.set_or_update` instead.") diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 44567e8d8ef..1cebbd1e0fc 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -575,7 +575,7 @@ def exec_code( api_service = context.server.get_service("apiservice") api_service.stash.upsert( - context.server.services.user.admin_verify_key(), self + context.server.services.user.root_verify_key, self ).unwrap() print = original_print # type: ignore @@ -650,7 +650,7 @@ def code_string(context: TransformContext) -> TransformContext: ) context.server = cast(AbstractServer, context.server) - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key # If endpoint exists **AND** (has visible access **OR** the user is admin) if endpoint_type is not None and ( diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index a8c443a6271..8df26f0ac15 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -10,7 +10,7 @@ from ...serde.serializable import serializable from ...service.action.action_endpoint import CustomEndpointActionObject from ...service.action.action_object import ActionObject -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -37,11 +37,9 @@ @serializable(canonical_name="APIService", version=1) class APIService(AbstractService): - store: DocumentStore stash: TwinAPIEndpointStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = TwinAPIEndpointStash(store=store) @service_method( @@ -263,7 +261,7 @@ def api_endpoints( context: AuthedServiceContext, ) -> list[TwinAPIEndpointView]: """Retrieves a list of available API endpoints view available to the user.""" - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key all_api_endpoints = self.stash.get_all(admin_key).unwrap() api_endpoint_view = [ @@ -587,7 +585,7 @@ def execute_server_side_endpoint_mock_by_id( def get_endpoint_by_uid( self, context: AuthedServiceContext, uid: UID ) -> TwinAPIEndpoint: - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key return self.stash.get_by_uid(admin_key, uid).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 3a610c23cef..0c0c6f73020 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -1,11 +1,7 @@ -# stdlib - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -14,24 +10,21 @@ MISSING_PATH_STRING = "Endpoint path: {path} does not exist." -@serializable(canonical_name="TwinAPIEndpointStash", version=1) -class TwinAPIEndpointStash(NewBaseUIDStoreStash): - object_type = TwinAPIEndpoint - settings: PartitionSettings = PartitionSettings( - name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="TwinAPIEndpointSQLStash", version=1) +class TwinAPIEndpointStash(ObjectStash[TwinAPIEndpoint]): @as_result(StashException, NotFoundException) def get_by_path(self, credentials: SyftVerifyKey, path: str) -> TwinAPIEndpoint: - endpoints = self.get_all(credentials=credentials).unwrap() - for endpoint in endpoints: - if endpoint.path == path: - return endpoint + # TODO standardize by returning None if endpoint doesnt exist. + res = self.get_one( + credentials=credentials, + filters={"path": path}, + ) - raise NotFoundException(public_message=MISSING_PATH_STRING.format(path=path)) + if res.is_err(): + raise NotFoundException( + public_message=MISSING_PATH_STRING.format(path=path) + ) + return res.unwrap() @as_result(StashException) def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool: @@ -49,11 +42,9 @@ def upsert( has_permission: bool = False, ) -> TwinAPIEndpoint: """Upsert an endpoint.""" - path_exists = self.path_exists( - credentials=credentials, path=endpoint.path - ).unwrap() + exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap() - if path_exists: + if exists: super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap() return ( diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index debb838c4b1..93110f72d74 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.result import as_result from ...util.util import str_to_bool @@ -24,8 +24,8 @@ class AttestationService(AbstractService): """This service is responsible for getting all sorts of attestations for any client.""" - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass @as_result(SyftException) def perform_request( diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py index 25d6042bd09..76abe869ae6 100644 --- a/packages/syft/src/syft/service/blob_storage/remote_profile.py +++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py @@ -1,10 +1,10 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject +from ..service import AbstractService @serializable() @@ -24,12 +24,14 @@ class AzureRemoteProfile(RemoteProfile): container_name: str -@serializable(canonical_name="RemoteProfileStash", version=1) -class RemoteProfileStash(NewBaseUIDStoreStash): - object_type = RemoteProfile - settings: PartitionSettings = PartitionSettings( - name=RemoteProfile.__canonical_name__, object_type=RemoteProfile - ) +@serializable(canonical_name="RemoteProfileSQLStash", version=1) +class RemoteProfileStash(ObjectStash[RemoteProfile]): + pass - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) + +@serializable(canonical_name="RemoteProfileService", version=1) +class RemoteProfileService(AbstractService): + stash: RemoteProfileStash + + def __init__(self, store: DBManager) -> None: + self.stash = RemoteProfileStash(store=store) diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index c4bc955e29d..055e4d946e4 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,8 +11,7 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.document_store import DocumentStore -from ...store.document_store import UIDPartitionKey +from ...store.db.db import DBManager from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry @@ -38,12 +37,10 @@ @serializable(canonical_name="BlobStorageService", version=1) class BlobStorageService(AbstractService): - store: DocumentStore stash: BlobStorageStash remote_profile_stash: RemoteProfileStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = BlobStorageStash(store=store) self.remote_profile_stash = RemoteProfileStash(store=store) @@ -139,8 +136,8 @@ def mount_azure( def get_files_from_bucket( self, context: AuthedServiceContext, bucket_name: str ) -> list: - bse_list = self.stash.find_all( - context.credentials, bucket_name=bucket_name + bse_list = self.stash.get_all( + context.credentials, filters={"bucket_name": bucket_name} ).unwrap() blob_files = [] @@ -323,8 +320,8 @@ def delete(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess: public_message=f"Failed to delete blob file with id '{uid}'. Error: {e}" ) - self.stash.delete( - context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True + self.stash.delete_by_uid( + context.credentials, uid, has_permission=True ).unwrap() except Exception as e: raise SyftException( diff --git a/packages/syft/src/syft/service/blob_storage/stash.py b/packages/syft/src/syft/service/blob_storage/stash.py index 0ab4d7a8aa5..9cb002b7eb9 100644 --- a/packages/syft/src/syft/service/blob_storage/stash.py +++ b/packages/syft/src/syft/service/blob_storage/stash.py @@ -1,17 +1,9 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings +from ...store.db.stash import ObjectStash from ...types.blob_storage import BlobStorageEntry -@serializable(canonical_name="BlobStorageStash", version=1) -class BlobStorageStash(NewBaseUIDStoreStash): - object_type = BlobStorageEntry - settings: PartitionSettings = PartitionSettings( - name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="BlobStorageSQLStash", version=1) +class BlobStorageStash(ObjectStash[BlobStorageEntry]): + pass diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 1ffb70ebb6f..d6c1a56e801 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,14 +4,9 @@ # relative from ...serde.serializable import serializable -from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey -from ...store.document_store_errors import StashException -from ...types.result import as_result from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -23,35 +18,19 @@ from .user_code import UserCodeStatusCollection -@serializable(canonical_name="StatusStash", version=1) -class StatusStash(NewBaseUIDStoreStash): - object_type = UserCodeStatusCollection +@serializable(canonical_name="StatusSQLStash", version=1) +class StatusStash(ObjectStash[UserCodeStatusCollection]): settings: PartitionSettings = PartitionSettings( name=UserCodeStatusCollection.__canonical_name__, object_type=UserCodeStatusCollection, ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store) - self.store = store - self.settings = self.settings - self._object_type = self.object_type - - @as_result(StashException) - def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> UserCodeStatusCollection: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @serializable(canonical_name="UserCodeStatusService", version=1) class UserCodeStatusService(AbstractService): - store: DocumentStore stash: StatusStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = StatusStash(store=store) @service_method(path="code_status.create", name="create", roles=ADMIN_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 0921ea9e704..81e586d5c02 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -46,7 +46,6 @@ from ...serde.signature import signature_remove_context from ...serde.signature import signature_remove_self from ...server.credentials import SyftVerifyKey -from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.dicttuple import DictTuple @@ -107,11 +106,6 @@ # relative from ...service.sync.diff_state import AttrDiff -UserVerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) -CodeHashPartitionKey = PartitionKey(key="code_hash", type_=str) -ServiceFuncNamePartitionKey = PartitionKey(key="service_func_name", type_=str) -SubmitTimePartitionKey = PartitionKey(key="submit_time", type_=DateTime) - PyCodeObject = Any diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 32863990a6a..79490223f45 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -60,11 +60,9 @@ class IsExecutionAllowedEnum(str, Enum): @serializable(canonical_name="UserCodeService", version=1) class UserCodeService(AbstractService): - store: DocumentStore stash: UserCodeStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserCodeStash(store=store) @service_method( @@ -77,11 +75,12 @@ def submit( self, context: AuthedServiceContext, code: SubmitUserCode ) -> SyftSuccess: """Add User Code""" - user_code = self._submit(context, code, exists_ok=False) + user_code = self._submit(context, code, exists_ok=False).unwrap() return SyftSuccess( message="User Code Submitted", require_api_update=True, value=user_code ) + @as_result(SyftException) def _submit( self, context: AuthedServiceContext, @@ -107,17 +106,17 @@ def _submit( context.credentials, code_hash=get_code_hash(submit_code.code, context.credentials), ).unwrap() - - if not exists_ok: + # no exception, code exists + if exists_ok: + return existing_code + else: raise SyftException( - public_message="The code to be submitted already exists" + public_message="UserCode with this code already exists" ) - return existing_code except NotFoundException: pass code = submit_code.to(UserCode, context=context) - result = self._post_user_code_transform_ops(context, code) if result.is_err(): @@ -281,17 +280,11 @@ def _get_or_submit_user_code( - If the code is a SubmitUserCode and the code hash does not exist, submit the code """ if isinstance(code, UserCode): - # Get existing UserCode - try: - return self.stash.get_by_uid(context.credentials, code.id).unwrap() - except NotFoundException as exc: - raise NotFoundException.from_exception( - exc, public_message=f"UserCode {code.id} not found on this server" - ) + return self.stash.get_by_uid(context.credentials, code.id).unwrap() else: # code: SubmitUserCode # Submit new UserCode, or get existing UserCode with the same code hash # TODO: Why is this tagged as unreachable? - return self._submit(context, code, exists_ok=True) # type: ignore[unreachable] + return self._submit(context, code, exists_ok=True).unwrap() # type: ignore[unreachable] @service_method( path="code.request_code_execution", diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 308de4d28bf..232342bd8d5 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -1,47 +1,32 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from .user_code import CodeHashPartitionKey -from .user_code import ServiceFuncNamePartitionKey -from .user_code import SubmitTimePartitionKey from .user_code import UserCode -from .user_code import UserVerifyKeyPartitionKey -@serializable(canonical_name="UserCodeStash", version=1) -class UserCodeStash(NewBaseUIDStoreStash): - object_type = UserCode +@serializable(canonical_name="UserCodeSQLStash", version=1) +class UserCodeStash(ObjectStash[UserCode]): settings: PartitionSettings = PartitionSettings( name=UserCode.__canonical_name__, object_type=UserCode ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - @as_result(StashException, NotFoundException) - def get_all_by_user_verify_key( - self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey - ) -> list[UserCode]: - qks = QueryKeys(qks=[UserVerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @as_result(StashException, NotFoundException) def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode: - qks = QueryKeys(qks=[CodeHashPartitionKey.with_obj(code_hash)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"code_hash": code_hash}, + ).unwrap() @as_result(StashException) def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> list[UserCode]: - qks = QueryKeys(qks=[ServiceFuncNamePartitionKey.with_obj(service_func_name)]) - return self.query_all( - credentials=credentials, qks=qks, order_by=SubmitTimePartitionKey + return self.get_all( + credentials=credentials, + filters={"service_func_name": service_func_name}, ).unwrap() diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index a3d06bdb4fa..ff2967f169b 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...types.uid import UID from ..code.user_code import SubmitUserCode @@ -24,11 +24,9 @@ @serializable(canonical_name="CodeHistoryService", version=1) class CodeHistoryService(AbstractService): - store: DocumentStore stash: CodeHistoryStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = CodeHistoryStash(store=store) @service_method( @@ -46,7 +44,7 @@ def submit_version( if isinstance(code, SubmitUserCode): code = context.server.services.user_code._submit( context=context, submit_code=code - ) + ).unwrap() try: code_history = self.stash.get_by_service_func_name_and_verify_key( @@ -189,11 +187,13 @@ def get_by_func_name_and_user_email( ) -> list[CodeHistory]: user_verify_key = context.server.services.user.user_verify_key(user_email) - kwargs = { + filters = { "id": user_id, "email": user_email, "verify_key": user_verify_key, "service_func_name": service_func_name, } - return self.stash.find_all(credentials=context.credentials, **kwargs).unwrap() + return self.stash.get_all( + credentials=context.credentials, filters=filters + ).unwrap() diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index d03358feb83..69dfd272717 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -1,56 +1,43 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result from .code_history import CodeHistory -NamePartitionKey = PartitionKey(key="service_func_name", type_=str) -VerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) - - -@serializable(canonical_name="CodeHistoryStash", version=1) -class CodeHistoryStash(NewBaseUIDStoreStash): - object_type = CodeHistory - settings: PartitionSettings = PartitionSettings( - name=CodeHistory.__canonical_name__, object_type=CodeHistory - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="CodeHistoryStashSQL", version=1) +class CodeHistoryStash(ObjectStash[CodeHistory]): @as_result(StashException) def get_by_service_func_name_and_verify_key( self, credentials: SyftVerifyKey, service_func_name: str, user_verify_key: SyftVerifyKey, - ) -> list[CodeHistory]: - qks = QueryKeys( - qks=[ - NamePartitionKey.with_obj(service_func_name), - VerifyKeyPartitionKey.with_obj(user_verify_key), - ] - ) - return self.query_one(credentials=credentials, qks=qks).unwrap() + ) -> CodeHistory: + return self.get_one( + credentials=credentials, + filters={ + "user_verify_key": user_verify_key, + "service_func_name": service_func_name, + }, + ).unwrap() @as_result(StashException) def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> list[CodeHistory]: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(service_func_name)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"service_func_name": service_func_name}, + ).unwrap() @as_result(StashException) def get_by_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> list[CodeHistory]: - if isinstance(user_verify_key, str): - user_verify_key = SyftVerifyKey.from_string(user_verify_key) - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"user_verify_key": user_verify_key}, + ).unwrap() diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index c54e8bc67ea..e7e482a4337 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,10 +3,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -14,44 +12,35 @@ from ..service import AbstractService from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE -from .data_subject_member import ChildPartitionKey from .data_subject_member import DataSubjectMemberRelationship -from .data_subject_member import ParentPartitionKey -@serializable(canonical_name="DataSubjectMemberStash", version=1) -class DataSubjectMemberStash(NewBaseUIDStoreStash): - object_type = DataSubjectMemberRelationship - settings: PartitionSettings = PartitionSettings( - name=DataSubjectMemberRelationship.__canonical_name__, - object_type=DataSubjectMemberRelationship, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="DataSubjectMemberSQLStash", version=1) +class DataSubjectMemberStash(ObjectStash[DataSubjectMemberRelationship]): @as_result(StashException) def get_all_for_parent( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - qks = QueryKeys(qks=[ParentPartitionKey.with_obj(name)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"parent": name}, + ).unwrap() @as_result(StashException) def get_all_for_child( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - qks = QueryKeys(qks=[ChildPartitionKey.with_obj(name)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"child": name}, + ).unwrap() @serializable(canonical_name="DataSubjectMemberService", version=1) class DataSubjectMemberService(AbstractService): - store: DocumentStore stash: DataSubjectMemberStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DataSubjectMemberStash(store=store) def add( diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index b8d5e6b8528..ecde100edf5 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,10 +5,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -19,43 +17,23 @@ from ..service import service_method from .data_subject import DataSubject from .data_subject import DataSubjectCreate -from .data_subject import NamePartitionKey -@serializable(canonical_name="DataSubjectStash", version=1) -class DataSubjectStash(NewBaseUIDStoreStash): - object_type = DataSubject - settings: PartitionSettings = PartitionSettings( - name=DataSubject.__canonical_name__, object_type=DataSubject - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="DataSubjectSQLStash", version=1) +class DataSubjectStash(ObjectStash[DataSubject]): @as_result(StashException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) - return self.query_one(credentials, qks=qks).unwrap() - - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - data_subject: DataSubject, - has_permission: bool = False, - ) -> DataSubject: - res = self.check_type(data_subject, DataSubject).unwrap() - # we dont use and_then logic here as it is hard because of the order of the arguments - return super().update(credentials=credentials, obj=res).unwrap() + return self.get_one( + credentials=credentials, + filters={"name": name}, + ).unwrap() @serializable(canonical_name="DataSubjectService", version=1) class DataSubjectService(AbstractService): - store: DocumentStore stash: DataSubjectStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DataSubjectStash(store=store) @service_method(path="data_subject.add", name="add_data_subject") diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 6c4a91e06c2..fbc5318be57 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -7,6 +7,7 @@ from typing import Any # third party +from IPython.display import display import markdown import pandas as pd from pydantic import ConfigDict @@ -291,8 +292,8 @@ def _private_data(self) -> Any: def data(self) -> Any: try: return self._private_data().unwrap() - except SyftException as e: - print(e) + except SyftException: + display(SyftError(message="You have no access to the private data")) return None @@ -533,6 +534,7 @@ def _repr_html_(self) -> Any: {self.assets._repr_html_()} """ + @property def action_ids(self) -> list[UID]: return [asset.action_id for asset in self.asset_list if asset.action_id] diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 43cbfacb117..a25a49ee60d 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.dicttuple import DictTuple from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission @@ -25,7 +25,6 @@ from .dataset import CreateDataset from .dataset import Dataset from .dataset import DatasetPageView -from .dataset import DatasetUpdate from .dataset_stash import DatasetStash logger = logging.getLogger(__name__) @@ -69,11 +68,9 @@ def _paginate_dataset_collection( @serializable(canonical_name="DatasetService", version=1) class DatasetService(AbstractService): - store: DocumentStore stash: DatasetStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DatasetStash(store=store) @service_method( @@ -117,13 +114,11 @@ def get_all( page_index: int | None = 0, ) -> DatasetPageView | DictTuple[str, Dataset]: """Get a Dataset""" - datasets = self.stash.get_all(context.credentials).unwrap() + datasets = self.stash.get_all_active(context.credentials).unwrap() for dataset in datasets: if context.server is not None: dataset.server_uid = context.server.id - if dataset.to_be_deleted: - datasets.remove(dataset) return _paginate_dataset_collection( datasets=datasets, page_size=page_size, page_index=page_index @@ -234,11 +229,9 @@ def delete( return_msg.append(f"Asset with id '{asset.id}' successfully deleted.") # soft delete the dataset object from the store - dataset_update = DatasetUpdate( - id=uid, name=f"_deleted_{dataset.name}_{uid}", to_be_deleted=True - ) - self.stash.update(context.credentials, dataset_update).unwrap() - + dataset.name = f"_deleted_{dataset.name}_{uid}" + dataset.to_be_deleted = True + self.stash.update(context.credentials, dataset).unwrap() return_msg.append(f"Dataset with id '{uid}' successfully deleted.") return SyftSuccess(message="\n".join(return_msg)) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 19fc33c5906..aee2a280372 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -1,64 +1,56 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID from .dataset import Dataset -from .dataset import DatasetUpdate -NamePartitionKey = PartitionKey(key="name", type_=str) -ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) - - -@serializable(canonical_name="DatasetStash", version=1) -class DatasetStash(NewBaseUIDStoreStash): - object_type = Dataset - settings: PartitionSettings = PartitionSettings( - name=Dataset.__canonical_name__, object_type=Dataset - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="DatasetStashSQL", version=1) +class DatasetStash(ObjectStash[Dataset]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one(credentials=credentials, filters={"name": name}).unwrap() @as_result(StashException) def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]: - qks = QueryKeys(qks=[ActionIDsPartitionKey.with_obj(uid)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() - - @as_result(StashException) - def get_all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool = False, - ) -> list: - result = super().get_all(credentials, order_by, has_permission).unwrap() - filtered_datasets = [dataset for dataset in result if not dataset.to_be_deleted] - return filtered_datasets + return self.get_all_active( + credentials=credentials, + filters={"action_ids__contains": uid}, + ).unwrap() - # FIX: This shouldn't be the update method, it just marks the dataset for deletion @as_result(StashException) - def update( + def get_all_active( self, credentials: SyftVerifyKey, - obj: DatasetUpdate, has_permission: bool = False, - ) -> Dataset: - _obj = self.check_type(obj, DatasetUpdate).unwrap() - # FIX: This method needs a revamp - qk = self.partition.store_query_key(obj) - return self.partition.update( - credentials=credentials, qk=qk, obj=_obj, has_permission=has_permission - ).unwrap() + order_by: str | None = None, + sort_order: str | None = None, + limit: int | None = None, + offset: int | None = None, + filters: dict | None = None, + ) -> list[Dataset]: + # TODO standardize soft delete and move to ObjectStash.get_all + default_filters = {"to_be_deleted": False} + filters = filters or {} + filters.update(default_filters) + + if offset is None: + offset = 0 + + return ( + super() + .get_all( + credentials=credentials, + filters=filters, + has_permission=has_permission, + order_by=order_by, + sort_order=sort_order, + limit=limit, + offset=offset, + ) + .unwrap() + ) diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index 2f88c60e123..064c8806f91 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -2,13 +2,11 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ..service import AbstractService @serializable(canonical_name="EnclaveService", version=1) class EnclaveService(AbstractService): - store: DocumentStore - - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index fbbdfd7d856..5cff02ecb74 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..action.action_object import ActionObject @@ -40,11 +40,9 @@ def wait_until(predicate: Callable[[], bool], timeout: int = 10) -> SyftSuccess: @serializable(canonical_name="JobService", version=1) class JobService(AbstractService): - store: DocumentStore stash: JobStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = JobStash(store=store) @service_method( @@ -132,7 +130,7 @@ def restart(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess: context.credentials, queue_item ).unwrap() - context.server.job_stash.set(context.credentials, job).unwrap() + self.stash.set(context.credentials, job).unwrap() context.server.services.log.restart(context, job.log_id) return SyftSuccess(message="Great Success!") @@ -218,7 +216,7 @@ def add_read_permission_job_for_code_owner( job.id, ActionPermission.READ, user_code.user_verify_key ) # TODO: make add_permission wrappable - return self.stash.add_permission(permission=permission) + return self.stash.add_permission(permission=permission).unwrap() @service_method( path="job.add_read_permission_log_for_code_owner", @@ -232,7 +230,7 @@ def add_read_permission_log_for_code_owner( ActionObjectPermission( log_id, ActionPermission.READ, user_code.user_verify_key ) - ) + ).unwrap() @service_method( path="job.create_job_for_user_code_id", diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index fc83b675503..358834470fb 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -1,5 +1,4 @@ # stdlib -from collections.abc import Callable from datetime import datetime from datetime import timedelta from datetime import timezone @@ -20,26 +19,19 @@ from ...server.credentials import SyftVerifyKey from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException -from ...store.document_store_errors import TooManyItemsFoundException from ...types.datetime import DateTime from ...types.datetime import format_timedelta from ...types.errors import SyftException from ...types.result import Err from ...types.result import as_result -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject -from ...types.transforms import make_set_default from ...types.uid import UID from ...util.markdown import as_markdown_code from ...util.util import prompt_warning_message @@ -125,6 +117,7 @@ class Job(SyncableSyftObject): "user_code_id", "result_id", ] + __repr_attrs__ = [ "id", "result", @@ -133,6 +126,7 @@ class Job(SyncableSyftObject): "creation_time", "user_code_name", ] + __exclude_sync_diff_attrs__ = ["action", "server_uid"] __table_coll_widths__ = [ "min-content", @@ -740,16 +734,12 @@ def from_job( return info -@serializable(canonical_name="JobStash", version=1) -class JobStash(NewBaseUIDStoreStash): - object_type = Job +@serializable(canonical_name="JobStashSQL", version=1) +class JobStash(ObjectStash[Job]): settings: PartitionSettings = PartitionSettings( name=Job.__canonical_name__, object_type=Job ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - @as_result(StashException) def set_result( self, @@ -764,98 +754,40 @@ def set_result( and item.result.syft_blob_storage_entry_id is not None ): item.result._clear_cache() - return ( - super() - .update(credentials, item, add_permissions) - .unwrap(public_message="Failed to update") - ) - - @as_result(StashException) - def get_by_result_id( - self, - credentials: SyftVerifyKey, - result_id: UID, - ) -> Job: - qks = QueryKeys( - qks=[PartitionKey(key="result_id", type_=UID).with_obj(result_id)] - ) - res = self.query_all(credentials=credentials, qks=qks).unwrap() - - if len(res) == 0: - raise NotFoundException() - elif len(res) > 1: - raise TooManyItemsFoundException() - else: - return res[0] - - @as_result(StashException) - def get_by_parent_id(self, credentials: SyftVerifyKey, uid: UID) -> list[Job]: - qks = QueryKeys( - qks=[PartitionKey(key="parent_job_id", type_=UID).with_obj(uid)] + return self.update(credentials, item, add_permissions).unwrap( + public_message="Failed to update" ) - return self.query_all(credentials=credentials, qks=qks).unwrap() - - @as_result(StashException) - def delete_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> bool: # type: ignore[override] - qk = UIDPartitionKey.with_obj(uid) - return super().delete(credentials=credentials, qk=qk).unwrap() - @as_result(StashException) def get_active(self, credentials: SyftVerifyKey) -> list[Job]: - qks = QueryKeys( - qks=[ - PartitionKey(key="status", type_=JobStatus).with_obj( - JobStatus.PROCESSING - ) - ] - ) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"status": JobStatus.CREATED}, + ).unwrap() - @as_result(StashException) def get_by_worker(self, credentials: SyftVerifyKey, worker_id: str) -> list[Job]: - qks = QueryKeys( - qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)] - ) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"job_worker_id": worker_id}, + ).unwrap() @as_result(StashException) def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[Job]: - qks = QueryKeys( - qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)] - ) - return self.query_all(credentials=credentials, qks=qks).unwrap() - - -@serializable() -class JobV1(SyncableSyftObject): - __canonical_name__ = "JobItem" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - server_uid: UID - result: Any | None = None - resolved: bool = False - status: JobStatus = JobStatus.CREATED - log_id: UID | None = None - parent_job_id: UID | None = None - n_iters: int | None = 0 - current_iter: int | None = None - creation_time: str | None = Field( - default_factory=lambda: str(datetime.now(tz=timezone.utc)) - ) - action: Action | None = None - job_pid: int | None = None - job_worker_id: UID | None = None - updated_at: DateTime | None = None - user_code_id: UID | None = None - requested_by: UID | None = None - job_type: JobType = JobType.JOB + return self.get_all( + credentials=credentials, + filters={"user_code_id": user_code_id}, + ).unwrap() + @as_result(StashException) + def get_by_parent_id(self, credentials: SyftVerifyKey, uid: UID) -> list[Job]: + return self.get_all( + credentials=credentials, + filters={"parent_job_id": uid}, + ).unwrap() -@migrate(JobV1, Job) -def migrate_job_update_v1_current() -> list[Callable]: - return [ - make_set_default("endpoint", None), - ] + @as_result(StashException) + def get_by_result_id(self, credentials: SyftVerifyKey, uid: UID) -> Job: + return self.get_one( + credentials=credentials, filters={"result_id": uid} + ).unwrap() diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index d3529b0906f..d4b96a0deed 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.uid import UID from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext @@ -16,11 +16,9 @@ @serializable(canonical_name="LogService", version=1) class LogService(AbstractService): - store: DocumentStore stash: LogStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = LogStash(store=store) @service_method(path="log.add", name="add", roles=DATA_SCIENTIST_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index c4072bfcfa5..ef50b081c24 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,17 +1,9 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings +from ...store.db.stash import ObjectStash from .log import SyftLog @serializable(canonical_name="LogStash", version=1) -class LogStash(NewBaseUIDStoreStash): - object_type = SyftLog - settings: PartitionSettings = PartitionSettings( - name=SyftLog.__canonical_name__, object_type=SyftLog - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +class LogStash(ObjectStash[SyftLog]): + pass diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 70453d9b084..b7b450b037b 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method @@ -12,8 +12,8 @@ @serializable(canonical_name="MetadataService", version=1) class MetadataService(AbstractService): - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass @service_method( path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 09848789559..b1d461e4e80 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,25 +1,25 @@ # stdlib from collections import defaultdict -from typing import cast # syft absolute import syft # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import StorePartition +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException from ...types.result import as_result from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry +from ...types.twin_object import TwinObject from ..action.action_object import Action from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import StoragePermission -from ..action.action_store import KeyValueActionStore +from ..action.action_store import ActionObjectStash from ..context import AuthedServiceContext from ..response import SyftSuccess from ..service import AbstractService @@ -34,11 +34,9 @@ @serializable(canonical_name="MigrationService", version=1) class MigrationService(AbstractService): - store: DocumentStore stash: SyftMigrationStateStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftMigrationStateStash(store=store) @service_method(path="migration", name="get_version") @@ -75,9 +73,7 @@ def register_migration_state( obj = SyftObjectMigrationState( current_version=current_version, canonical_name=canonical_name ) - return self.stash.set( - migration_state=obj, credentials=context.credentials - ).unwrap() + return self.stash.set(obj=obj, credentials=context.credentials).unwrap() @as_result(SyftException, NotFoundException) def _find_klasses_pending_for_migration( @@ -120,83 +116,37 @@ def get_all_store_metadata( return self._get_all_store_metadata( context, document_store_object_types=document_store_object_types, - include_action_store=include_action_store, ).unwrap() - @as_result(SyftException) - def _get_partition_from_type( - self, - context: AuthedServiceContext, - object_type: type[SyftObject], - ) -> KeyValueActionStore | StorePartition: - object_partition: KeyValueActionStore | StorePartition | None = None - if issubclass(object_type, ActionObject): - object_partition = cast(KeyValueActionStore, context.server.action_store) - else: - canonical_name = object_type.__canonical_name__ # type: ignore[unreachable] - object_partition = self.store.partitions.get(canonical_name) - - if object_partition is None: - raise SyftException( - public_message=f"Object partition not found for {object_type}" - ) # type: ignore - - return object_partition - - @as_result(SyftException) - def _get_store_metadata( - self, - context: AuthedServiceContext, - object_type: type[SyftObject], - ) -> StoreMetadata: - object_partition = self._get_partition_from_type(context, object_type).unwrap() - permissions = dict(object_partition.get_all_permissions().unwrap()) - storage_permissions = dict( - object_partition.get_all_storage_permissions().unwrap() - ) - return StoreMetadata( - object_type=object_type, - permissions=permissions, - storage_permissions=storage_permissions, - ) - @as_result(SyftException) def _get_all_store_metadata( self, context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, - include_action_store: bool = True, ) -> dict[type[SyftObject], StoreMetadata]: - if document_store_object_types is None: - document_store_object_types = self.store.get_partition_object_types() - + # metadata = permissions + storage permissions + stashes = context.server.services.stashes store_metadata = {} - for klass in document_store_object_types: - store_metadata[klass] = self._get_store_metadata(context, klass).unwrap() - if include_action_store: - store_metadata[ActionObject] = self._get_store_metadata( - context, ActionObject - ).unwrap() - return store_metadata + for klass, stash in stashes.items(): + if ( + document_store_object_types is not None + and klass not in document_store_object_types + ): + continue + store_metadata[klass] = StoreMetadata( + object_type=klass, + permissions=stash.get_all_permissions().unwrap(), + storage_permissions=stash.get_all_storage_permissions().unwrap(), + ) - @service_method( - path="migration.update_store_metadata", - name="update_store_metadata", - roles=ADMIN_ROLE_LEVEL, - ) - def update_store_metadata( - self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] - ) -> None: - return self._update_store_metadata(context, store_metadata).unwrap() + return store_metadata @as_result(SyftException) def _update_store_metadata_for_klass( self, context: AuthedServiceContext, metadata: StoreMetadata ) -> None: - object_partition = self._get_partition_from_type( - context, metadata.object_type - ).unwrap() + stash = self._search_stash_for_klass(context, metadata.object_type).unwrap() permissions = [ ActionObjectPermission.from_permission_string(uid, perm_str) for uid, perm_strs in metadata.permissions.items() @@ -209,8 +159,8 @@ def _update_store_metadata_for_klass( for server_uid in server_uids ] - object_partition.add_permissions(permissions) - object_partition.add_storage_permissions(storage_permissions) + stash.add_permissions(permissions, ignore_missing=True).unwrap() + stash.add_storage_permissions(storage_permissions, ignore_missing=True).unwrap() @as_result(SyftException) def _update_store_metadata( @@ -220,21 +170,6 @@ def _update_store_metadata( for metadata in store_metadata.values(): self._update_store_metadata_for_klass(context, metadata).unwrap() - @service_method( - path="migration.get_migration_objects", - name="get_migration_objects", - roles=ADMIN_ROLE_LEVEL, - ) - def get_migration_objects( - self, - context: AuthedServiceContext, - document_store_object_types: list[type[SyftObject]] | None = None, - get_all: bool = False, - ) -> dict: - return self._get_migration_objects( - context, document_store_object_types, get_all - ).unwrap() - @as_result(SyftException) def _get_migration_objects( self, @@ -243,7 +178,7 @@ def _get_migration_objects( get_all: bool = False, ) -> dict[type[SyftObject], list[SyftObject]]: if document_store_object_types is None: - document_store_object_types = self.store.get_partition_object_types() + document_store_object_types = list(context.server.services.stashes.keys()) if get_all: klasses_to_migrate = document_store_object_types @@ -255,14 +190,12 @@ def _get_migration_objects( result = defaultdict(list) for klass in klasses_to_migrate: - canonical_name = klass.__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) - if object_partition is None: + stash_or_err = self._search_stash_for_klass(context, klass) + if stash_or_err.is_err(): continue - objects = object_partition.all( - context.credentials, has_permission=True - ).unwrap() - for object in objects: + stash = stash_or_err.unwrap() + + for object in stash._data: actual_klass = type(object) use_klass = ( klass @@ -274,24 +207,33 @@ def _get_migration_objects( return dict(result) @as_result(SyftException) - def _search_partition_for_object( - self, context: AuthedServiceContext, obj: SyftObject - ) -> StorePartition: - klass = type(obj) + def _search_stash_for_klass( + self, context: AuthedServiceContext, klass: type[SyftObject] + ) -> ObjectStash: + if issubclass(klass, ActionObject | TwinObject | Action): + return context.server.services.action.stash + + stashes: dict[str, ObjectStash] = { # type: ignore + t.__canonical_name__: stash + for t, stash in context.server.services.stashes.items() + } + mro = klass.__mro__ class_index = 0 - object_partition = None + object_stash = None while len(mro) > class_index: - canonical_name = mro[class_index].__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) - if object_partition is not None: + try: + canonical_name = mro[class_index].__canonical_name__ + except AttributeError: + # Classes without cname dont have a stash + break + object_stash = stashes.get(canonical_name) + if object_stash is not None: break class_index += 1 - if object_partition is None: - raise SyftException( - public_message=f"Object partition not found for {klass}" - ) - return object_partition + if object_stash is None: + raise SyftException(public_message=f"Object stash not found for {klass}") + return object_stash @service_method( path="migration.create_migrated_objects", @@ -316,12 +258,11 @@ def _create_migrated_objects( ignore_existing: bool = True, ) -> SyftSuccess: for migrated_object in migrated_objects: - object_partition = self._search_partition_for_object( - context, migrated_object + stash = self._search_stash_for_klass( + context, type(migrated_object) ).unwrap() - # upsert the object - result = object_partition.set( + result = stash.set( context.credentials, obj=migrated_object, ) @@ -340,34 +281,20 @@ def _create_migrated_objects( result.unwrap() # this will raise the exception inside the wrapper return SyftSuccess(message="Created migrate objects!") - @service_method( - path="migration.update_migrated_objects", - name="update_migrated_objects", - roles=ADMIN_ROLE_LEVEL, - ) - def update_migrated_objects( - self, context: AuthedServiceContext, migrated_objects: list[SyftObject] - ) -> None: - self._update_migrated_objects(context, migrated_objects).unwrap() - @as_result(SyftException) def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - object_partition = self._search_partition_for_object( - context, migrated_object + stash = self._search_stash_for_klass( + context, type(migrated_object) ).unwrap() - qk = object_partition.settings.store_key.with_obj(migrated_object.id) - object_partition._update( + stash.update( context.credentials, - qk=qk, obj=migrated_object, - has_permission=True, - overwrite=True, - allow_missing_keys=True, ).unwrap() + return SyftSuccess(message="Updated migration objects!") @as_result(SyftException) @@ -405,33 +332,18 @@ def migrate_data( context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, ) -> SyftSuccess: - # Track all object type that need migration for document store - - # get all objects, keyed by type (because we might want to have different rules for different types) - # Q: will this be tricky with the protocol???? - # A: For now we will assume that the client will have the same version - - # Then, locally we write stuff that says - # for klass, objects in migration_dict.items(): - # for object in objects: - # if isinstance(object, X): - # do something custom - # else: - # migrated_value = object.migrate_to(klass.__version__, context) - # - # migrated_values = [SyftObject] - # client.migration.write_migrated_values(migrated_values) - migration_objects = self._get_migration_objects( context, document_store_object_types ).unwrap() migrated_objects = self._migrate_objects(context, migration_objects).unwrap() self._update_migrated_objects(context, migrated_objects).unwrap() + migration_actionobjects = self._get_migration_actionobjects(context).unwrap() migrated_actionobjects = self._migrate_objects( context, migration_actionobjects ).unwrap() self._update_migrated_actionobjects(context, migrated_actionobjects).unwrap() + return SyftSuccess(message="Data upgraded to the latest version") @service_method( @@ -449,7 +361,7 @@ def _get_migration_actionobjects( self, context: AuthedServiceContext, get_all: bool = False ) -> dict[type[SyftObject], list[SyftObject]]: # Track all object types from action store - action_object_types = [Action, ActionObject] + action_object_types = [Action, ActionObject, TwinObject] action_object_types.extend(ActionObject.__subclasses__()) klass_by_canonical_name: dict[str, type[SyftObject]] = { klass.__canonical_name__: klass for klass in action_object_types @@ -459,10 +371,8 @@ def _get_migration_actionobjects( context=context, object_types=action_object_types ).unwrap() result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list) - action_store = context.server.action_store - action_store_objects = action_store._all( - context.credentials, has_permission=True - ) + action_stash = context.server.services.action.stash + action_store_objects = action_stash.get_all(context.credentials).unwrap() for obj in action_store_objects: if get_all or type(obj) in action_object_pending_migration: @@ -470,26 +380,16 @@ def _get_migration_actionobjects( result_dict[klass].append(obj) # type: ignore return dict(result_dict) - @service_method( - path="migration.update_migrated_actionobjects", - name="update_migrated_actionobjects", - roles=ADMIN_ROLE_LEVEL, - ) - def update_migrated_actionobjects( - self, context: AuthedServiceContext, objects: list[SyftObject] - ) -> SyftSuccess: - self._update_migrated_actionobjects(context, objects).unwrap() - return SyftSuccess(message="succesfully migrated actionobjects") - @as_result(SyftException) def _update_migrated_actionobjects( self, context: AuthedServiceContext, objects: list[SyftObject] ) -> str: - # Track all object types from action store - action_store = context.server.action_store + action_store: ActionObjectStash = context.server.services.action.stash for obj in objects: - action_store.set( - uid=obj.id, credentials=context.credentials, syft_object=obj + action_store.set_or_update( + uid=obj.id, + credentials=context.credentials, + syft_object=obj, ).unwrap() return "success" diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index b7fa2bb7cbd..22363d867f2 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -15,10 +15,8 @@ from ...serde.serialize import _serialize from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.blob_storage import CreateBlobStorageEntry @@ -34,7 +32,6 @@ from ...types.transforms import make_set_default from ...types.uid import UID from ...util.util import prompt_warning_message -from ..action.action_permissions import ActionObjectPermission from ..response import SyftSuccess from ..worker.utils import DEFAULT_WORKER_POOL_NAME from ..worker.worker_image import SyftWorkerImage @@ -70,45 +67,16 @@ def supported_versions(self) -> list: KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) -@serializable(canonical_name="SyftMigrationStateStash", version=1) -class SyftMigrationStateStash(NewBaseStash): - object_type = SyftObjectMigrationState - settings: PartitionSettings = PartitionSettings( - name=SyftObjectMigrationState.__canonical_name__, - object_type=SyftObjectMigrationState, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - @as_result(SyftException) - def set( # type: ignore [override] - self, - credentials: SyftVerifyKey, - migration_state: SyftObjectMigrationState, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObjectMigrationState: - obj = self.check_type(migration_state, self.object_type).unwrap() - return ( - super() - .set( - credentials=credentials, - obj=obj, - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ignore_duplicates=ignore_duplicates, - ) - .unwrap() - ) - +@serializable(canonical_name="SyftMigrationStateSQLStash", version=1) +class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]): @as_result(SyftException, NotFoundException) def get_by_name( self, canonical_name: str, credentials: SyftVerifyKey ) -> SyftObjectMigrationState: - qks = KlassNamePartitionKey.with_obj(canonical_name) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"canonical_name": canonical_name}, + ).unwrap() @serializable() @@ -276,6 +244,14 @@ def get_items_by_canonical_name(self, canonical_name: str) -> list[SyftObject]: return v return [] + def get_metadata_by_canonical_name(self, canonical_name: str) -> StoreMetadata: + for k, v in self.metadata.items(): + if k.__canonical_name__ == canonical_name: + return v + return StoreMetadata( + object_type=SyftObject, permissions={}, storage_permissions={} + ) + def copy_without_workerpools(self) -> "MigrationData": items_to_exclude = [ WorkerPool.__canonical_name__, diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 428501fb92d..aa2264dbe10 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,11 +15,8 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -36,7 +33,6 @@ from ...util.util import prompt_warning_message from ...util.util import str_to_bool from ..context import AuthedServiceContext -from ..data_subject.data_subject import NamePartitionKey from ..metadata.server_metadata import ServerMetadata from ..request.request import Request from ..request.request import RequestStatus @@ -61,10 +57,6 @@ logger = logging.getLogger(__name__) -VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) -ServerTypePartitionKey = PartitionKey(key="server_type", type_=ServerType) -OrderByNamePartitionKey = PartitionKey(key="name", type_=str) - REVERSE_TUNNEL_ENABLED = "REVERSE_TUNNEL_ENABLED" @@ -79,40 +71,20 @@ class ServerPeerAssociationStatus(Enum): PEER_NOT_FOUND = "PEER_NOT_FOUND" -@serializable(canonical_name="NetworkStash", version=1) -class NetworkStash(NewBaseUIDStoreStash): - object_type = ServerPeer - settings: PartitionSettings = PartitionSettings( - name=ServerPeer.__canonical_name__, object_type=ServerPeer - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="NetworkSQLStash", version=1) +class NetworkStash(ObjectStash[ServerPeer]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> ServerPeer: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) try: - return self.query_one(credentials=credentials, qks=qks).unwrap() - except NotFoundException as exc: + return self.get_one( + credentials=credentials, + filters={"name": name}, + ).unwrap() + except NotFoundException as e: raise NotFoundException.from_exception( - exc, public_message=f"ServerPeer with {name} not found" + e, public_message=f"ServerPeer with {name} not found" ) - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - peer_update: ServerPeerUpdate, - has_permission: bool = False, - ) -> ServerPeer: - self.check_type(peer_update, ServerPeerUpdate).unwrap() - return ( - super() - .update(credentials, peer_update, has_permission=has_permission) - .unwrap() - ) - @as_result(StashException) def create_or_update_peer( self, credentials: SyftVerifyKey, peer: ServerPeer @@ -148,28 +120,26 @@ def create_or_update_peer( def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> ServerPeer: - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_one(credentials, qks).unwrap( - private_message=f"ServerPeer with {verify_key} not found" - ) + return self.get_one( + credentials=credentials, + filters={"verify_key": verify_key}, + ).unwrap() @as_result(StashException) def get_by_server_type( self, credentials: SyftVerifyKey, server_type: ServerType ) -> list[ServerPeer]: - qks = QueryKeys(qks=[ServerTypePartitionKey.with_obj(server_type)]) - return self.query_all( - credentials=credentials, qks=qks, order_by=OrderByNamePartitionKey + return self.get_all( + credentials=credentials, + filters={"server_type": server_type}, ).unwrap() @serializable(canonical_name="NetworkService", version=1) class NetworkService(AbstractService): - store: DocumentStore stash: NetworkStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NetworkStash(store=store) if reverse_tunnel_enabled(): self.rtunnel_service = ReverseTunnelService() @@ -407,7 +377,8 @@ def get_all_peers(self, context: AuthedServiceContext) -> list[ServerPeer]: """Get all Peers""" return self.stash.get_all( credentials=context.server.verify_key, - order_by=OrderByNamePartitionKey, + order_by="name", + sort_order="asc", ).unwrap() @service_method( @@ -447,7 +418,7 @@ def update_peer( peer = self.stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, ).unwrap() self.set_reverse_tunnel_config(context=context, remote_server_peer=peer) @@ -494,7 +465,6 @@ def set_reverse_tunnel_config( def delete_peer_by_id(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess: """Delete Server Peer""" peer_to_delete = self.stash.get_by_uid(context.credentials, uid).unwrap() - peer_to_delete = cast(ServerPeer, peer_to_delete) server_side_type = cast(ServerType, context.server.server_type) if server_side_type.value == ServerType.GATEWAY.value: @@ -608,7 +578,7 @@ def add_route( ) self.stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, ).unwrap() return SyftSuccess( @@ -736,7 +706,7 @@ def delete_route( id=remote_server_peer.id, server_routes=remote_server_peer.server_routes ) self.stash.update( - credentials=context.server.verify_key, peer_update=peer_update + credentials=context.server.verify_key, obj=peer_update ).unwrap() return SyftSuccess(message=return_message) diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 5cd7a5f2136..f6de35f75fd 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -130,8 +130,7 @@ def server(self) -> AbstractServer | None: server_type=self.worker_settings.server_type, server_side_type=self.worker_settings.server_side_type, signing_key=self.worker_settings.signing_key, - document_store_config=self.worker_settings.document_store_config, - action_store_config=self.worker_settings.action_store_config, + db_config=self.worker_settings.db_config, processes=1, ) return server diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index 280e836f17b..655d40b9b3e 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -88,7 +88,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> None: result = network_stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, has_permission=True, ) diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py index 2ebc0908a88..c53da5fabef 100644 --- a/packages/syft/src/syft/service/notification/email_templates.py +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -133,7 +133,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) -> @staticmethod def email_body(notification: "Notification", context: AuthedServiceContext) -> str: user_service = context.server.services.user - admin_verify_key = user_service.admin_verify_key() + admin_verify_key = user_service.root_verify_key user = user_service.stash.get_by_verify_key( credentials=admin_verify_key, verify_key=notification.to_user_verify_key ).unwrap() @@ -224,7 +224,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) -> @staticmethod def email_body(notification: "Notification", context: AuthedServiceContext) -> str: user_service = context.server.services.user - admin_verify_key = user_service.admin_verify_key() + admin_verify_key = user_service.root_verify_key admin = user_service.get_by_verify_key(admin_verify_key).unwrap() admin_name = admin.name diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 15b2b9725d8..7d6cb83e2f1 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result @@ -28,11 +28,9 @@ @serializable(canonical_name="NotificationService", version=1) class NotificationService(AbstractService): - store: DocumentStore stash: NotificationStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NotificationStash(store=store) @service_method(path="notifications.send", name="send") diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index fd41ad0dda6..029cf1b325a 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -1,80 +1,49 @@ -# stdlib - -# third party - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject -from ...types.datetime import DateTime from ...types.result import as_result from ...types.uid import UID from .notifications import Notification from .notifications import NotificationStatus -FromUserVerifyKeyPartitionKey = PartitionKey( - key="from_user_verify_key", type_=SyftVerifyKey -) -ToUserVerifyKeyPartitionKey = PartitionKey( - key="to_user_verify_key", type_=SyftVerifyKey -) -StatusPartitionKey = PartitionKey(key="status", type_=NotificationStatus) - -OrderByCreatedAtTimeStampPartitionKey = PartitionKey(key="created_at", type_=DateTime) - -LinkedObjectPartitionKey = PartitionKey(key="linked_obj", type_=LinkedObject) - - -@serializable(canonical_name="NotificationStash", version=1) -class NotificationStash(NewBaseUIDStoreStash): - object_type = Notification - settings: PartitionSettings = PartitionSettings( - name=Notification.__canonical_name__, - object_type=Notification, - ) +@serializable(canonical_name="NotificationSQLStash", version=1) +class NotificationStash(ObjectStash[Notification]): @as_result(StashException) def get_all_inbox_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: - qks = QueryKeys( - qks=[ - ToUserVerifyKeyPartitionKey.with_obj(verify_key), - ] - ) - return self.get_all_for_verify_key( - credentials=credentials, verify_key=verify_key, qks=qks + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") + return self.get_all( + credentials, + filters={"to_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) def get_all_sent_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: - qks = QueryKeys( - qks=[ - FromUserVerifyKeyPartitionKey.with_obj(verify_key), - ] - ) - return self.get_all_for_verify_key( - credentials, verify_key=verify_key, qks=qks + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") + return self.get_all( + credentials, + filters={"from_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) def get_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, qks: QueryKeys + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - return self.query_all( + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") + return self.get_all( credentials, - qks=qks, - order_by=OrderByCreatedAtTimeStampPartitionKey, + filters={"from_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) @@ -84,16 +53,14 @@ def get_all_by_verify_key_for_status( verify_key: SyftVerifyKey, status: NotificationStatus, ) -> list[Notification]: - qks = QueryKeys( - qks=[ - ToUserVerifyKeyPartitionKey.with_obj(verify_key), - StatusPartitionKey.with_obj(status), - ] - ) - return self.query_all( + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") + return self.get_all( credentials, - qks=qks, - order_by=OrderByCreatedAtTimeStampPartitionKey, + filters={ + "to_user_verify_key": str(verify_key), + "status": status.name, + }, ).unwrap() @as_result(StashException, NotFoundException) @@ -102,14 +69,12 @@ def get_notification_for_linked_obj( credentials: SyftVerifyKey, linked_obj: LinkedObject, ) -> Notification: - qks = QueryKeys( - qks=[ - LinkedObjectPartitionKey.with_obj(linked_obj), - ] - ) - return self.query_one(credentials=credentials, qks=qks).unwrap( - public_message=f"Notifications for Linked Object {linked_obj} not found" - ) + return self.get_one( + credentials, + filters={ + "linked_obj.id": linked_obj.id, + }, + ).unwrap() @as_result(StashException, NotFoundException) def update_notification_status( @@ -123,6 +88,8 @@ def update_notification_status( def delete_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> bool: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") notifications = self.get_all_inbox_for_verify_key( credentials, verify_key=verify_key, diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index dffad4e414e..3fbddf6eb98 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -8,7 +8,9 @@ from ...server.credentials import SyftVerifyKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_credentials_for_key @@ -48,7 +50,7 @@ class ReplyNotification(SyftObject): @serializable() -class Notification(SyftObject): +class NotificationV1(SyftObject): __canonical_name__ = "Notification" __version__ = SYFT_OBJECT_VERSION_1 @@ -71,6 +73,32 @@ class Notification(SyftObject): __repr_attrs__ = ["subject", "status", "created_at", "linked_obj"] __table_sort_attr__ = "Created at" + +@serializable() +class Notification(SyftObject): + __canonical_name__ = "Notification" + __version__ = SYFT_OBJECT_VERSION_2 + + subject: str + server_uid: UID + from_user_verify_key: SyftVerifyKey + to_user_verify_key: SyftVerifyKey + created_at: DateTime + status: NotificationStatus = NotificationStatus.UNREAD + linked_obj: LinkedObject | None = None + notifier_types: list[NOTIFIERS] = [] + email_template: type[EmailTemplate] | None = None + replies: list[ReplyNotification] = [] + + __attr_searchable__ = [ + "from_user_verify_key", + "to_user_verify_key", + "status", + ] + __repr_attrs__ = ["subject", "status", "created_at", "linked_obj"] + __table_sort_attr__ = "Created at" + __order_by__ = ("created_at", "asc") + def _repr_html_(self) -> str: return f"""
@@ -145,3 +173,8 @@ def createnotification_to_notification() -> list[Callable]: add_credentials_for_key("from_user_verify_key"), add_server_uid_for_key("server_uid"), ] + + +@migrate(NotificationV1, Notification) +def migrate_nofitication_v1_to_v2() -> list[Callable]: + return [] # skip migration, no changes in the class diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index dbde4905972..31cb439a614 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import AbstractServer from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -35,11 +35,9 @@ class RateLimitException(SyftException): @serializable(canonical_name="NotifierService", version=1) class NotifierService(AbstractService): - store: DocumentStore - stash: NotifierStash # Which stash should we use? + stash: NotifierStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NotifierStash(store=store) @as_result(StashException) @@ -258,7 +256,7 @@ def init_notifier( """ try: # Create a new NotifierStash since its a static method. - notifier_stash = NotifierStash(store=server.document_store) + notifier_stash = NotifierStash(store=server.db) should_update = False # Get the notifier @@ -325,7 +323,7 @@ def set_email_rate_limit( def dispatch_notification( self, context: AuthedServiceContext, notification: Notification ) -> SyftSuccess: - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key # Silently fail on notification not delivered try: diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 8dbe8e31e8f..861f3a89975 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -5,54 +5,29 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseStash -from ...store.document_store import PartitionKey +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...types.uid import UID -from ..action.action_permissions import ActionObjectPermission +from ...util.telemetry import instrument from .notifier import NotifierSettings -NamePartitionKey = PartitionKey(key="name", type_=str) -ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) - -@serializable(canonical_name="NotifierStash", version=1) -class NotifierStash(NewBaseStash): - object_type = NotifierSettings +@instrument +@serializable(canonical_name="NotifierSQLStash", version=1) +class NotifierStash(ObjectStash[NotifierSettings]): settings: PartitionSettings = PartitionSettings( name=NotifierSettings.__canonical_name__, object_type=NotifierSettings ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - def admin_verify_key(self) -> SyftVerifyKey: - return self.partition.root_verify_key - - # TODO: should this method behave like a singleton? @as_result(StashException, NotFoundException) def get(self, credentials: SyftVerifyKey) -> NotifierSettings: """Get Settings""" - settings: list[NotifierSettings] = self.get_all(credentials).unwrap() - if len(settings) == 0: - raise NotFoundException - return settings[0] - - @as_result(StashException) - def set( - self, - credentials: SyftVerifyKey, - settings: NotifierSettings, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> NotifierSettings: - result = self.check_type(settings, self.object_type).unwrap() - # we dont use and_then logic here as it is hard because of the order of the arguments - return ( - super().set(credentials=credentials, obj=result).unwrap() - ) # TODO check if result isInstance(Ok) + # actually get latest settings + result = self.get_all(credentials, limit=1, sort_order="desc").unwrap() + if len(result) > 0: + return result[0] + raise NotFoundException( + public_message="No settings found for the current user." + ) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 422788d22f5..5d26ff2cb3e 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -7,14 +7,10 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException -from ...store.document_store_errors import TooManyItemsFoundException from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.result import as_result @@ -183,65 +179,41 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return res -@serializable(canonical_name="OutputStash", version=1) -class OutputStash(NewBaseUIDStoreStash): - object_type = ExecutionOutput - settings: PartitionSettings = PartitionSettings( - name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store) - self.store = store - self.settings = self.settings - self._object_type = self.object_type - +@serializable(canonical_name="OutputStashSQL", version=1) +class OutputStash(ObjectStash[ExecutionOutput]): @as_result(StashException) def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[ExecutionOutput]: - qks = QueryKeys( - qks=[UserCodeIdPartitionKey.with_obj(user_code_id)], - ) - return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + return self.get_all( + credentials=credentials, + filters={"user_code_id": user_code_id}, ).unwrap() @as_result(StashException) def get_by_job_id( - self, credentials: SyftVerifyKey, user_code_id: UID - ) -> ExecutionOutput: - qks = QueryKeys( - qks=[JobIdPartitionKey.with_obj(user_code_id)], - ) - res = self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + self, credentials: SyftVerifyKey, job_id: UID + ) -> ExecutionOutput | None: + return self.get_one( + credentials=credentials, + filters={"job_id": job_id}, ).unwrap() - if len(res) == 0: - raise NotFoundException() - elif len(res) > 1: - raise TooManyItemsFoundException() - return res[0] @as_result(StashException) def get_by_output_policy_id( self, credentials: SyftVerifyKey, output_policy_id: UID ) -> list[ExecutionOutput]: - qks = QueryKeys( - qks=[OutputPolicyIdPartitionKey.with_obj(output_policy_id)], - ) - return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + return self.get_all( + credentials=credentials, + filters={"output_policy_id": output_policy_id}, ).unwrap() @serializable(canonical_name="OutputService", version=1) class OutputService(AbstractService): - store: DocumentStore stash: OutputStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = OutputStash(store=store) @service_method( @@ -310,7 +282,7 @@ def has_output_read_permissions( ActionObjectREAD(uid=_id.id, credentials=user_verify_key) for _id in result_ids ] - if context.server.services.action.store.has_permissions(permissions): + if context.server.services.action.stash.has_permissions(permissions): return True return False @@ -321,11 +293,11 @@ def has_output_read_permissions( roles=ADMIN_ROLE_LEVEL, ) def get_by_job_id( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, job_id: UID ) -> ExecutionOutput: return self.stash.get_by_job_id( credentials=context.server.verify_key, # type: ignore - user_code_id=user_code_id, + job_id=job_id, ).unwrap() @service_method( diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 3b1f33c0a08..1e33755418e 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -174,7 +174,6 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str from ..action.action_object import ActionObject # fetches the all the current api's connected - api_list = APIRegistry.get_all_api() output_kwargs = {} for k, v in kwargs.items(): uid = v @@ -190,7 +189,7 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str raise Exception(f"Input {k} must have a UID not {type(v)}") _obj_exists = False - for api in api_list: + for identity, api in APIRegistry.__api_registry__.items(): try: if api.services.action.exists(uid): server_identity = ServerIdentity.from_api(api) @@ -205,6 +204,9 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str # To handle the cases , where there an old api objects in # in APIRegistry continue + except Exception as e: + print(f"Error in partition_by_server with identity {identity}", e) + raise e if not _obj_exists: raise Exception(f"Input data {k}:{uid} does not belong to any Datasite") @@ -335,7 +337,7 @@ class UserOwned(PolicyRule): def is_owned( self, context: AuthedServiceContext, action_object: ActionObject ) -> bool: - action_store = context.server.services.action.store + action_store = context.server.services.action.stash return action_store.has_permission( ActionObjectPermission( action_object.id, ActionPermission.OWNER, context.credentials diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index fe32c226dbf..1e5f430d109 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -16,11 +16,9 @@ @serializable(canonical_name="PolicyService", version=1) class PolicyService(AbstractService): - store: DocumentStore stash: UserPolicyStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserPolicyStash(store=store) @service_method(path="policy.get_all", name="get_all") diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py index 38ab7f54c06..9e3a103280b 100644 --- a/packages/syft/src/syft/service/policy/user_policy_stash.py +++ b/packages/syft/src/syft/service/policy/user_policy_stash.py @@ -3,30 +3,20 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from .policy import PolicyUserVerifyKeyPartitionKey from .policy import UserPolicy -@serializable(canonical_name="UserPolicyStash", version=1) -class UserPolicyStash(NewBaseUIDStoreStash): - object_type = UserPolicy - settings: PartitionSettings = PartitionSettings( - name=UserPolicy.__canonical_name__, object_type=UserPolicy - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="UserPolicySQLStash", version=1) +class UserPolicyStash(ObjectStash[UserPolicy]): @as_result(StashException, NotFoundException) def get_all_by_user_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> list[UserPolicy]: - qks = QueryKeys(qks=[PolicyUserVerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"user_verify_key": user_verify_key}, + ).unwrap() diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 0b64a9d8870..5ed21f5007d 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -1247,12 +1247,12 @@ def send(self, return_all_projects: bool = False) -> Project | list[Project]: projects_map = self._create_projects(self.clients) # bootstrap project with pending events on leader server's project - self._bootstrap_events(projects_map[leader]) + self._bootstrap_events(projects_map[leader.id]) # type: ignore if return_all_projects: return list(projects_map.values()) - return projects_map[leader] + return projects_map[leader.id] # type: ignore def _pre_submit_checks(self, clients: list[SyftClient]) -> bool: try: @@ -1277,12 +1277,12 @@ def _exchange_routes(self, leader: SyftClient, followers: list[SyftClient]) -> N self.leader_server_route = connection_to_route(leader.connection) - def _create_projects(self, clients: list[SyftClient]) -> dict[SyftClient, Project]: - projects: dict[SyftClient, Project] = {} + def _create_projects(self, clients: list[SyftClient]) -> dict[UID, Project]: + projects: dict[UID, Project] = {} for client in clients: result = client.api.services.project.create_project(project=self).value - projects[client] = result + projects[client.id] = result # type: ignore return projects diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 3a2543e38a1..2df2fd42e1c 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -32,11 +32,9 @@ @serializable(canonical_name="ProjectService", version=1) class ProjectService(AbstractService): - store: DocumentStore stash: ProjectStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = ProjectStash(store=store) @as_result(SyftException) diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index bf81bd5b9b1..13dab37bdea 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -5,49 +5,27 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...types.uid import UID -from ..request.request import Request from .project import Project -# TODO: Move to a partitions file? -VerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) -NamePartitionKey = PartitionKey(key="name", type_=str) - -@serializable(canonical_name="ProjectStash", version=1) -class ProjectStash(NewBaseUIDStoreStash): - object_type = Project - settings: PartitionSettings = PartitionSettings( - name=Project.__canonical_name__, object_type=Project - ) - - # TODO: Shouldn't this be a list of projects? +@serializable(canonical_name="ProjectSQLStash", version=1) +class ProjectStash(ObjectStash[Project]): @as_result(StashException) def get_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: VerifyKeyPartitionKey - ) -> list[Request]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_all( + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey + ) -> list[Project]: + return self.get_all( credentials=credentials, - qks=qks, + filters={"user_verify_key": verify_key}, ).unwrap() - @as_result(StashException, NotFoundException) - def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> Project: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, project_name: str) -> Project: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(project_name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"name": project_name}, + ).unwrap() diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index ae170cad95b..b2807389cf1 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -6,6 +6,7 @@ from threading import Thread import time from typing import Any +from typing import cast # third party import psutil @@ -178,8 +179,7 @@ def handle_message_multiprocessing( id=worker_settings.id, name=worker_settings.name, signing_key=worker_settings.signing_key, - document_store_config=worker_settings.document_store_config, - action_store_config=worker_settings.action_store_config, + db_config=worker_settings.db_config, blob_storage_config=worker_settings.blob_store_config, server_side_type=worker_settings.server_side_type, queue_config=queue_config, @@ -256,7 +256,7 @@ def handle_message_multiprocessing( public_message=f"Job {queue_item.job_id} not found!" ) - job_item.server_uid = worker.id + job_item.server_uid = worker.id # type: ignore[assignment] job_item.result = result job_item.resolved = True job_item.status = job_status @@ -282,7 +282,10 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: from ...server.server import Server queue_item = deserialize(message, from_bytes=True) + queue_item = cast(QueueItem, queue_item) worker_settings = queue_item.worker_settings + if worker_settings is None: + raise ValueError("Worker settings are missing in the queue item.") queue_config = worker_settings.queue_config queue_config.client_config.create_producer = False @@ -292,9 +295,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: id=worker_settings.id, name=worker_settings.name, signing_key=worker_settings.signing_key, - document_store_config=worker_settings.document_store_config, - action_store_config=worker_settings.action_store_config, - blob_storage_config=worker_settings.blob_store_config, + db_config=worker_settings.db_config, server_side_type=worker_settings.server_side_type, deployment_type=worker_settings.deployment_type, queue_config=queue_config, diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index c898893ee35..b98f344745d 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -2,30 +2,14 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...types.uid import UID -from ..context import AuthedServiceContext +from ...store.db.db import DBManager from ..service import AbstractService -from ..service import service_method -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from .queue_stash import QueueItem from .queue_stash import QueueStash @serializable(canonical_name="QueueService", version=1) class QueueService(AbstractService): - store: DocumentStore stash: QueueStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = QueueStash(store=store) - - @service_method( - path="queue.get_subjobs", - name="get_subjobs", - roles=DATA_SCIENTIST_ROLE_LEVEL, - ) - def get_subjobs(self, context: AuthedServiceContext, uid: UID) -> list[QueueItem]: - # FIX: There is no get_by_parent_id in QueueStash - return self.stash.get_by_parent_id(context.credentials, uid=uid) diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 251f4a9fb63..aa5b872b226 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from enum import Enum from typing import Any @@ -6,22 +7,24 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseStash +from ...server.worker_settings import WorkerSettingsV1 +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import TransformContext from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission +__all__ = ["QueueItem"] + @serializable(canonical_name="Status", version=1) class Status(str, Enum): @@ -37,7 +40,7 @@ class Status(str, Enum): @serializable() -class QueueItem(SyftObject): +class QueueItemV1(SyftObject): __canonical_name__ = "QueueItem" __version__ = SYFT_OBJECT_VERSION_1 @@ -49,6 +52,29 @@ class QueueItem(SyftObject): resolved: bool = False status: Status = Status.CREATED + method: str + service: str + args: list + kwargs: dict[str, Any] + job_id: UID | None = None + worker_settings: WorkerSettingsV1 | None = None + has_execute_permissions: bool = False + worker_pool: LinkedObject + + +@serializable() +class QueueItem(SyftObject): + __canonical_name__ = "QueueItem" + __version__ = SYFT_OBJECT_VERSION_2 + + __attr_searchable__ = ["status", "worker_pool_id"] + + id: UID + server_uid: UID + result: Any | None = None + resolved: bool = False + status: Status = Status.CREATED + method: str service: str args: list @@ -64,6 +90,10 @@ def __repr__(self) -> str: def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return f": {self.status}" + @property + def worker_pool_id(self) -> UID: + return self.worker_pool.object_uid + @property def is_action(self) -> bool: return self.service_path == "Action" and self.method_name == "execute" @@ -78,7 +108,7 @@ def action(self) -> Any: @serializable() class ActionQueueItem(QueueItem): __canonical_name__ = "ActionQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 method: str = "execute" service: str = "actionservice" @@ -87,22 +117,32 @@ class ActionQueueItem(QueueItem): @serializable() class APIEndpointQueueItem(QueueItem): __canonical_name__ = "APIEndpointQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 method: str service: str = "apiservice" -@serializable(canonical_name="QueueStash", version=1) -class QueueStash(NewBaseStash): - object_type = QueueItem - settings: PartitionSettings = PartitionSettings( - name=QueueItem.__canonical_name__, object_type=QueueItem - ) +@serializable() +class ActionQueueItemV1(QueueItemV1): + __canonical_name__ = "ActionQueueItem" + __version__ = SYFT_OBJECT_VERSION_1 + + method: str = "execute" + service: str = "actionservice" + + +@serializable() +class APIEndpointQueueItemV1(QueueItemV1): + __canonical_name__ = "APIEndpointQueueItem" + __version__ = SYFT_OBJECT_VERSION_1 + + method: str + service: str = "apiservice" - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="QueueSQLStash", version=1) +class QueueStash(ObjectStash[QueueItem]): # FIX: Check the return value for None. set_result is used extensively @as_result(StashException) def set_result( @@ -133,11 +173,6 @@ def set_placeholder( return super().set(credentials, item, add_permissions).unwrap() return item - @as_result(StashException) - def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @as_result(StashException) def pop(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem | None: try: @@ -156,23 +191,51 @@ def pop_on_complete(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: self.delete_by_uid(credentials=credentials, uid=uid) return queue_item - @as_result(StashException) - def delete_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> UID: - qk = UIDPartitionKey.with_obj(uid) - super().delete(credentials=credentials, qk=qk).unwrap() - return uid - @as_result(StashException) def get_by_status( self, credentials: SyftVerifyKey, status: Status ) -> list[QueueItem]: - qks = QueryKeys(qks=StatusPartitionKey.with_obj(status)) - - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"status": status}, + ).unwrap() @as_result(StashException) def _get_by_worker_pool( self, credentials: SyftVerifyKey, worker_pool: LinkedObject ) -> list[QueueItem]: - qks = QueryKeys(qks=_WorkerPoolPartitionKey.with_obj(worker_pool)) - return self.query_all(credentials=credentials, qks=qks).unwrap() + worker_pool_id = worker_pool.object_uid + + return self.get_all( + credentials=credentials, + filters={"worker_pool_id": worker_pool_id}, + ).unwrap() + + +def upgrade_worker_settings_for_queue(context: TransformContext) -> TransformContext: + if context.output and context.output["worker_settings"] is not None: + worker_settings_old: WorkerSettingsV1 | None = context.output["worker_settings"] + if worker_settings_old is None: + return context + + worker_settings = worker_settings_old.migrate_to( + WorkerSettings.__version__, context=context.to_server_context() + ) + context.output["worker_settings"] = worker_settings + + return context + + +@migrate(QueueItemV1, QueueItem) +def migrate_queue_item_from_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] + + +@migrate(ActionQueueItemV1, ActionQueueItem) +def migrate_action_queue_item_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] + + +@migrate(APIEndpointQueueItemV1, APIEndpointQueueItem) +def migrate_api_endpoint_queue_item_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index f6993d6b032..b2f6d4b9a6e 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -288,7 +288,7 @@ def _set_worker_job(self, job_id: UID | None) -> None: ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING ) res = self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, worker_uid=self.syft_worker_id, consumer_state=consumer_state, ) diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py index 5cb5056f8a2..197f87ab283 100644 --- a/packages/syft/src/syft/service/queue/zmq_producer.py +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -159,7 +159,7 @@ def read_items(self) -> None: # Items to be queued items_to_queue = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, + self.queue_stash.root_verify_key, status=Status.CREATED, ).unwrap() @@ -167,7 +167,7 @@ def read_items(self) -> None: # Queue Items that are in the processing state items_processing = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, + self.queue_stash.root_verify_key, status=Status.PROCESSING, ).unwrap() @@ -281,14 +281,14 @@ def update_consumer_state_for_worker( try: try: self.worker_stash.get_by_uid( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, uid=syft_worker_id, ).unwrap() except Exception: return None self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, worker_uid=syft_worker_id, consumer_state=consumer_state, ).unwrap() diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 1a492ea1d53..674c66a019a 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -40,8 +40,8 @@ from ...util.notebook_ui.icons import Icon from ...util.util import prompt_warning_message from ..action.action_object import ActionObject -from ..action.action_store import ActionObjectPermission -from ..action.action_store import ActionPermission +from ..action.action_permissions import ActionObjectPermission +from ..action.action_permissions import ActionPermission from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus from ..code.user_code import UserCodeStatusCollection @@ -112,7 +112,7 @@ class ActionStoreChange(Change): @as_result(SyftException) def _run(self, context: ChangeContext, apply: bool) -> SyftSuccess: - action_store = context.server.services.action.store + action_store = context.server.services.action.stash # can we ever have a lineage ID in the store? obj_uid = self.linked_obj.object_uid @@ -362,6 +362,7 @@ class Request(SyncableSyftObject): __attr_searchable__ = [ "requesting_user_verify_key", "approving_user_verify_key", + "code_id", ] __attr_unique__ = ["request_hash"] __repr_attrs__ = [ diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 6343007f6ca..ed94c185689 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result @@ -37,11 +37,9 @@ @serializable(canonical_name="RequestService", version=1) class RequestService(AbstractService): - store: DocumentStore stash: RequestStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = RequestStash(store=store) @service_method(path="request.submit", name="submit", roles=GUEST_ROLE_LEVEL) @@ -59,7 +57,7 @@ def submit( request, ).unwrap() - root_verify_key = context.server.services.user.admin_verify_key() + root_verify_key = context.server.services.user.root_verify_key if send_message: message_subject = f"Result to request {str(request.id)[:4]}...{str(request.id)[-3:]}\ diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 19bb2f5720e..a28fd5842e1 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -1,59 +1,31 @@ -# stdlib - -# third party - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...types.datetime import DateTime +from ...store.db.stash import ObjectStash from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID from .request import Request -RequestingUserVerifyKeyPartitionKey = PartitionKey( - key="requesting_user_verify_key", type_=SyftVerifyKey -) - -OrderByRequestTimeStampPartitionKey = PartitionKey(key="request_time", type_=DateTime) - - -@serializable(canonical_name="RequestStash", version=1) -class RequestStash(NewBaseUIDStoreStash): - object_type = Request - settings: PartitionSettings = PartitionSettings( - name=Request.__canonical_name__, object_type=Request - ) +@serializable(canonical_name="RequestStashSQL", version=1) +class RequestStash(ObjectStash[Request]): @as_result(SyftException) def get_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> list[Request]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_all( + return self.get_all( credentials=credentials, - qks=qks, - order_by=OrderByRequestTimeStampPartitionKey, + filters={"requesting_user_verify_key": verify_key}, ).unwrap() @as_result(SyftException) def get_by_usercode_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[Request]: - all_requests = self.get_all(credentials=credentials).unwrap() - res = [] - for r in all_requests: - try: - if r.code_id == user_code_id: - res.append(r) - except SyftException: - pass - return res + return self.get_all( + credentials=credentials, + filters={"code_id": user_code_id}, + ).unwrap() diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 784eca2e340..49749711853 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -37,6 +37,7 @@ from ..serde.signature import signature_remove_context from ..serde.signature import signature_remove_self from ..server.credentials import SyftVerifyKey +from ..store.db.stash import ObjectStash from ..store.document_store import DocumentStore from ..store.linked_obj import LinkedObject from ..types.errors import SyftException @@ -71,6 +72,7 @@ class AbstractService: server: AbstractServer server_uid: UID store_type: type = DocumentStore + stash: ObjectStash @as_result(SyftException) def resolve_link( diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 43fe685b7e5..10890350e2d 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import ServerSideType from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.sqlite_document_store import SQLiteStoreConfig @@ -48,11 +48,9 @@ @serializable(canonical_name="SettingsService", version=1) class SettingsService(AbstractService): - store: DocumentStore stash: SettingsStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SettingsStash(store=store) @service_method(path="settings.get", name="get") @@ -123,14 +121,16 @@ def update( def _update( self, context: AuthedServiceContext, settings: ServerSettingsUpdate ) -> ServerSettings: - all_settings = self.stash.get_all(context.credentials).unwrap() + all_settings = self.stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() if len(all_settings) > 0: new_settings = all_settings[0].model_copy( update=settings.to_dict(exclude_empty=True) ) ServerSettings.model_validate(new_settings.to_dict()) update_result = self.stash.update( - context.credentials, settings=new_settings + context.credentials, obj=new_settings ).unwrap() # If notifications_enabled is present in the update, we need to update the notifier settings @@ -173,10 +173,12 @@ def set_server_side_type_dangerous( public_message=f"Not a valid server_side_type, please use one of the options from: {side_type_options}" ) - current_settings = self.stash.get_all(context.credentials).unwrap() + current_settings = self.stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() if len(current_settings) > 0: new_settings = current_settings[0] - new_settings.server_side_type = server_side_type + new_settings.server_side_type = ServerSideType(server_side_type) updated_settings = self.stash.update( context.credentials, new_settings ).unwrap() diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index aa02847504a..c22c08045f3 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -1,36 +1,11 @@ # relative from ...serde.serializable import serializable -from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store_errors import StashException -from ...types.result import as_result -from ...types.uid import UID +from ...store.db.stash import ObjectStash +from ...util.telemetry import instrument from .settings import ServerSettings -NamePartitionKey = PartitionKey(key="name", type_=str) -ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) - -@serializable(canonical_name="SettingsStash", version=1) -class SettingsStash(NewBaseUIDStoreStash): - object_type = ServerSettings - settings: PartitionSettings = PartitionSettings( - name=ServerSettings.__canonical_name__, object_type=ServerSettings - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - # Should we have this at all? - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - settings: ServerSettings, - has_permission: bool = False, - ) -> ServerSettings: - obj = self.check_type(settings, self.object_type).unwrap() - return super().update(credentials=credentials, obj=obj).unwrap() +@instrument +@serializable(canonical_name="SettingsStashSQL", version=1) +class SettingsStash(ObjectStash[ServerSettings]): + pass diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 75959be55e5..ddafd86b1d3 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -6,7 +6,8 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException from ...store.linked_obj import LinkedObject @@ -36,22 +37,20 @@ logger = logging.getLogger(__name__) -def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> Any: +def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash: if isinstance(item, ActionObject): service = context.server.services.action # type: ignore - return service.store # type: ignore + return service.stash # type: ignore service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore - return service.stash.partition + return service.stash @instrument @serializable(canonical_name="SyncService", version=1) class SyncService(AbstractService): - store: DocumentStore stash: SyncStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = SyncStash(store=store) def add_actionobject_read_permissions( @@ -60,14 +59,14 @@ def add_actionobject_read_permissions( action_object: ActionObject, new_permissions: list[ActionObjectPermission], ) -> None: - store_to = context.server.services.action.store # type: ignore + action_stash = context.server.services.action.stash for permission in new_permissions: if permission.permission == ActionPermission.READ: - store_to.add_permission(permission) + action_stash.add_permission(permission) blob_id = action_object.syft_blob_storage_entry_id if blob_id: - store_to_blob = context.server.services.blob_sotrage.stash.partition # type: ignore + blob_stash = context.server.services.blob_storage.stash for permission in new_permissions: if permission.permission == ActionPermission.READ: permission_blob = ActionObjectPermission( @@ -75,7 +74,7 @@ def add_actionobject_read_permissions( permission=permission.permission, credentials=permission.credentials, ) - store_to_blob.add_permission(permission_blob) + blob_stash.add_permission(permission_blob) def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: if hasattr(x, "__dict__") and isinstance(x, SyftObject): @@ -246,9 +245,9 @@ def get_permissions( self, context: AuthedServiceContext, items: list[SyncableSyftObject], - ) -> tuple[dict[UID, set[str]], dict[UID, set[str]]]: - permissions = {} - storage_permissions = {} + ) -> tuple[dict[UID, set[str]], dict[UID, set[UID]]]: + permissions: dict[UID, set[str]] = {} + storage_permissions: dict[UID, set[UID]] = {} for item in items: store = get_store(context, item) @@ -369,7 +368,9 @@ def build_current_state( storage_permissions = {} try: - previous_state = self.stash.get_latest(context=context).unwrap() + previous_state = self.stash.get_latest( + credentials=context.credentials + ).unwrap() except NotFoundException: previous_state = None diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 8633f31f130..114ee209af1 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -1,69 +1,37 @@ -# stdlib -import threading - # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey +from ...server.credentials import SyftVerifyKey +from ...store.db.db import DBManager +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...store.document_store_errors import StashException -from ...types.datetime import DateTime from ...types.result import as_result -from ..context import AuthedServiceContext from .sync_state import SyncState -OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime) - @serializable(canonical_name="SyncStash", version=1) -class SyncStash(NewBaseUIDStoreStash): - object_type = SyncState +class SyncStash(ObjectStash[SyncState]): settings: PartitionSettings = PartitionSettings( name=SyncState.__canonical_name__, object_type=SyncState, ) - def __init__(self, store: DocumentStore): + def __init__(self, store: DBManager) -> None: super().__init__(store) - self.store = store - self.settings = self.settings - self._object_type = self.object_type self.last_state: SyncState | None = None @as_result(StashException) - def get_latest(self, context: AuthedServiceContext) -> SyncState | None: + def get_latest(self, credentials: SyftVerifyKey) -> SyncState | None: if self.last_state is not None: return self.last_state - all_states = self.get_all( - credentials=context.server.verify_key, # type: ignore - order_by=OrderByDatePartitionKey, + + states = self.get_all( + credentials=credentials, + order_by="created_at", + sort_order="desc", + limit=1, ).unwrap() - if len(all_states) > 0: - self.last_state = all_states[-1] - return all_states[-1] + if len(states) > 0: + return states[0] return None - - def unwrap_set(self, context: AuthedServiceContext, item: SyncState) -> SyncState: - return super().set(context, item).unwrap() - - @as_result(StashException) - def set( # type: ignore - self, - context: AuthedServiceContext, - item: SyncState, - **kwargs, - ) -> SyncState: - self.last_state = item - - # use threading - threading.Thread( - target=self.unwrap_set, - args=( - context, - item, - ), - kwargs=kwargs, - ).start() - return item diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index dab3b6cccd1..a0d27a00786 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -11,7 +11,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -81,23 +81,21 @@ def _paginate( @serializable(canonical_name="UserService", version=1) class UserService(AbstractService): - store: DocumentStore stash: UserStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserStash(store=store) @as_result(StashException) def _add_user(self, credentials: SyftVerifyKey, user: User) -> User: - action_object_permissions = ActionObjectPermission( - uid=user.id, permission=ActionPermission.ALL_READ - ) - return self.stash.set( credentials=credentials, obj=user, - add_permissions=[action_object_permissions], + add_permissions=[ + ActionObjectPermission( + uid=user.id, permission=ActionPermission.ALL_READ + ), + ], ).unwrap() def _check_if_email_exists(self, credentials: SyftVerifyKey, email: str) -> bool: @@ -132,7 +130,7 @@ def forgot_password( "If the email is valid, we sent a password " + "reset token to your email or a password request to the admin." ) - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) @@ -243,7 +241,7 @@ def reset_password( self, context: UnauthedServiceContext, token: str, new_password: str ) -> SyftSuccess: """Resets a certain user password using a temporary token.""" - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) try: @@ -321,21 +319,36 @@ def view(self, context: AuthedServiceContext, uid: UID) -> UserView: def get_all( self, context: AuthedServiceContext, + order_by: str | None = None, + sort_order: str | None = None, page_size: int | None = 0, page_index: int | None = 0, ) -> list[UserView]: - if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: - users = self.stash.get_all( - context.credentials, has_permission=True - ).unwrap() - else: - users = self.stash.get_all(context.credentials).unwrap() + users = self.stash.get_all( + context.credentials, + order_by=order_by, + sort_order=sort_order, + ).unwrap() users = [user.to(UserView) for user in users] return _paginate(users, page_size, page_index) + @service_method( + path="user.get_index", name="get_index", roles=DATA_OWNER_ROLE_LEVEL + ) + def get_index( + self, + context: AuthedServiceContext, + index: int, + ) -> UserView: + return ( + self.stash.get_index(credentials=context.credentials, index=index) + .unwrap() + .to(UserView) + ) + def signing_key_for_verify_key(self, verify_key: SyftVerifyKey) -> UserPrivateKey: user = self.stash.get_by_verify_key( - credentials=self.stash.admin_verify_key(), verify_key=verify_key + credentials=self.stash.root_verify_key, verify_key=verify_key ).unwrap() return user.to(UserPrivateKey) @@ -348,12 +361,11 @@ def get_role_for_credentials( # they could be different # TODO: This fn is cryptic -- when does each situation occur? if isinstance(credentials, SyftVerifyKey): - user = self.stash.get_by_verify_key( - credentials=credentials, verify_key=credentials - ).unwrap() + role = self.stash.get_role(credentials=credentials) + return role elif isinstance(credentials, SyftSigningKey): user = self.stash.get_by_signing_key( - credentials=credentials, + credentials=credentials.verify_key, signing_key=credentials, # type: ignore ).unwrap() else: @@ -378,7 +390,9 @@ def search( if len(kwargs) == 0: raise SyftException(public_message="Invalid search parameters") - users = self.stash.find_all(credentials=context.credentials, **kwargs).unwrap() + users = self.stash.get_all( + credentials=context.credentials, filters=kwargs + ).unwrap() users = [user.to(UserView) for user in users] if users is not None else [] return _paginate(users, page_size, page_index) @@ -506,45 +520,54 @@ def update( ).unwrap() if user.role == ServiceRole.ADMIN: - settings_stash = SettingsStash(store=self.store) - settings = settings_stash.get_all(context.credentials).unwrap() + settings_stash = SettingsStash(store=self.stash.db) + settings = settings_stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() # TODO: Chance to refactor here in settings, as we're always doing get_att[0] if len(settings) > 0: settings_data = settings[0] settings_data.admin_email = user.email settings_stash.update( - credentials=context.credentials, settings=settings_data + credentials=context.credentials, obj=settings_data ) return user.to(UserView) @service_method(path="user.delete", name="delete", roles=GUEST_ROLE_LEVEL) def delete(self, context: AuthedServiceContext, uid: UID) -> UID: - user = self.stash.get_by_uid(credentials=context.credentials, uid=uid).unwrap() + user_to_delete = self.stash.get_by_uid( + credentials=context.credentials, uid=uid + ).unwrap() - if ( + # Cannot delete root user + if user_to_delete.verify_key == self.root_verify_key: + raise UserPermissionError( + private_message=f"User {context.credentials} attempted to delete root user." + ) + + # - Admins can delete any user + # - Data Owners can delete Data Scientists and Guests + has_delete_permissions = ( context.role == ServiceRole.ADMIN or context.role == ServiceRole.DATA_OWNER - and user.role - in [ - ServiceRole.GUEST, - ServiceRole.DATA_SCIENTIST, - ] - ): - pass - else: + and user_to_delete.role in [ServiceRole.GUEST, ServiceRole.DATA_SCIENTIST] + ) + + if not has_delete_permissions: raise UserPermissionError( - f"User {context.credentials} ({context.role}) tried to delete user {uid} ({user.role})" + private_message=( + f"User {context.credentials} ({context.role}) tried to delete user " + f"{uid} ({user_to_delete.role})" + ) ) # TODO: Remove notifications for the deleted user - self.stash.delete_by_uid( - credentials=context.credentials, uid=uid, has_permission=True + return self.stash.delete_by_uid( + credentials=context.credentials, uid=uid ).unwrap() - return uid - def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: """Verify user TODO: We might want to use a SyftObject instead @@ -554,7 +577,7 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: raise SyftException(public_message="Invalid login credentials") user = self.stash.get_by_email( - credentials=self.admin_verify_key(), email=context.login_credentials.email + credentials=self.root_verify_key, email=context.login_credentials.email ).unwrap() if check_pwd(context.login_credentials.password, user.hashed_password): @@ -573,9 +596,9 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: return SyftSuccess(message="Login successful.", value=user.to(UserPrivateKey)) - def admin_verify_key(self) -> SyftVerifyKey: - # TODO: Remove passthrough method? - return self.stash.admin_verify_key() + @property + def root_verify_key(self) -> SyftVerifyKey: + return self.stash.root_verify_key def register( self, context: ServerServiceContext, new_user: UserCreate @@ -616,7 +639,7 @@ def register( success_message = f"User '{user.name}' successfully registered!" # Notification Step - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) link = None @@ -643,7 +666,7 @@ def register( @as_result(StashException) def user_verify_key(self, email: str) -> SyftVerifyKey: # we are bypassing permissions here, so dont use to return a result directly to the user - credentials = self.admin_verify_key() + credentials = self.root_verify_key user = self.stash.get_by_email(credentials=credentials, email=email).unwrap() if user.verify_key is None: raise UserError(f"User {email} has no verify key") @@ -652,7 +675,7 @@ def user_verify_key(self, email: str) -> SyftVerifyKey: @as_result(StashException) def get_by_verify_key(self, verify_key: SyftVerifyKey) -> UserView: # we are bypassing permissions here, so dont use to return a result directly to the user - credentials = self.admin_verify_key() + credentials = self.root_verify_key user = self.stash.get_by_verify_key( credentials=credentials, verify_key=verify_key ).unwrap() diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 3272f9a946c..92fb87d37b3 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -2,90 +2,63 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...types.uid import UID from .user import User from .user_roles import ServiceRole -# 🟡 TODO 27: it would be nice if these could be defined closer to the User -EmailPartitionKey = PartitionKey(key="email", type_=str) -PasswordResetTokenPartitionKey = PartitionKey(key="reset_token", type_=str) -RolePartitionKey = PartitionKey(key="role", type_=ServiceRole) -SigningKeyPartitionKey = PartitionKey(key="signing_key", type_=SyftSigningKey) -VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) - - -@serializable(canonical_name="UserStash", version=1) -class UserStash(NewBaseUIDStoreStash): - object_type = User - settings: PartitionSettings = PartitionSettings( - name=User.__canonical_name__, - object_type=User, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - def admin_verify_key(self) -> SyftVerifyKey: - return self.partition.root_verify_key +@serializable(canonical_name="UserStashSQL", version=1) +class UserStash(ObjectStash[User]): @as_result(StashException, NotFoundException) def admin_user(self) -> User: # TODO: This returns only one user, the first user with the role ADMIN - admin_credentials = self.admin_verify_key() + admin_credentials = self.root_verify_key return self.get_by_role( credentials=admin_credentials, role=ServiceRole.ADMIN ).unwrap() - @as_result(StashException, NotFoundException) - def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> User: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - try: - return self.query_one(credentials=credentials, qks=qks).unwrap() - except NotFoundException as exc: - raise NotFoundException.from_exception( - exc, public_message=f"User {uid} not found" - ) - @as_result(StashException, NotFoundException) def get_by_reset_token(self, credentials: SyftVerifyKey, token: str) -> User: - qks = QueryKeys(qks=[PasswordResetTokenPartitionKey.with_obj(token)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"reset_token": token}, + ).unwrap() @as_result(StashException, NotFoundException) def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: - qks = QueryKeys(qks=[EmailPartitionKey.with_obj(email)]) + return self.get_one( + credentials=credentials, + filters={"email": email}, + ).unwrap() + @as_result(StashException) + def email_exists(self, email: str) -> bool: try: - return self.query_one(credentials=credentials, qks=qks).unwrap() - except NotFoundException as exc: - raise NotFoundException.from_exception( - exc, public_message=f"User {email} not found" - ) + self.get_by_email(credentials=self.root_verify_key, email=email).unwrap() + return True + except NotFoundException: + return False @as_result(StashException) - def email_exists(self, email: str) -> bool: + def verify_key_exists(self, verify_key: SyftVerifyKey) -> bool: try: - self.get_by_email(credentials=self.admin_verify_key(), email=email).unwrap() + self.get_by_verify_key( + credentials=self.root_verify_key, verify_key=verify_key + ).unwrap() return True except NotFoundException: return False @as_result(StashException, NotFoundException) def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User: - # TODO: Is this method correct? Should'nt it return a list of all member with a particular role? - qks = QueryKeys(qks=[RolePartitionKey.with_obj(role)]) - try: - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"role": role}, + ).unwrap() except NotFoundException as exc: private_msg = f"User with role {role} not found" raise NotFoundException.from_exception(exc, private_message=private_msg) @@ -94,28 +67,25 @@ def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User: def get_by_signing_key( self, credentials: SyftVerifyKey, signing_key: SyftSigningKey | str ) -> User: - if isinstance(signing_key, str): - signing_key = SyftSigningKey.from_string(signing_key) - - qks = QueryKeys(qks=[SigningKeyPartitionKey.with_obj(signing_key)]) - try: - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"signing_key": signing_key}, + ).unwrap() except NotFoundException as exc: private_msg = f"User with signing key {signing_key} not found" raise NotFoundException.from_exception(exc, private_message=private_msg) @as_result(StashException, NotFoundException) def get_by_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey | str + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> User: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - try: - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one( + credentials=credentials, + filters={"verify_key": verify_key}, + ).unwrap() + except NotFoundException as exc: private_msg = f"User with verify key {verify_key} not found" raise NotFoundException.from_exception(exc, private_message=private_msg) diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 7310decd85d..83a30bb670b 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..context import AuthedServiceContext @@ -20,11 +20,9 @@ @serializable(canonical_name="SyftImageRegistryService", version=1) class SyftImageRegistryService(AbstractService): - store: DocumentStore stash: SyftImageRegistryStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftImageRegistryStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/image_registry_stash.py b/packages/syft/src/syft/service/worker/image_registry_stash.py index 5c469da0825..cfb71b9848b 100644 --- a/packages/syft/src/syft/service/worker/image_registry_stash.py +++ b/packages/syft/src/syft/service/worker/image_registry_stash.py @@ -1,55 +1,34 @@ -# stdlib - -# third party - -# stdlib - # stdlib from typing import Literal # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result from .image_registry import SyftImageRegistry -__all__ = ["SyftImageRegistryStash"] - - -URLPartitionKey = PartitionKey(key="url", type_=str) - - -@serializable(canonical_name="SyftImageRegistryStash", version=1) -class SyftImageRegistryStash(NewBaseUIDStoreStash): - object_type = SyftImageRegistry - settings: PartitionSettings = PartitionSettings( - name=SyftImageRegistry.__canonical_name__, - object_type=SyftImageRegistry, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftImageRegistrySQLStash", version=1) +class SyftImageRegistryStash(ObjectStash[SyftImageRegistry]): @as_result(SyftException, StashException, NotFoundException) def get_by_url( self, credentials: SyftVerifyKey, url: str, - ) -> SyftImageRegistry | None: - qks = QueryKeys(qks=[URLPartitionKey.with_obj(url)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( - public_message=f"Image Registry with url {url} not found" - ) + ) -> SyftImageRegistry: + return self.get_one( + credentials=credentials, + filters={"url": url}, + ).unwrap() @as_result(SyftException, StashException) def delete_by_url(self, credentials: SyftVerifyKey, url: str) -> Literal[True]: - qk = URLPartitionKey.with_obj(url) - return super().delete(credentials=credentials, qk=qk).unwrap() + item = self.get_by_url(credentials=credentials, url=url).unwrap() + self.delete_by_uid(credentials=credentials, uid=item.id).unwrap() + + # TODO standardize delete return type + return True diff --git a/packages/syft/src/syft/service/worker/worker_image.py b/packages/syft/src/syft/service/worker/worker_image.py index e5c110a6e0e..99ca3ad6040 100644 --- a/packages/syft/src/syft/service/worker/worker_image.py +++ b/packages/syft/src/syft/service/worker/worker_image.py @@ -7,18 +7,20 @@ from ...server.credentials import SyftVerifyKey from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.uid import UID from .image_identifier import SyftWorkerImageIdentifier @serializable() -class SyftWorkerImage(SyftObject): +class SyftWorkerImageV1(SyftObject): __canonical_name__ = "SyftWorkerImage" __version__ = SYFT_OBJECT_VERSION_1 __attr_unique__ = ["config"] __attr_searchable__ = ["config", "image_hash", "created_by"] + __repr_attrs__ = [ "image_identifier", "image_hash", @@ -35,6 +37,40 @@ class SyftWorkerImage(SyftObject): image_hash: str | None = None built_at: DateTime | None = None + +@serializable() +class SyftWorkerImage(SyftObject): + __canonical_name__ = "SyftWorkerImage" + __version__ = SYFT_OBJECT_VERSION_2 + + __attr_unique__ = ["config_hash"] + __attr_searchable__ = [ + "config", + "image_hash", + "created_by", + "config_hash", + ] + + __repr_attrs__ = [ + "image_identifier", + "image_hash", + "created_at", + "built_at", + "config", + ] + + id: UID + config: WorkerConfig + created_by: SyftVerifyKey + created_at: DateTime = DateTime.now() + image_identifier: SyftWorkerImageIdentifier | None = None + image_hash: str | None = None + built_at: DateTime | None = None + + @property + def config_hash(self) -> str: + return self.config.hash() + @property def is_built(self) -> bool: """Returns True if the image has been built or is prebuilt.""" diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index 37622a36d71..a5f05f94dac 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -10,7 +10,7 @@ from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException @@ -31,11 +31,9 @@ @serializable(canonical_name="SyftWorkerImageService", version=1) class SyftWorkerImageService(AbstractService): - store: DocumentStore stash: SyftWorkerImageStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftWorkerImageStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index 983dfe9a8d6..dc220905839 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -2,16 +2,16 @@ # third party +# third party +from sqlalchemy.orm import Session + # relative from ...custom_worker.config import DockerWorkerConfig from ...custom_worker.config import WorkerConfig from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -20,32 +20,22 @@ from ..action.action_permissions import ActionPermission from .worker_image import SyftWorkerImage -WorkerConfigPK = PartitionKey(key="config", type_=WorkerConfig) - - -@serializable(canonical_name="SyftWorkerImageStash", version=1) -class SyftWorkerImageStash(NewBaseUIDStoreStash): - object_type = SyftWorkerImage - settings: PartitionSettings = PartitionSettings( - name=SyftWorkerImage.__canonical_name__, - object_type=SyftWorkerImage, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftWorkerImageSQLStash", version=1) +class SyftWorkerImageStash(ObjectStash[SyftWorkerImage]): @as_result(SyftException, StashException, NotFoundException) - def set( # type: ignore + @with_session + def set( self, credentials: SyftVerifyKey, obj: SyftWorkerImage, add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: Session = None, ) -> SyftWorkerImage: - add_permissions = [] if add_permissions is None else add_permissions - # By default syft images have all read permission + add_permissions = [] if add_permissions is None else add_permissions add_permissions.append( ActionObjectPermission(uid=obj.id, permission=ActionPermission.ALL_READ) ) @@ -67,6 +57,7 @@ def set( # type: ignore add_permissions=add_permissions, add_storage_permission=add_storage_permission, ignore_duplicates=ignore_duplicates, + session=session, ) .unwrap() ) @@ -85,7 +76,11 @@ def worker_config_exists( def get_by_worker_config( self, credentials: SyftVerifyKey, config: WorkerConfig ) -> SyftWorkerImage: - qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( + # TODO cannot search on fields containing objects + all_images = self.get_all(credentials=credentials).unwrap() + for image in all_images: + if image.config == config: + return image + raise NotFoundException( public_message=f"Worker Image with config {config} not found" ) diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 4ceced2bf26..55b103ba369 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -12,7 +12,7 @@ from ...custom_worker.k8s import IN_KUBERNETES from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -52,11 +52,9 @@ @serializable(canonical_name="SyftWorkerPoolService", version=1) class SyftWorkerPoolService(AbstractService): - store: DocumentStore stash: SyftWorkerPoolStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftWorkerPoolStash(store=store) self.image_stash = SyftWorkerImageStash(store=store) diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 94a4a0b8fab..81a4f4741d2 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -2,14 +2,14 @@ # third party +# third party +from sqlalchemy.orm import Session + # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -18,29 +18,22 @@ from ..action.action_permissions import ActionPermission from .worker_pool import WorkerPool -PoolNamePartitionKey = PartitionKey(key="name", type_=str) -PoolImageIDPartitionKey = PartitionKey(key="image_id", type_=UID) - - -@serializable(canonical_name="SyftWorkerPoolStash", version=1) -class SyftWorkerPoolStash(NewBaseUIDStoreStash): - object_type = WorkerPool - settings: PartitionSettings = PartitionSettings( - name=WorkerPool.__canonical_name__, - object_type=WorkerPool, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftWorkerPoolSQLStash", version=1) +class SyftWorkerPoolStash(ObjectStash[WorkerPool]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, pool_name: str) -> WorkerPool: - qks = QueryKeys(qks=[PoolNamePartitionKey.with_obj(pool_name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( + result = self.get_one( + credentials=credentials, + filters={"name": pool_name}, + ) + + return result.unwrap( public_message=f"WorkerPool with name {pool_name} not found" ) @as_result(StashException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -48,6 +41,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: Session = None, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -62,6 +56,7 @@ def set( add_permissions=add_permissions, add_storage_permission=add_storage_permission, ignore_duplicates=ignore_duplicates, + session=session, ) .unwrap() ) @@ -70,5 +65,7 @@ def set( def get_by_image_uid( self, credentials: SyftVerifyKey, image_uid: UID ) -> list[WorkerPool]: - qks = QueryKeys(qks=[PoolImageIDPartitionKey.with_obj(image_uid)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all( + credentials=credentials, + filters={"image_id": image_uid}, + ).unwrap() diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 625a88a46b4..300c0b6ed3d 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -13,7 +13,7 @@ from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.db import DBManager from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -39,11 +39,9 @@ @serializable(canonical_name="WorkerService", version=1) class WorkerService(AbstractService): - store: DocumentStore stash: WorkerStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = WorkerStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index b2b059ffec5..48a192ecd19 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -2,14 +2,15 @@ # third party +# third party +from sqlalchemy.orm import Session + # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash +from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -22,17 +23,10 @@ WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str) -@serializable(canonical_name="WorkerStash", version=1) -class WorkerStash(NewBaseUIDStoreStash): - object_type = SyftWorker - settings: PartitionSettings = PartitionSettings( - name=SyftWorker.__canonical_name__, object_type=SyftWorker - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="WorkerSQLStash", version=1) +class WorkerStash(ObjectStash[SyftWorker]): @as_result(StashException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -40,6 +34,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: Session = None, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -54,17 +49,11 @@ def set( add_permissions=add_permissions, ignore_duplicates=ignore_duplicates, add_storage_permission=add_storage_permission, + session=session, ) .unwrap() ) - @as_result(StashException, NotFoundException) - def get_worker_by_name( - self, credentials: SyftVerifyKey, worker_name: str - ) -> SyftWorker: - qks = QueryKeys(qks=[WorkerContainerNamePartitionKey.with_obj(worker_name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @as_result(StashException, NotFoundException) def update_consumer_state( self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState diff --git a/packages/syft/src/syft/store/__init__.py b/packages/syft/src/syft/store/__init__.py index 9260d13f956..e69de29bb2d 100644 --- a/packages/syft/src/syft/store/__init__.py +++ b/packages/syft/src/syft/store/__init__.py @@ -1,3 +0,0 @@ -# relative -from .mongo_document_store import MongoDict -from .mongo_document_store import MongoStoreConfig diff --git a/packages/syft/tests/mongomock/py.typed b/packages/syft/src/syft/store/db/__init__.py similarity index 100% rename from packages/syft/tests/mongomock/py.typed rename to packages/syft/src/syft/store/db/__init__.py diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py new file mode 100644 index 00000000000..cc82e5a3f4e --- /dev/null +++ b/packages/syft/src/syft/store/db/db.py @@ -0,0 +1,81 @@ +# stdlib +import logging +from typing import Generic +from typing import TypeVar +from urllib.parse import urlparse + +# third party +from pydantic import BaseModel +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID +from .schema import PostgresBase +from .schema import SQLiteBase + +logger = logging.getLogger(__name__) + + +@serializable(canonical_name="DBConfig", version=1) +class DBConfig(BaseModel): + @property + def connection_string(self) -> str: + raise NotImplementedError("Subclasses must implement this method.") + + @classmethod + def from_connection_string(cls, conn_str: str) -> "DBConfig": + # relative + from .postgres import PostgresDBConfig + from .sqlite import SQLiteDBConfig + + parsed = urlparse(conn_str) + if parsed.scheme == "postgresql": + return PostgresDBConfig( + host=parsed.hostname, + port=parsed.port, + user=parsed.username, + password=parsed.password, + database=parsed.path.lstrip("/"), + ) + elif parsed.scheme == "sqlite": + return SQLiteDBConfig(path=parsed.path) + else: + raise ValueError(f"Unsupported database scheme {parsed.scheme}") + + +ConfigT = TypeVar("ConfigT", bound=DBConfig) + + +class DBManager(Generic[ConfigT]): + def __init__( + self, + config: ConfigT, + server_uid: UID, + root_verify_key: SyftVerifyKey, + ) -> None: + self.config = config + self.root_verify_key = root_verify_key + self.server_uid = server_uid + self.engine = create_engine( + config.connection_string, + # json_serializer=dumps, + # json_deserializer=loads, + ) + logger.info(f"Connecting to {config.connection_string}") + self.sessionmaker = sessionmaker(bind=self.engine) + self.update_settings() + logger.info(f"Successfully connected to {config.connection_string}") + + def update_settings(self) -> None: + pass + + def init_tables(self, reset: bool = False) -> None: + Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase + + with self.sessionmaker().begin() as _: + if reset: + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/errors.py b/packages/syft/src/syft/store/db/errors.py new file mode 100644 index 00000000000..8f9a4ca048a --- /dev/null +++ b/packages/syft/src/syft/store/db/errors.py @@ -0,0 +1,31 @@ +# stdlib +import logging + +# third party +from sqlalchemy.exc import DatabaseError +from typing_extensions import Self + +# relative +from ..document_store_errors import StashException + +logger = logging.getLogger(__name__) + + +class StashDBException(StashException): + """ + See https://docs.sqlalchemy.org/en/20/errors.html#databaseerror + + StashDBException converts a SQLAlchemy DatabaseError into a StashException, + DatabaseErrors are errors thrown by the database itself, for example when a + query fails because a table is missing. + """ + + public_message = "There was an error retrieving data. Contact your admin." + + @classmethod + def from_sqlalchemy_error(cls, e: DatabaseError) -> Self: + logger.exception(e) + + error_type = e.__class__.__name__ + private_message = f"{error_type}: {str(e)}" + return cls(private_message=private_message) diff --git a/packages/syft/src/syft/store/db/postgres.py b/packages/syft/src/syft/store/db/postgres.py new file mode 100644 index 00000000000..630155e29de --- /dev/null +++ b/packages/syft/src/syft/store/db/postgres.py @@ -0,0 +1,48 @@ +# third party +from sqlalchemy import URL + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID +from .db import DBManager +from .sqlite import DBConfig + + +@serializable(canonical_name="PostgresDBConfig", version=1) +class PostgresDBConfig(DBConfig): + host: str + port: int + user: str + password: str + database: str + + @property + def connection_string(self) -> str: + return URL.create( + "postgresql", + username=self.user, + password=self.password, + host=self.host, + port=self.port, + database=self.database, + ).render_as_string(hide_password=False) + + +class PostgresDBManager(DBManager[PostgresDBConfig]): + def update_settings(self) -> None: + return super().update_settings() + + @classmethod + def random( + cls: type, + *, + config: PostgresDBConfig, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "PostgresDBManager": + root_verify_key = root_verify_key or SyftVerifyKey.generate() + server_uid = server_uid or UID() + return PostgresDBManager( + config=config, server_uid=server_uid, root_verify_key=root_verify_key + ) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py new file mode 100644 index 00000000000..04864a4a74a --- /dev/null +++ b/packages/syft/src/syft/store/db/query.py @@ -0,0 +1,383 @@ +# stdlib +from abc import ABC +from abc import abstractmethod +import enum +from typing import Any +from typing import Literal + +# third party +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Dialect +from sqlalchemy import Result +from sqlalchemy import Select +from sqlalchemy import Table +from sqlalchemy import func +from sqlalchemy.exc import DatabaseError +from sqlalchemy.orm import Session +from typing_extensions import Self + +# relative +from ...serde.json_serde import serialize_json +from ...server.credentials import SyftVerifyKey +from ...service.action.action_permissions import ActionObjectPermission +from ...service.action.action_permissions import ActionPermission +from ...service.user.user_roles import ServiceRole +from ...types.syft_object import SyftObject +from ...types.uid import UID +from .errors import StashDBException +from .schema import PostgresBase +from .schema import SQLiteBase + + +class FilterOperator(enum.Enum): + EQ = "eq" + CONTAINS = "contains" + + +class Query(ABC): + def __init__(self, object_type: type[SyftObject]) -> None: + self.object_type: type = object_type + self.table: Table = self._get_table(object_type) + self.stmt: Select = self.table.select() + + @abstractmethod + def _get_table(self, object_type: type[SyftObject]) -> Table: + raise NotImplementedError + + @staticmethod + def get_query_class(dialect: str | Dialect) -> "type[Query]": + if isinstance(dialect, Dialect): + dialect = dialect.name + + if dialect == "sqlite": + return SQLiteQuery + elif dialect == "postgresql": + return PostgresQuery + else: + raise ValueError(f"Unsupported dialect {dialect}") + + @classmethod + def create(cls, object_type: type[SyftObject], dialect: str | Dialect) -> "Query": + """Create a query object for the given object type and dialect.""" + query_class = cls.get_query_class(dialect) + return query_class(object_type) + + def execute(self, session: Session) -> Result: + """Execute the query using the given session.""" + try: + return session.execute(self.stmt) + except DatabaseError as e: + raise StashDBException.from_sqlalchemy_error(e) from e + + def with_permissions( + self, + credentials: SyftVerifyKey, + role: ServiceRole, + permission: ActionPermission = ActionPermission.READ, + ) -> Self: + """Add a permission check to the query. + + If the user has a role below DATA_OWNER, the query will be filtered to only include objects + that the user has the specified permission on. + + Args: + credentials (SyftVerifyKey): user verify key + role (ServiceRole): role of the user + permission (ActionPermission, optional): Type of permission to check for. + Defaults to ActionPermission.READ. + + Returns: + Self: The query object with the permission check applied + """ + if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): + return self + + ao_permission = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + permission=permission, + ) + + permission_clause = self._make_permissions_clause(ao_permission) + self.stmt = self.stmt.where(permission_clause) + + return self + + def filter(self, field: str, operator: str | FilterOperator, value: Any) -> Self: + """Add a filter to the query. + + example usage: + Query(User).filter("name", "eq", "Alice") + Query(User).filter("friends", "contains", "Bob") + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + filter = self._create_filter_clause(self.table, field, operator, value) + self.stmt = self.stmt.where(filter) + return self + + def filter_and(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self: + """Add filters to the query using an AND clause. + + example usage: + Query(User).filter_and( + ("name", "eq", "Alice"), + ("age", "eq", 30), + ) + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + filter_clauses = [ + self._create_filter_clause(self.table, field, operator, value) + for field, operator, value in filters + ] + + self.stmt = self.stmt.where(sa.and_(*filter_clauses)) + return self + + def filter_or(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self: + """Add filters to the query using an OR clause. + + example usage: + Query(User).filter_or( + ("name", "eq", "Alice"), + ("age", "eq", 30), + ) + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + filter_clauses = [ + self._create_filter_clause(self.table, field, operator, value) + for field, operator, value in filters + ] + + self.stmt = self.stmt.where(sa.or_(*filter_clauses)) + return self + + def _create_filter_clause( + self, + table: Table, + field: str, + operator: str | FilterOperator, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if isinstance(operator, str): + try: + operator = FilterOperator(operator.lower()) + except ValueError: + raise ValueError(f"Filter operator {operator} not supported") + + if operator == FilterOperator.EQ: + return self._eq_filter(table, field, value) + elif operator == FilterOperator.CONTAINS: + return self._contains_filter(table, field, value) + + def order_by( + self, + field: str | None = None, + order: Literal["asc", "desc"] | None = None, + ) -> Self: + """Add an order by clause to the query, with sensible defaults if field or order is not provided. + + Args: + field (Optional[str]): field to order by. If None, uses the default field. + order (Optional[Literal["asc", "desc"]]): Order to use ("asc" or "desc"). + Defaults to 'asc' if field is provided and order is not, or the default order otherwise. + + Raises: + ValueError: If the order is not "asc" or "desc" + + Returns: + Self: The query object with the order by clause applied. + """ + # Determine the field and order defaults if not provided + if field is None: + if hasattr(self.object_type, "__order_by__"): + default_field, default_order = self.object_type.__order_by__ + else: + default_field, default_order = "_created_at", "desc" + field = default_field + else: + # If field is provided but order is not, default to 'asc' + default_order = "asc" + order = order or default_order + + column = self._get_column(field) + + if isinstance(column.type, sa.JSON): + column = sa.cast(column, sa.String) + + if order.lower() == "asc": + self.stmt = self.stmt.order_by(column.asc()) + + elif order.lower() == "desc": + self.stmt = self.stmt.order_by(column.desc()) + else: + raise ValueError(f"Invalid sort order {order}") + + return self + + def limit(self, limit: int | None) -> Self: + """Add a limit clause to the query.""" + if limit is None: + return self + + if limit < 0: + raise ValueError("Limit must be a positive integer") + self.stmt = self.stmt.limit(limit) + + return self + + def offset(self, offset: int) -> Self: + """Add an offset clause to the query.""" + if offset < 0: + raise ValueError("Offset must be a positive integer") + + self.stmt = self.stmt.offset(offset) + return self + + @abstractmethod + def _make_permissions_clause( + self, + permission: ActionObjectPermission, + ) -> sa.sql.elements.BinaryExpression: + pass + + @abstractmethod + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + pass + + def _get_column(self, column: str) -> Column: + if column == "id": + return self.table.c.id + if column == "created_date" or column == "_created_at": + return self.table.c._created_at + elif column == "updated_date" or column == "_updated_at": + return self.table.c._updated_at + elif column == "deleted_date" or column == "_deleted_at": + return self.table.c._deleted_at + + return self.table.c.fields[column] + + +class SQLiteQuery(Query): + def _make_permissions_clause( + self, + permission: ActionObjectPermission, + ) -> sa.sql.elements.BinaryExpression: + permission_string = permission.permission_string + compound_permission_string = permission.compound_permission_string + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), + ) + + def _get_table(self, object_type: type[SyftObject]) -> Table: + cname = object_type.__canonical_name__ + if cname not in SQLiteBase.metadata.tables: + raise ValueError(f"Table for {cname} not found") + return SQLiteBase.metadata.tables[cname] + + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + field_value = serialize_json(value) + return table.c.fields[field].contains(func.json_quote(field_value)) + + def _eq_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if field == "id": + return table.c.id == UID(value) + + if "." in field: + # magic! + field = field.split(".") # type: ignore + + json_value = serialize_json(value) + return table.c.fields[field] == func.json_quote(json_value) + + +class PostgresQuery(Query): + def _make_permissions_clause( + self, permission: ActionObjectPermission + ) -> sa.sql.elements.BinaryExpression: + permission_string = [permission.permission_string] + compound_permission_string = [permission.compound_permission_string] + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), + ) + + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + field_value = serialize_json(value) + col = sa.cast(table.c.fields[field], sa.Text) + val = sa.cast(field_value, sa.Text) + return col.contains(val) + + def _get_table(self, object_type: type[SyftObject]) -> Table: + cname = object_type.__canonical_name__ + if cname not in PostgresBase.metadata.tables: + raise ValueError(f"Table for {cname} not found") + return PostgresBase.metadata.tables[cname] + + def _eq_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if field == "id": + return table.c.id == UID(value) + + if "." in field: + # magic! + field = field.split(".") # type: ignore + + json_value = serialize_json(value) + # NOTE: there might be a bug with casting everything to text + return table.c.fields[field].astext == sa.cast(json_value, sa.Text) diff --git a/packages/syft/src/syft/store/db/schema.py b/packages/syft/src/syft/store/db/schema.py new file mode 100644 index 00000000000..7f81e39802e --- /dev/null +++ b/packages/syft/src/syft/store/db/schema.py @@ -0,0 +1,87 @@ +# stdlib + +# stdlib +import uuid + +# third party +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Dialect +from sqlalchemy import Table +from sqlalchemy import TypeDecorator +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.types import JSON + +# relative +from ...types.syft_object import SyftObject +from ...types.uid import UID + + +class SQLiteBase(DeclarativeBase): + pass + + +class PostgresBase(DeclarativeBase): + pass + + +class UIDTypeDecorator(TypeDecorator): + """Converts between Syft UID and UUID.""" + + impl = sa.UUID + cache_ok = True + + def process_bind_param(self, value, dialect): # type: ignore + if value is not None: + return value.value + + def process_result_value(self, value, dialect): # type: ignore + if value is not None: + return UID(value) + + +def create_table( + object_type: type[SyftObject], + dialect: Dialect, +) -> Table: + """Create a table for a given SYftObject type, and add it to the metadata. + + To create the table on the database, you must call `Base.metadata.create_all(engine)`. + + Args: + object_type (type[SyftObject]): The type of the object to create a table for. + dialect (Dialect): The dialect of the database. + + Returns: + Table: The created table. + """ + table_name = object_type.__canonical_name__ + dialect_name = dialect.name + + fields_type = JSON if dialect_name == "sqlite" else postgresql.JSON + permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB + storage_permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB + + Base = SQLiteBase if dialect_name == "sqlite" else PostgresBase + + if table_name not in Base.metadata.tables: + Table( + object_type.__canonical_name__, + Base.metadata, + Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), + Column("fields", fields_type, default={}), + Column("permissions", permissions_type, default=[]), + Column( + "storage_permissions", + storage_permissions_type, + default=[], + ), + Column( + "_created_at", sa.DateTime, server_default=sa.func.now(), index=True + ), + Column("_updated_at", sa.DateTime, server_onupdate=sa.func.now()), + Column("_deleted_at", sa.DateTime, index=True), + ) + + return Base.metadata.tables[table_name] diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py new file mode 100644 index 00000000000..fbcf87ce47b --- /dev/null +++ b/packages/syft/src/syft/store/db/sqlite.py @@ -0,0 +1,61 @@ +# stdlib +from pathlib import Path +import tempfile +import uuid + +# third party +from pydantic import Field +import sqlalchemy as sa + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftSigningKey +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID +from .db import DBConfig +from .db import DBManager + + +@serializable(canonical_name="SQLiteDBConfig", version=1) +class SQLiteDBConfig(DBConfig): + filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") + path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir())) + + @property + def connection_string(self) -> str: + """ + NOTE in-memory sqlite is not shared between connections, so: + - using 2 workers (high/low) will not share a db + - re-using a connection (e.g. for a Job worker) will not share a db + """ + if self.path == Path("."): + # Use in-memory database, only for unittests + return "sqlite://" + filepath = self.path / self.filename + return f"sqlite:///{filepath.resolve()}" + + +class SQLiteDBManager(DBManager[SQLiteDBConfig]): + def update_settings(self) -> None: + connection = self.engine.connect() + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) + + @classmethod + def random( + cls, + *, + config: SQLiteDBConfig | None = None, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "SQLiteDBManager": + root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key + server_uid = server_uid or UID() + config = config or SQLiteDBConfig() + return SQLiteDBManager( + config=config, + server_uid=server_uid, + root_verify_key=root_verify_key, + ) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py new file mode 100644 index 00000000000..aec2a2ed9c5 --- /dev/null +++ b/packages/syft/src/syft/store/db/stash.py @@ -0,0 +1,848 @@ +# stdlib +from collections.abc import Callable +from functools import wraps +import inspect +from typing import Any +from typing import Generic +from typing import ParamSpec +from typing import Set # noqa: UP035 +from typing import cast +from typing import get_args + +# third party +from pydantic import ValidationError +import sqlalchemy as sa +from sqlalchemy import Row +from sqlalchemy import Table +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session +from typing_extensions import Self +from typing_extensions import TypeVar + +# relative +from ...serde.json_serde import deserialize_json +from ...serde.json_serde import is_json_primitive +from ...serde.json_serde import serialize_json +from ...server.credentials import SyftVerifyKey +from ...service.action.action_permissions import ActionObjectEXECUTE +from ...service.action.action_permissions import ActionObjectOWNER +from ...service.action.action_permissions import ActionObjectPermission +from ...service.action.action_permissions import ActionObjectREAD +from ...service.action.action_permissions import ActionObjectWRITE +from ...service.action.action_permissions import ActionPermission +from ...service.action.action_permissions import StoragePermission +from ...service.user.user_roles import ServiceRole +from ...types.errors import SyftException +from ...types.result import as_result +from ...types.syft_metaclass import Empty +from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.telemetry import instrument +from ..document_store_errors import NotFoundException +from ..document_store_errors import StashException +from .db import DBManager +from .query import Query +from .schema import PostgresBase +from .schema import SQLiteBase +from .schema import create_table +from .sqlite import SQLiteDBManager + +StashT = TypeVar("StashT", bound=SyftObject) +T = TypeVar("T") +P = ParamSpec("P") + + +def parse_filters(filter_dict: dict[str, Any] | None) -> list[tuple[str, str, Any]]: + # NOTE using django style filters, e.g. {"age__gt": 18} + if filter_dict is None: + return [] + filters = [] + for key, value in filter_dict.items(): + key_split = key.split("__") + # Operator is eq if not specified + if len(key_split) == 1: + field, operator = key, "eq" + elif len(key_split) == 2: + field, operator = key_split + filters.append((field, operator, value)) + return filters + + +def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore + """ + Decorator to inject a session into the function kwargs if it is not provided. + + TODO: This decorator is a temporary fix, we want to move to a DI approach instead: + move db connection and session to context, and pass context to all stash methods. + """ + + # inspect if the function has a session kwarg + sig = inspect.signature(func) + inject_session: bool = "session" in sig.parameters + + @wraps(func) + def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: + if inject_session and kwargs.get("session") is None: + with self.sessionmaker() as session: + kwargs["session"] = session + return func(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return wrapper # type: ignore + + +@instrument +class ObjectStash(Generic[StashT]): + allow_any_type: bool = False + + def __init__(self, store: DBManager) -> None: + self.db = store + self.object_type = self.get_object_type() + self.table = create_table(self.object_type, self.dialect) + self.sessionmaker = self.db.sessionmaker + + @property + def dialect(self) -> sa.engine.interfaces.Dialect: + return self.db.engine.dialect + + @classmethod + def get_object_type(cls) -> type[StashT]: + """ + Get the object type this stash is storing. This is the generic argument of the + ObjectStash class. + """ + generic_args = get_args(cls.__orig_bases__[0]) + if len(generic_args) != 1: + raise TypeError("ObjectStash must have a single generic argument") + elif not issubclass(generic_args[0], SyftObject): + raise TypeError( + "ObjectStash generic argument must be a subclass of SyftObject" + ) + return generic_args[0] + + @with_session + def __len__(self, session: Session = None) -> int: + return session.query(self.table).count() + + @classmethod + def random(cls, **kwargs: dict) -> Self: + """Create a random stash with a random server_uid and root_verify_key. Useful for development.""" + db_manager = SQLiteDBManager.random(**kwargs) + stash = cls(store=db_manager) + stash.db.init_tables() + return stash + + def _is_sqlite(self) -> bool: + return self.db.engine.dialect.name == "sqlite" + + @property + def server_uid(self) -> UID: + return self.db.server_uid + + @property + def root_verify_key(self) -> SyftVerifyKey: + return self.db.root_verify_key + + @property + def _data(self) -> list[StashT]: + return self.get_all(self.root_verify_key, has_permission=True).unwrap() + + def query(self, object_type: type[SyftObject] | None = None) -> Query: + """Creates a query for this stash's object type and SQL dialect.""" + object_type = object_type or self.object_type + return Query.create(object_type, self.dialect) + + @as_result(StashException) + def check_type(self, obj: T, type_: type) -> T: + if not isinstance(obj, type_): + raise StashException(f"{type(obj)} does not match required type: {type_}") + return cast(T, obj) + + @property + def session(self) -> Session: + return self.db.session + + def _print_query(self, stmt: sa.sql.select) -> None: + print( + stmt.compile( + compile_kwargs={"literal_binds": True}, + dialect=self.db.engine.dialect, + ) + ) + + @property + def unique_fields(self) -> list[str]: + return getattr(self.object_type, "__attr_unique__", []) + + @with_session + def is_unique(self, obj: StashT, session: Session = None) -> bool: + unique_fields = self.unique_fields + if not unique_fields: + return True + + filters = [] + for field_name in unique_fields: + field_value = getattr(obj, field_name, None) + if not is_json_primitive(field_value): + raise StashException( + f"Cannot check uniqueness of non-primitive field {field_name}" + ) + if field_value is None: + continue + filters.append((field_name, "eq", field_value)) + + query = self.query() + query = query.filter_or( + *filters, + ) + + results = query.execute(session).all() + + if len(results) > 1: + return False + elif len(results) == 1: + result = results[0] + return result.id == obj.id + return True + + @with_session + def exists( + self, credentials: SyftVerifyKey, uid: UID, session: Session = None + ) -> bool: + # TODO should be @as_result + # TODO needs credentials check? + # TODO use COUNT(*) instead of SELECT + query = self.query().filter("id", "eq", uid) + result = query.execute(session).first() + return result is not None + + @as_result(SyftException, StashException, NotFoundException) + @with_session + def get_by_uid( + self, + credentials: SyftVerifyKey, + uid: UID, + has_permission: bool = False, + session: Session = None, + ) -> StashT: + return self.get_one( + credentials=credentials, + filters={"id": uid}, + has_permission=has_permission, + session=session, + ).unwrap() + + def _get_field_filter( + self, + field_name: str, + field_value: Any, + table: Table | None = None, + ) -> sa.sql.elements.BinaryExpression: + table = table if table is not None else self.table + if field_name == "id": + uid_field_value = UID(field_value) + return table.c.id == uid_field_value + + json_value = serialize_json(field_value) + if self.db.engine.dialect.name == "sqlite": + return table.c.fields[field_name] == func.json_quote(json_value) + elif self.db.engine.dialect.name == "postgresql": + return table.c.fields[field_name].astext == cast(json_value, sa.String) + + @as_result(SyftException, StashException, NotFoundException) + def get_index( + self, credentials: SyftVerifyKey, index: int, has_permission: bool = False + ) -> StashT: + order_by, sort_order = self.object_type.__order_by__ + if index < 0: + index = -1 - index + sort_order = "desc" if sort_order == "asc" else "asc" + + items = self.get_all( + credentials, + has_permission=has_permission, + limit=1, + offset=index, + order_by=order_by, + sort_order=sort_order, + ).unwrap() + + if len(items) == 0: + raise NotFoundException(f"No item found at index {index}") + return items[0] + + def row_as_obj(self, row: Row) -> StashT: + # TODO make unwrappable serde + return deserialize_json(row.fields) + + @with_session + def get_role( + self, credentials: SyftVerifyKey, session: Session = None + ) -> ServiceRole: + # relative + from ...service.user.user import User + + Base = SQLiteBase if self._is_sqlite() else PostgresBase + + # TODO error handling + if Base.metadata.tables.get("User") is None: + # if User table does not exist, we assume the user is a guest + # this happens when we create stashes in tests + return ServiceRole.GUEST + + try: + query = self.query(User).filter("verify_key", "eq", credentials) + except Exception as e: + print("Error getting role", e) + raise e + + user = query.execute(session).first() + if user is None: + return ServiceRole.GUEST + + return self.row_as_obj(user).role + + def _get_permission_filter_from_permisson( + self, + permission: ActionObjectPermission, + ) -> sa.sql.elements.BinaryExpression: + permission_string = permission.permission_string + compound_permission_string = permission.compound_permission_string + + if self.db.engine.dialect.name == "postgresql": + permission_string = [permission_string] # type: ignore + compound_permission_string = [compound_permission_string] # type: ignore + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), + ) + + @with_session + def _apply_permission_filter( + self, + stmt: T, + *, + credentials: SyftVerifyKey, + permission: ActionPermission = ActionPermission.READ, + has_permission: bool = False, + session: Session = None, + ) -> T: + if has_permission: + # ignoring permissions + return stmt + role = self.get_role(credentials, session=session) + if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): + # admins and data owners have all permissions + return stmt + + action_object_permission = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + permission=permission, + ) + + stmt = stmt.where( + self._get_permission_filter_from_permisson( + permission=action_object_permission + ) + ) + return stmt + + @as_result(SyftException, StashException) + @with_session + def set( + self, + credentials: SyftVerifyKey, + obj: StashT, + add_permissions: list[ActionObjectPermission] | None = None, + add_storage_permission: bool = True, # TODO: check the default value + ignore_duplicates: bool = False, + session: Session = None, + ) -> StashT: + if not self.allow_any_type: + self.check_type(obj, self.object_type).unwrap() + uid = obj.id + + # check if the object already exists + if self.exists(credentials, uid) or not self.is_unique(obj): + if ignore_duplicates: + return obj + unique_fields_str = ", ".join(self.unique_fields) + raise StashException( + public_message=f"Duplication Key Error for {obj}.\n" + f"The fields that should be unique are {unique_fields_str}." + ) + + permissions = self.get_ownership_permissions(uid, credentials) + if add_permissions is not None: + add_permission_strings = [p.permission_string for p in add_permissions] + permissions.extend(add_permission_strings) + + storage_permissions = [] + if add_storage_permission: + storage_permissions.append( + self.server_uid.no_dash, + ) + + fields = serialize_json(obj) + try: + # check if the fields are deserializable + # TODO: Ideally, we want to make sure we don't serialize what we cannot deserialize + # and remove this check. + deserialize_json(fields) + except Exception as e: + raise StashException( + f"Error serializing object: {e}. Some fields are invalid." + ) + + # create the object with the permissions + stmt = self.table.insert().values( + id=uid, + fields=fields, + permissions=permissions, + storage_permissions=storage_permissions, + ) + session.execute(stmt) + session.commit() + return self.get_by_uid(credentials, uid, session=session).unwrap() + + @as_result(ValidationError, AttributeError) + def apply_partial_update( + self, original_obj: StashT, update_obj: SyftObject + ) -> StashT: + for key, value in update_obj.__dict__.items(): + if value is Empty: + continue + + if key in original_obj.__dict__: + setattr(original_obj, key, value) + else: + raise AttributeError( + f"{type(update_obj).__name__}.{key} not found in {type(original_obj).__name__}" + ) + + # validate the new fields + self.object_type.model_validate(original_obj) + return original_obj + + @as_result(StashException, NotFoundException, AttributeError, ValidationError) + @with_session + def update( + self, + credentials: SyftVerifyKey, + obj: StashT, + has_permission: bool = False, + session: Session = None, + ) -> StashT: + """ + NOTE: We cannot do partial updates on the database, + because we are using computed fields that are not known to the DB: + - serialize_json will add computed fields to the JSON stored in the database + - If we update a single field in the JSON, the computed fields can get out of sync. + - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. + """ + + if issubclass(type(obj), PartialSyftObject): + original_obj = self.get_by_uid( + credentials, obj.id, session=session + ).unwrap() + obj = self.apply_partial_update( + original_obj=original_obj, update_obj=obj + ).unwrap() + + # TODO has_permission is not used + if not self.is_unique(obj): + raise StashException(f"Some fields are not unique for {type(obj).__name__}") + + stmt = self.table.update().where(self._get_field_filter("id", obj.id)) + stmt = self._apply_permission_filter( + stmt, + credentials=credentials, + permission=ActionPermission.WRITE, + has_permission=has_permission, + session=session, + ) + fields = serialize_json(obj) + try: + deserialize_json(fields) + except Exception as e: + raise StashException( + f"Error serializing object: {e}. Some fields are invalid." + ) + stmt = stmt.values(fields=fields) + + result = session.execute(stmt) + session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {obj.id} not found or no permission to update." + ) + return self.get_by_uid(credentials, obj.id).unwrap() + + @as_result(StashException, NotFoundException) + @with_session + def delete_by_uid( + self, + credentials: SyftVerifyKey, + uid: UID, + has_permission: bool = False, + session: Session = None, + ) -> UID: + stmt = self.table.delete().where(self._get_field_filter("id", uid)) + stmt = self._apply_permission_filter( + stmt, + credentials=credentials, + permission=ActionPermission.WRITE, + has_permission=has_permission, + session=session, + ) + result = session.execute(stmt) + session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {uid} not found or no permission to delete." + ) + return uid + + @as_result(StashException) + @with_session + def get_one( + self, + credentials: SyftVerifyKey, + filters: dict[str, Any] | None = None, + has_permission: bool = False, + order_by: str | None = None, + sort_order: str | None = None, + offset: int = 0, + session: Session = None, + ) -> StashT: + """ + Get first objects from the stash, optionally filtered. + + Args: + credentials (SyftVerifyKey): credentials of the user + filters (dict[str, Any] | None, optional): dictionary of filters, + where the key is the field name and the value is the filter value. + Operators other than equals can be used in the key, + e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None. + has_permission (bool, optional): If True, overrides the permission check. + Defaults to False. + order_by (str | None, optional): If provided, the results will be ordered by this field. + If not provided, the default order and field defined on the SyftObject.__order_by__ are used. + Defaults to None. + sort_order (str | None, optional): "asc" or "desc" If not defined, + the default order defined on the SyftObject.__order_by__ is used. + Defaults to None. + offset (int, optional): offset the results. Defaults to 0. + + Returns: + list[StashT]: list of objects. + """ + query = self.query() + + if not has_permission: + role = self.get_role(credentials, session=session) + query = query.with_permissions(credentials, role) + + for field_name, operator, field_value in parse_filters(filters): + query = query.filter(field_name, operator, field_value) + + query = query.order_by(order_by, sort_order).offset(offset).limit(1) + result = query.execute(session).first() + if result is None: + raise NotFoundException(f"{self.object_type.__name__}: not found") + + return self.row_as_obj(result) + + @as_result(StashException) + @with_session + def get_all( + self, + credentials: SyftVerifyKey, + filters: dict[str, Any] | None = None, + has_permission: bool = False, + order_by: str | None = None, + sort_order: str | None = None, + limit: int | None = None, + offset: int = 0, + session: Session = None, + ) -> list[StashT]: + """ + Get all objects from the stash, optionally filtered. + + Args: + credentials (SyftVerifyKey): credentials of the user + filters (dict[str, Any] | None, optional): dictionary of filters, + where the key is the field name and the value is the filter value. + Operators other than equals can be used in the key, + e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None. + has_permission (bool, optional): If True, overrides the permission check. + Defaults to False. + order_by (str | None, optional): If provided, the results will be ordered by this field. + If not provided, the default order and field defined on the SyftObject.__order_by__ are used. + Defaults to None. + sort_order (str | None, optional): "asc" or "desc" If not defined, + the default order defined on the SyftObject.__order_by__ is used. + Defaults to None. + limit (int | None, optional): limit the number of results. Defaults to None. + offset (int, optional): offset the results. Defaults to 0. + + Returns: + list[StashT]: list of objects. + """ + query = self.query() + + if not has_permission: + role = self.get_role(credentials, session=session) + query = query.with_permissions(credentials, role) + + for field_name, operator, field_value in parse_filters(filters): + query = query.filter(field_name, operator, field_value) + + query = query.order_by(order_by, sort_order).limit(limit).offset(offset) + result = query.execute(session).all() + return [self.row_as_obj(row) for row in result] + + # PERMISSIONS + def get_ownership_permissions( + self, uid: UID, credentials: SyftVerifyKey + ) -> list[str]: + return [ + ActionObjectOWNER(uid=uid, credentials=credentials).permission_string, + ActionObjectWRITE(uid=uid, credentials=credentials).permission_string, + ActionObjectREAD(uid=uid, credentials=credentials).permission_string, + ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, + ] + + @as_result(NotFoundException) + @with_session + def add_permission( + self, + permission: ActionObjectPermission, + session: Session = None, + ignore_missing: bool = False, + ) -> None: + try: + existing_permissions = self._get_permissions_for_uid( + permission.uid, session=session + ).unwrap() + except NotFoundException: + if ignore_missing: + return None + raise + + existing_permissions.add(permission.permission_string) + + stmt = self.table.update().where(self.table.c.id == permission.uid) + stmt = stmt.values(permissions=list(existing_permissions)) + session.execute(stmt) + session.commit() + + return None + + @as_result(NotFoundException) + @with_session + def add_permissions( + self, + permissions: list[ActionObjectPermission], + ignore_missing: bool = False, + session: Session = None, + ) -> None: + for permission in permissions: + self.add_permission( + permission, session=session, ignore_missing=ignore_missing + ).unwrap() + return None + + @with_session + def remove_permission( + self, permission: ActionObjectPermission, session: Session = None + ) -> None: + # TODO not threadsafe + try: + permissions = self._get_permissions_for_uid(permission.uid).unwrap() + permissions.remove(permission.permission_string) + except (NotFoundException, KeyError): + # TODO add error handling to permissions + return None + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(permissions=list(permissions)) + ) + session.execute(stmt) + session.commit() + return None + + @with_session + def has_permission( + self, permission: ActionObjectPermission, session: Session = None + ) -> bool: + if self.get_role(permission.credentials, session=session) in ( + ServiceRole.ADMIN, + ServiceRole.DATA_OWNER, + ): + return True + return self.has_permissions([permission], session=session) + + @with_session + def has_permissions( + self, permissions: list[ActionObjectPermission], session: Session = None + ) -> bool: + # TODO: we should use a permissions table to check all permissions at once + + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self._get_permission_filter_from_permisson(permission=p), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ), + ) + result = session.execute(stmt).first() + return result is not None + + @as_result(StashException) + @with_session + def _get_permissions_for_uid(self, uid: UID, session: Session = None) -> Set[str]: # noqa: UP006 + stmt = select(self.table.c.permissions).where(self.table.c.id == uid) + result = session.execute(stmt).scalar_one_or_none() + if result is None: + raise NotFoundException(f"No permissions found for uid: {uid}") + return set(result) + + @as_result(StashException) + @with_session + def get_all_permissions(self, session: Session = None) -> dict[UID, Set[str]]: # noqa: UP006 + stmt = select(self.table.c.id, self.table.c.permissions) + results = session.execute(stmt).all() + return {UID(row.id): set(row.permissions) for row in results} + + # STORAGE PERMISSIONS + @with_session + def has_storage_permission( + self, permission: StoragePermission, session: Session = None + ) -> bool: + return self.has_storage_permissions([permission], session=session) + + @with_session + def has_storage_permissions( + self, permissions: list[StoragePermission], session: Session = None + ) -> bool: + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self.table.c.storage_permissions.contains( + p.server_uid.no_dash + if self._is_sqlite() + else [p.server_uid.no_dash] + ), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ) + ) + result = session.execute(stmt).first() + return result is not None + + @as_result(StashException) + @with_session + def get_all_storage_permissions( + self, session: Session = None + ) -> dict[UID, Set[UID]]: # noqa: UP006 + stmt = select(self.table.c.id, self.table.c.storage_permissions) + results = session.execute(stmt).all() + + return { + UID(row.id): {UID(uid) for uid in row.storage_permissions} + for row in results + } + + @as_result(NotFoundException) + @with_session + def add_storage_permissions( + self, + permissions: list[StoragePermission], + session: Session = None, + ignore_missing: bool = False, + ) -> None: + for permission in permissions: + self.add_storage_permission( + permission, session=session, ignore_missing=ignore_missing + ).unwrap() + + return None + + @as_result(NotFoundException) + @with_session + def add_storage_permission( + self, + permission: StoragePermission, + session: Session = None, + ignore_missing: bool = False, + ) -> None: + try: + existing_permissions = self._get_storage_permissions_for_uid( + permission.uid, session=session + ).unwrap() + except NotFoundException: + if ignore_missing: + return None + raise + + existing_permissions.add(permission.server_uid) + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=[str(uid) for uid in existing_permissions]) + ) + + session.execute(stmt) + + @with_session + def remove_storage_permission( + self, permission: StoragePermission, session: Session = None + ) -> None: + try: + permissions = self._get_storage_permissions_for_uid( + permission.uid, session=session + ).unwrap() + permissions.discard(permission.server_uid) + except NotFoundException: + # TODO add error handling to permissions + return None + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=[str(uid) for uid in permissions]) + ) + session.execute(stmt) + session.commit() + return None + + @as_result(StashException) + @with_session + def _get_storage_permissions_for_uid( + self, uid: UID, session: Session = None + ) -> Set[UID]: # noqa: UP006 + stmt = select(self.table.c.id, self.table.c.storage_permissions).where( + self.table.c.id == uid + ) + result = session.execute(stmt).first() + if result is None: + raise NotFoundException(f"No storage permissions found for uid: {uid}") + return {UID(uid) for uid in result.storage_permissions} diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py deleted file mode 100644 index ca0f3e1f33a..00000000000 --- a/packages/syft/src/syft/store/dict_document_store.py +++ /dev/null @@ -1,106 +0,0 @@ -# future -from __future__ import annotations - -# stdlib -from typing import Any - -# third party -from pydantic import Field - -# relative -from ..serde.serializable import serializable -from ..server.credentials import SyftVerifyKey -from ..types import uid -from .document_store import DocumentStore -from .document_store import StoreConfig -from .kv_document_store import KeyValueBackingStore -from .kv_document_store import KeyValueStorePartition -from .locks import LockingConfig -from .locks import ThreadingLockingConfig - - -@serializable(canonical_name="DictBackingStore", version=1) -class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] - # TODO: fix the mypy issue - """Dictionary-based Store core logic""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._ddtype = kwargs.get("ddtype", None) - - def __getitem__(self, key: Any) -> Any: - try: - value = super().__getitem__(key) - return value - except KeyError as e: - if self._ddtype: - return self._ddtype() - raise e - - -@serializable(canonical_name="DictStorePartition", version=1) -class DictStorePartition(KeyValueStorePartition): - """Dictionary-based StorePartition - - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: DictStoreConfig - DictStore specific configuration - """ - - def prune(self) -> None: - self.init_store().unwrap() - - -# the base document store is already a dict but we can change it later -@serializable(canonical_name="DictDocumentStore", version=1) -class DictDocumentStore(DocumentStore): - """Dictionary-based Document Store - - Parameters: - `store_config`: DictStoreConfig - Dictionary Store specific configuration, containing the store type and the backing store type - """ - - partition_type = DictStorePartition - - def __init__( - self, - server_uid: uid, - root_verify_key: SyftVerifyKey | None, - store_config: DictStoreConfig | None = None, - ) -> None: - if store_config is None: - store_config = DictStoreConfig() - super().__init__( - server_uid=server_uid, - root_verify_key=root_verify_key, - store_config=store_config, - ) - - def reset(self) -> None: - for partition in self.partitions.values(): - partition.prune() - - -@serializable() -class DictStoreConfig(StoreConfig): - __canonical_name__ = "DictStoreConfig" - """Dictionary-based configuration - - Parameters: - `store_type`: Type[DocumentStore] - The Document type used. Default: DictDocumentStore - `backing_store`: Type[KeyValueBackingStore] - The backend type used. Default: DictBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to ThreadingLockingConfig. - """ - - store_type: type[DocumentStore] = DictDocumentStore - backing_store: type[KeyValueBackingStore] = DictBackingStore - locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig) diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index cc97802a08b..98cccb82568 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -188,16 +188,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: def as_dict(self) -> dict[str, Any]: return {self.key: self.value} - @property - def as_dict_mongo(self) -> dict[str, Any]: - key = self.key - if key == "id": - key = "_id" - if self.type_list: - # We want to search inside the list of values - return {key: {"$in": self.value}} - return {key: self.value} - @serializable(canonical_name="PartitionKeysWithUID", version=1) class PartitionKeysWithUID(PartitionKeys): @@ -273,21 +263,6 @@ def as_dict(self) -> dict: qk_dict[qk_key] = qk_value return qk_dict - @property - def as_dict_mongo(self) -> dict: - qk_dict = {} - for qk in self.all: - qk_key = qk.key - qk_value = qk.value - if qk_key == "id": - qk_key = "_id" - if qk.type_list: - # We want to search inside the list of values - qk_dict[qk_key] = {"$in": qk_value} - else: - qk_dict[qk_key] = qk_value - return qk_dict - UIDPartitionKey = PartitionKey(key="id", type_=UID) @@ -627,30 +602,8 @@ def __init__( def __has_admin_permissions( self, settings: PartitionSettings ) -> Callable[[SyftVerifyKey], bool]: - # relative - from ..service.user.user import User - from ..service.user.user_roles import ServiceRole - from ..service.user.user_stash import UserStash - - # leave out UserStash to avoid recursion - # TODO: pass the callback from BaseStash instead of DocumentStore - # so that this works with UserStash after the sqlite thread fix is merged - if settings.object_type is User: - return lambda credentials: False - - user_stash = UserStash(store=self) - def has_admin_permissions(credentials: SyftVerifyKey) -> bool: - res = user_stash.get_by_verify_key( - credentials=credentials, - verify_key=credentials, - ) - - return ( - res.is_ok() - and (user := res.ok()) is not None - and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - ) + return credentials == self.root_verify_key return has_admin_permissions @@ -702,7 +655,6 @@ class NewBaseStash: partition: StorePartition def __init__(self, store: DocumentStore) -> None: - self.store = store self.partition = store.partition(type(self).settings) @as_result(StashException) diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 5d9c29c9d9d..d3e40372842 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -76,7 +76,7 @@ def update_with_context( raise SyftException(public_message=f"context {context}'s server is None") service = context.server.get_service(self.service_type) if hasattr(service, "stash"): - result = service.stash.update(credentials, obj) + result = service.stash.update(credentials, obj).unwrap() else: raise SyftException( public_message=f"service {service} does not have a stash" diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py deleted file mode 100644 index 9767059a1cb..00000000000 --- a/packages/syft/src/syft/store/mongo_client.py +++ /dev/null @@ -1,275 +0,0 @@ -# stdlib -import logging -from threading import Lock -from typing import Any - -# third party -from pymongo.collection import Collection as MongoCollection -from pymongo.database import Database as MongoDatabase -from pymongo.errors import ConnectionFailure -from pymongo.mongo_client import MongoClient as PyMongoClient - -# relative -from ..serde.serializable import serializable -from ..types.errors import SyftException -from ..types.result import as_result -from ..util.telemetry import TRACING_ENABLED -from .document_store import PartitionSettings -from .document_store import StoreClientConfig -from .document_store import StoreConfig -from .mongo_codecs import SYFT_CODEC_OPTIONS - -if TRACING_ENABLED: - try: - # third party - from opentelemetry.instrumentation.pymongo import PymongoInstrumentor - - PymongoInstrumentor().instrument() - message = "> Added OTEL PymongoInstrumentor" - print(message) - logger = logging.getLogger(__name__) - logger.info(message) - except Exception: # nosec - pass - - -@serializable(canonical_name="MongoStoreClientConfig", version=1) -class MongoStoreClientConfig(StoreClientConfig): - """ - Paramaters: - `hostname`: optional string - hostname or IP address or Unix domain socket path of a single mongod or mongos - instance to connect to, or a mongodb URI, or a list of hostnames (but no more - than one mongodb URI). If `host` is an IPv6 literal it must be enclosed in '[' - and ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for localhost). - Multihomed and round robin DNS addresses are **not** supported. - `port` : optional int - port number on which to connect - `directConnection`: bool - if ``True``, forces this client to connect directly to the specified MongoDB host - as a standalone. If ``false``, the client connects to the entire replica set of which - the given MongoDB host(s) is a part. If this is ``True`` and a mongodb+srv:// URI - or a URI containing multiple seeds is provided, an exception will be raised. - `maxPoolSize`: int. Default 100 - The maximum allowable number of concurrent connections to each connected server. - Requests to a server will block if there are `maxPoolSize` outstanding connections - to the requested server. Defaults to 100. Can be either 0 or None, in which case - there is no limit on the number of concurrent connections. - `minPoolSize` : int. Default 0 - The minimum required number of concurrent connections that the pool will maintain - to each connected server. Default is 0. - `maxIdleTimeMS`: int - The maximum number of milliseconds that a connection can remain idle in the pool - before being removed and replaced. Defaults to `None` (no limit). - `appname`: string - The name of the application that created this MongoClient instance. The server will - log this value upon establishing each connection. It is also recorded in the slow - query log and profile collections. - `maxConnecting`: optional int - The maximum number of connections that each pool can establish concurrently. - Defaults to `2`. - `timeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait when executing an operation - (including retry attempts) before raising a timeout error. ``0`` or ``None`` means - no timeout. - `socketTimeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait for a response after sending - an ordinary (non-monitoring) database operation before concluding that a network error - has occurred. ``0`` or ``None`` means no timeout. Defaults to ``None`` (no timeout). - `connectTimeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait during server monitoring when - connecting a new socket to a server before concluding the server is unavailable. - ``0`` or ``None`` means no timeout. Defaults to ``20000`` (20 seconds). - `serverSelectionTimeoutMS`: (integer) - Controls how long (in milliseconds) the driver will wait to find an available, appropriate - server to carry out a database operation; while it is waiting, multiple server monitoring - operations may be carried out, each controlled by `connectTimeoutMS`. - Defaults to ``120000`` (120 seconds). - `waitQueueTimeoutMS`: (integer or None) - How long (in milliseconds) a thread will wait for a socket from the pool if the pool - has no free sockets. Defaults to ``None`` (no timeout). - `heartbeatFrequencyMS`: (optional) - The number of milliseconds between periodic server checks, or None to accept the default - frequency of 10 seconds. - # Auth - username: str - Database username - password: str - Database pass - authSource: str - The database to authenticate on. - Defaults to the database specified in the URI, if provided, or to “admin”. - tls: bool - If True, create the connection to the server using transport layer security. - Defaults to False. - # Testing and connection reuse - client: Optional[PyMongoClient] - If provided, this client is reused. Default = None - - """ - - # Connection - hostname: str | None = "127.0.0.1" - port: int | None = None - directConnection: bool = False - maxPoolSize: int = 200 - minPoolSize: int = 0 - maxIdleTimeMS: int | None = None - maxConnecting: int = 3 - timeoutMS: int = 0 - socketTimeoutMS: int = 0 - connectTimeoutMS: int = 20000 - serverSelectionTimeoutMS: int = 120000 - waitQueueTimeoutMS: int | None = None - heartbeatFrequencyMS: int = 10000 - appname: str = "pysyft" - # Auth - username: str | None = None - password: str | None = None - authSource: str = "admin" - tls: bool | None = False - # Testing and connection reuse - client: Any = None - - # this allows us to have one connection per `Server` object - # in the MongoClientCache - server_obj_python_id: int | None = None - - -class MongoClientCache: - __client_cache__: dict[int, type["MongoClient"] | None] = {} - _lock: Lock = Lock() - - @classmethod - def from_cache(cls, config: MongoStoreClientConfig) -> PyMongoClient | None: - return cls.__client_cache__.get(hash(str(config)), None) - - @classmethod - def set_cache(cls, config: MongoStoreClientConfig, client: PyMongoClient) -> None: - with cls._lock: - cls.__client_cache__[hash(str(config))] = client - - -class MongoClient: - client: PyMongoClient = None - - def __init__(self, config: MongoStoreClientConfig, cache: bool = True) -> None: - self.config = config - if config.client is not None: - self.client = config.client - elif cache: - self.client = MongoClientCache.from_cache(config=config) - - if not cache or self.client is None: - self.connect(config=config).unwrap() - - @as_result(SyftException) - def connect(self, config: MongoStoreClientConfig) -> bool: - self.client = PyMongoClient( - # Connection - host=config.hostname, - port=config.port, - directConnection=config.directConnection, - maxPoolSize=config.maxPoolSize, - minPoolSize=config.minPoolSize, - maxIdleTimeMS=config.maxIdleTimeMS, - maxConnecting=config.maxConnecting, - timeoutMS=config.timeoutMS, - socketTimeoutMS=config.socketTimeoutMS, - connectTimeoutMS=config.connectTimeoutMS, - serverSelectionTimeoutMS=config.serverSelectionTimeoutMS, - waitQueueTimeoutMS=config.waitQueueTimeoutMS, - heartbeatFrequencyMS=config.heartbeatFrequencyMS, - appname=config.appname, - # Auth - username=config.username, - password=config.password, - authSource=config.authSource, - tls=config.tls, - uuidRepresentation="standard", - ) - MongoClientCache.set_cache(config=config, client=self.client) - try: - # Check if mongo connection is still up - self.client.admin.command("ping") - except ConnectionFailure as e: - self.client = None - raise SyftException.from_exception(e) - - return True - - @as_result(SyftException) - def with_db(self, db_name: str) -> MongoDatabase: - try: - return self.client[db_name] - except BaseException as e: - raise SyftException.from_exception(e) - - @as_result(SyftException) - def with_collection( - self, - collection_settings: PartitionSettings, - store_config: StoreConfig, - collection_name: str | None = None, - ) -> MongoCollection: - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_name = ( - collection_name - if collection_name is not None - else collection_settings.name - ) - collection = db.get_collection( - name=collection_name, codec_options=SYFT_CODEC_OPTIONS - ) - except BaseException as e: - raise SyftException.from_exception(e) - - return collection - - @as_result(SyftException) - def with_collection_permissions( - self, collection_settings: PartitionSettings, store_config: StoreConfig - ) -> MongoCollection: - """ - For each collection, create a corresponding collection - that store the permissions to the data in that collection - """ - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_permissions_name: str = collection_settings.name + "_permissions" - collection_permissions = db.get_collection( - name=collection_permissions_name, codec_options=SYFT_CODEC_OPTIONS - ) - except BaseException as e: - raise SyftException.from_exception(e) - return collection_permissions - - @as_result(SyftException) - def with_collection_storage_permissions( - self, collection_settings: PartitionSettings, store_config: StoreConfig - ) -> MongoCollection: - """ - For each collection, create a corresponding collection - that store the permissions to the data in that collection - """ - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_storage_permissions_name: str = ( - collection_settings.name + "_storage_permissions" - ) - storage_permissons_collection = db.get_collection( - name=collection_storage_permissions_name, - codec_options=SYFT_CODEC_OPTIONS, - ) - except BaseException as e: - raise SyftException.from_exception(e) - - return storage_permissons_collection - - def close(self) -> None: - self.client.close() - MongoClientCache.__client_cache__.pop(hash(str(self.config)), None) diff --git a/packages/syft/src/syft/store/mongo_codecs.py b/packages/syft/src/syft/store/mongo_codecs.py deleted file mode 100644 index 08b7fa63562..00000000000 --- a/packages/syft/src/syft/store/mongo_codecs.py +++ /dev/null @@ -1,31 +0,0 @@ -# stdlib -from typing import Any - -# third party -from bson import CodecOptions -from bson.binary import Binary -from bson.binary import USER_DEFINED_SUBTYPE -from bson.codec_options import TypeDecoder -from bson.codec_options import TypeRegistry - -# relative -from ..serde.deserialize import _deserialize -from ..serde.serialize import _serialize - - -def fallback_syft_encoder(value: object) -> Binary: - return Binary(_serialize(value, to_bytes=True), USER_DEFINED_SUBTYPE) - - -class SyftMongoBinaryDecoder(TypeDecoder): - bson_type = Binary - - def transform_bson(self, value: Any) -> Any: - if value.subtype == USER_DEFINED_SUBTYPE: - return _deserialize(value, from_bytes=True) - return value - - -syft_codecs = [SyftMongoBinaryDecoder()] -syft_type_registry = TypeRegistry(syft_codecs, fallback_encoder=fallback_syft_encoder) -SYFT_CODEC_OPTIONS = CodecOptions(type_registry=syft_type_registry) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py deleted file mode 100644 index 805cc042cdf..00000000000 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ /dev/null @@ -1,963 +0,0 @@ -# stdlib -from collections.abc import Callable -from typing import Any -from typing import Set # noqa: UP035 - -# third party -from pydantic import Field -from pymongo import ASCENDING -from pymongo.collection import Collection as MongoCollection -from typing_extensions import Self - -# relative -from ..serde.deserialize import _deserialize -from ..serde.serializable import serializable -from ..serde.serialize import _serialize -from ..server.credentials import SyftVerifyKey -from ..service.action.action_permissions import ActionObjectEXECUTE -from ..service.action.action_permissions import ActionObjectOWNER -from ..service.action.action_permissions import ActionObjectPermission -from ..service.action.action_permissions import ActionObjectREAD -from ..service.action.action_permissions import ActionObjectWRITE -from ..service.action.action_permissions import ActionPermission -from ..service.action.action_permissions import StoragePermission -from ..service.context import AuthedServiceContext -from ..service.response import SyftSuccess -from ..types.errors import SyftException -from ..types.result import as_result -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import StorableObjectType -from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftObject -from ..types.transforms import TransformContext -from ..types.transforms import transform -from ..types.transforms import transform_method -from ..types.uid import UID -from .document_store import DocumentStore -from .document_store import PartitionKey -from .document_store import PartitionSettings -from .document_store import QueryKey -from .document_store import QueryKeys -from .document_store import StoreConfig -from .document_store import StorePartition -from .document_store_errors import NotFoundException -from .kv_document_store import KeyValueBackingStore -from .locks import LockingConfig -from .locks import NoLockingConfig -from .mongo_client import MongoClient -from .mongo_client import MongoStoreClientConfig - - -@serializable() -class MongoDict(SyftBaseObject): - __canonical_name__ = "MongoDict" - __version__ = SYFT_OBJECT_VERSION_1 - - keys: list[Any] - values: list[Any] - - @property - def dict(self) -> dict[Any, Any]: - return dict(zip(self.keys, self.values)) - - @classmethod - def from_dict(cls, input: dict) -> Self: - return cls(keys=list(input.keys()), values=list(input.values())) - - def __repr__(self) -> str: - return self.dict.__repr__() - - -class MongoBsonObject(StorableObjectType, dict): - pass - - -def _repr_debug_(value: Any) -> str: - if hasattr(value, "_repr_debug_"): - return value._repr_debug_() - return repr(value) - - -def to_mongo(context: TransformContext) -> TransformContext: - output = {} - if context.obj: - unique_keys_dict = context.obj._syft_unique_keys_dict() - search_keys_dict = context.obj._syft_searchable_keys_dict() - all_dict = unique_keys_dict - all_dict.update(search_keys_dict) - for k in all_dict: - value = getattr(context.obj, k, "") - # if the value is a method, store its value - if callable(value): - output[k] = value() - else: - output[k] = value - - output["__canonical_name__"] = context.obj.__canonical_name__ - output["__version__"] = context.obj.__version__ - output["__blob__"] = _serialize(context.obj, to_bytes=True) - output["__arepr__"] = _repr_debug_(context.obj) # a comes first in alphabet - - if context.output and "id" in context.output: - output["_id"] = context.output["id"] - - context.output = output - - return context - - -@transform(SyftObject, MongoBsonObject) -def syft_obj_to_mongo() -> list[Callable]: - return [to_mongo] - - -@transform_method(MongoBsonObject, SyftObject) -def from_mongo( - storage_obj: dict, context: TransformContext | None = None -) -> SyftObject: - return _deserialize(storage_obj["__blob__"], from_bytes=True) - - -@serializable(attrs=["storage_type"], canonical_name="MongoStorePartition", version=1) -class MongoStorePartition(StorePartition): - """Mongo StorePartition - - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for partitioning and indexing. - `store_config`: MongoStoreConfig - Mongo specific configuration - """ - - storage_type: type[StorableObjectType] = MongoBsonObject - - @as_result(SyftException) - def init_store(self) -> bool: - super().init_store().unwrap() - client = MongoClient(config=self.store_config.client_config) - self._collection = client.with_collection( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - self._permissions = client.with_collection_permissions( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - self._storage_permissions = client.with_collection_storage_permissions( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - return self._create_update_index().unwrap() - - # Potentially thread-unsafe methods. - # - # CAUTION: - # * Don't use self.lock here. - # * Do not call the public thread-safe methods here(with locking). - # These methods are called from the public thread-safe API, and will hang the process. - - @as_result(SyftException) - def _create_update_index(self) -> bool: - """Create or update mongo database indexes""" - collection: MongoCollection = self.collection.unwrap() - - def check_index_keys( - current_keys: list[tuple[str, int]], new_index_keys: list[tuple[str, int]] - ) -> bool: - current_keys.sort() - new_index_keys.sort() - return current_keys == new_index_keys - - syft_obj = self.settings.object_type - - unique_attrs = getattr(syft_obj, "__attr_unique__", []) - object_name = syft_obj.__canonical_name__ - - new_index_keys = [(attr, ASCENDING) for attr in unique_attrs] - - try: - current_indexes = collection.index_information() - except BaseException as e: - raise SyftException.from_exception(e) - index_name = f"{object_name}_index_name" - - current_index_keys = current_indexes.get(index_name, None) - - if current_index_keys is not None: - keys_same = check_index_keys(current_index_keys["key"], new_index_keys) - if keys_same: - return True - - # Drop current index, since incompatible with current object - try: - collection.drop_index(index_or_name=index_name) - except Exception: - raise SyftException( - public_message=( - f"Failed to drop index for object: {object_name}" - f" with index keys: {current_index_keys}" - ) - ) - - # If no new indexes, then skip index creation - if len(new_index_keys) == 0: - return True - - try: - collection.create_index(new_index_keys, unique=True, name=index_name) - except Exception: - raise SyftException( - public_message=f"Failed to create index for {object_name} with index keys: {new_index_keys}" - ) - - return True - - @property - @as_result(SyftException) - def collection(self) -> MongoCollection: - if not hasattr(self, "_collection"): - self.init_store().unwrap() - return self._collection - - @property - @as_result(SyftException) - def permissions(self) -> MongoCollection: - if not hasattr(self, "_permissions"): - self.init_store().unwrap() - return self._permissions - - @property - @as_result(SyftException) - def storage_permissions(self) -> MongoCollection: - if not hasattr(self, "_storage_permissions"): - self.init_store().unwrap() - return self._storage_permissions - - @as_result(SyftException) - def set(self, *args: Any, **kwargs: Any) -> SyftObject: - return self._set(*args, **kwargs).unwrap() - - @as_result(SyftException) - def _set( - self, - credentials: SyftVerifyKey, - obj: SyftObject, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObject: - # TODO: Refactor this function since now it's doing both set and - # update at the same time - write_permission = ActionObjectWRITE(uid=obj.id, credentials=credentials) - can_write: bool = self.has_permission(write_permission) - - store_query_key: QueryKey = self.settings.store_key.with_obj(obj) - collection: MongoCollection = self.collection.unwrap() - - store_key_exists = ( - collection.find_one(store_query_key.as_dict_mongo) is not None - ) - if (not store_key_exists) and (not self.item_keys_exist(obj, collection)): - # attempt to claim ownership for writing - can_write = self.take_ownership( - uid=obj.id, credentials=credentials - ).unwrap() - elif not ignore_duplicates: - unique_query_keys: QueryKeys = self.settings.unique_keys.with_obj(obj) - keys = ", ".join(f"`{key.key}`" for key in unique_query_keys.all) - raise SyftException( - public_message=f"Duplication Key Error for {obj}.\nThe fields that should be unique are {keys}." - ) - else: - # we are not throwing an error, because we are ignoring duplicates - # we are also not writing though - return obj - - if not can_write: - raise SyftException( - public_message=f"No permission to write object with id {obj.id}" - ) - - storage_obj = obj.to(self.storage_type) - - collection.insert_one(storage_obj) - - # adding permissions - read_permission = ActionObjectPermission( - uid=obj.id, - credentials=credentials, - permission=ActionPermission.READ, - ) - self.add_permission(read_permission) - - if add_permissions is not None: - self.add_permissions(add_permissions) - - if add_storage_permission: - self.add_storage_permission( - StoragePermission( - uid=obj.id, - server_uid=self.server_uid, - ) - ) - - return obj - - def item_keys_exist(self, obj: SyftObject, collection: MongoCollection) -> bool: - qks: QueryKeys = self.settings.unique_keys.with_obj(obj) - query = {"$or": [{k: v} for k, v in qks.as_dict_mongo.items()]} - res = collection.find_one(query) - return res is not None - - @as_result(SyftException) - def _update( - self, - credentials: SyftVerifyKey, - qk: QueryKey, - obj: SyftObject, - has_permission: bool = False, - overwrite: bool = False, - allow_missing_keys: bool = False, - ) -> SyftObject: - collection: MongoCollection = self.collection.unwrap() - - # TODO: optimize the update. The ID should not be overwritten, - # but the qk doesn't necessarily have to include the `id` field either. - - prev_obj = self._get_all_from_store(credentials, QueryKeys(qks=[qk])).unwrap() - if len(prev_obj) == 0: - raise SyftException( - public_message=f"Failed to update missing values for query key: {qk} for type {type(obj)}" - ) - - prev_obj = prev_obj[0] - if has_permission or self.has_permission( - ActionObjectWRITE(uid=prev_obj.id, credentials=credentials) - ): - for key, value in obj.to_dict(exclude_empty=True).items(): - # we don't want to overwrite Mongo's "id_" or Syft's "id" on update - if key == "id": - # protected field - continue - - # Overwrite the value if the key is already present - setattr(prev_obj, key, value) - - # Create the Mongo object - storage_obj = prev_obj.to(self.storage_type) - - try: - collection.update_one( - filter=qk.as_dict_mongo, update={"$set": storage_obj} - ) - except Exception: - raise SyftException(f"Failed to update obj: {obj} with qk: {qk}") - - return prev_obj - else: - raise SyftException(f"Failed to update obj {obj}, you have no permission") - - @as_result(SyftException) - def _find_index_or_search_keys( - self, - credentials: SyftVerifyKey, - index_qks: QueryKeys, - search_qks: QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[SyftObject]: - # TODO: pass index as hint to find method - qks = QueryKeys(qks=(list(index_qks.all) + list(search_qks.all))) - return self._get_all_from_store( - credentials=credentials, qks=qks, order_by=order_by - ).unwrap() - - @property - def data(self) -> dict: - values: list = self._all(credentials=None, has_permission=True).unwrap() - return {v.id: v for v in values} - - @as_result(SyftException) - def _get( - self, - uid: UID, - credentials: SyftVerifyKey, - has_permission: bool | None = False, - ) -> SyftObject: - qks = QueryKeys.from_dict({"id": uid}) - res = self._get_all_from_store( - credentials, qks, order_by=None, has_permission=has_permission - ).unwrap() - if len(res) == 0: - raise NotFoundException - else: - return res[0] - - @as_result(SyftException) - def _get_all_from_store( - self, - credentials: SyftVerifyKey, - qks: QueryKeys, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[SyftObject]: - collection = self.collection.unwrap() - - if order_by is not None: - storage_objs = collection.find(filter=qks.as_dict_mongo).sort(order_by.key) - else: - _default_key = "_id" - storage_objs = collection.find(filter=qks.as_dict_mongo).sort(_default_key) - - syft_objs = [] - for storage_obj in storage_objs: - obj = self.storage_type(storage_obj) - transform_context = TransformContext(output={}, obj=obj) - - syft_obj = obj.to(self.settings.object_type, transform_context) - if has_permission or self.has_permission( - ActionObjectREAD(uid=syft_obj.id, credentials=credentials) - ): - syft_objs.append(syft_obj) - - return syft_objs - - @as_result(SyftException) - def _delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False - ) -> SyftSuccess: - if not ( - has_permission - or self.has_permission( - ActionObjectWRITE(uid=qk.value, credentials=credentials) - ) - ): - raise SyftException( - public_message=f"You don't have permission to delete object with qk: {qk}" - ) - - collection = self.collection.unwrap() - collection_permissions: MongoCollection = self.permissions.unwrap() - - qks = QueryKeys(qks=qk) - # delete the object - result = collection.delete_one(filter=qks.as_dict_mongo) - # delete the object's permission - result_permission = collection_permissions.delete_one(filter=qks.as_dict_mongo) - if result.deleted_count == 1 and result_permission.deleted_count == 1: - return SyftSuccess(message="Object and its permission are deleted") - elif result.deleted_count == 0: - raise SyftException(public_message=f"Failed to delete object with qk: {qk}") - else: - raise SyftException( - public_message=f"Object with qk: {qk} was deleted, but failed to delete its corresponding permission" - ) - - def has_permission(self, permission: ActionObjectPermission) -> bool: - """Check if the permission is inside the permission collection""" - collection_permissions_status = self.permissions - if collection_permissions_status.is_err(): - return False - collection_permissions: MongoCollection = collection_permissions_status.ok() - - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - - if permissions is None: - return False - - if ( - permission.credentials - and self.root_verify_key.verify == permission.credentials.verify - ): - return True - - if ( - permission.credentials - and self.has_admin_permissions is not None - and self.has_admin_permissions(permission.credentials) - ): - return True - - if permission.permission_string in permissions["permissions"]: - return True - - # check ALL_READ permission - if ( - permission.permission == ActionPermission.READ - and ActionObjectPermission( - permission.uid, ActionPermission.ALL_READ - ).permission_string - in permissions["permissions"] - ): - return True - - return False - - @as_result(SyftException) - def _get_permissions_for_uid(self, uid: UID) -> Set[str]: # noqa: UP006 - collection_permissions = self.permissions.unwrap() - permissions: dict | None = collection_permissions.find_one({"_id": uid}) - if permissions is None: - raise SyftException( - public_message=f"Permissions for object with UID {uid} not found!" - ) - return set(permissions["permissions"]) - - @as_result(SyftException) - def get_all_permissions(self) -> dict[UID, Set[str]]: # noqa: UP006 - # Returns a dictionary of all permissions {object_uid: {*permissions}} - collection_permissions: MongoCollection = self.permissions.unwrap() - permissions = collection_permissions.find({}) - permissions_dict = {} - for permission in permissions: - permissions_dict[permission["_id"]] = permission["permissions"] - return permissions_dict - - def add_permission(self, permission: ActionObjectPermission) -> None: - collection_permissions = self.permissions.unwrap() - - # find the permissions for the given permission.uid - # e.g. permissions = {"_id": "7b88fdef6bff42a8991d294c3d66f757", - # "permissions": set(["permission_str_1", "permission_str_2"]}} - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - if permissions is None: - # Permission doesn't exist, add a new one - collection_permissions.insert_one( - { - "_id": permission.uid, - "permissions": {permission.permission_string}, - } - ) - else: - # update the permissions with the new permission string - permission_strings: set = permissions["permissions"] - permission_strings.add(permission.permission_string) - collection_permissions.update_one( - {"_id": permission.uid}, {"$set": {"permissions": permission_strings}} - ) - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - for permission in permissions: - self.add_permission(permission) - - def remove_permission(self, permission: ActionObjectPermission) -> None: - collection_permissions = self.permissions.unwrap() - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - if permissions is None: - raise SyftException( - public_message=f"permission with UID {permission.uid} not found!" - ) - permissions_strings: set = permissions["permissions"] - if permission.permission_string in permissions_strings: - permissions_strings.remove(permission.permission_string) - if len(permissions_strings) > 0: - collection_permissions.update_one( - {"_id": permission.uid}, - {"$set": {"permissions": permissions_strings}}, - ) - else: - collection_permissions.delete_one({"_id": permission.uid}) - else: - raise SyftException( - public_message=f"the permission {permission.permission_string} does not exist!" - ) - - def add_storage_permission(self, storage_permission: StoragePermission) -> None: - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": storage_permission.uid} - ) - if storage_permissions is None: - # Permission doesn't exist, add a new one - storage_permissions_collection.insert_one( - { - "_id": storage_permission.uid, - "server_uids": {storage_permission.server_uid}, - } - ) - else: - # update the permissions with the new permission string - server_uids: set = storage_permissions["server_uids"] - server_uids.add(storage_permission.server_uid) - storage_permissions_collection.update_one( - {"_id": storage_permission.uid}, - {"$set": {"server_uids": server_uids}}, - ) - - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission) - - def has_storage_permission(self, permission: StoragePermission) -> bool: # type: ignore - """Check if the storage_permission is inside the storage_permission collection""" - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": permission.uid} - ) - if storage_permissions is None or "server_uids" not in storage_permissions: - return False - return permission.server_uid in storage_permissions["server_uids"] - - def remove_storage_permission(self, storage_permission: StoragePermission) -> None: - storage_permissions_collection = self.storage_permissions.unwrap() - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": storage_permission.uid} - ) - if storage_permissions is None: - raise SyftException( - public_message=f"storage permission with UID {storage_permission.uid} not found!" - ) - server_uids: set = storage_permissions["server_uids"] - if storage_permission.server_uid in server_uids: - server_uids.remove(storage_permission.server_uid) - storage_permissions_collection.update_one( - {"_id": storage_permission.uid}, - {"$set": {"server_uids": server_uids}}, - ) - else: - raise SyftException( - public_message=( - f"the server_uid {storage_permission.server_uid} does not exist in the storage permission!" - ) - ) - - @as_result(SyftException) - def _get_storage_permissions_for_uid(self, uid: UID) -> Set[UID]: # noqa: UP006 - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": uid} - ) - if storage_permissions is None: - raise SyftException( - public_message=f"Storage permissions for object with UID {uid} not found!" - ) - return set(storage_permissions["server_uids"]) - - @as_result(SyftException) - def get_all_storage_permissions( - self, - ) -> dict[UID, Set[UID]]: # noqa: UP006 - # Returns a dictionary of all storage permissions {object_uid: {*server_uids}} - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions = storage_permissions_collection.find({}) - storage_permissions_dict = {} - for storage_permission in storage_permissions: - storage_permissions_dict[storage_permission["_id"]] = storage_permission[ - "server_uids" - ] - return storage_permissions_dict - - @as_result(SyftException) - def take_ownership(self, uid: UID, credentials: SyftVerifyKey) -> bool: - collection_permissions: MongoCollection = self.permissions.unwrap() - collection: MongoCollection = self.collection.unwrap() - data: list[UID] | None = collection.find_one({"_id": uid}) - permissions: list[UID] | None = collection_permissions.find_one({"_id": uid}) - - if permissions is not None or data is not None: - raise SyftException(public_message=f"UID: {uid} already owned.") - - # first person using this UID can claim ownership - self.add_permissions( - [ - ActionObjectOWNER(uid=uid, credentials=credentials), - ActionObjectWRITE(uid=uid, credentials=credentials), - ActionObjectREAD(uid=uid, credentials=credentials), - ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] - ) - - return True - - @as_result(SyftException) - def _all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[SyftObject]: - qks = QueryKeys(qks=()) - return self._get_all_from_store( - credentials=credentials, - qks=qks, - order_by=order_by, - has_permission=has_permission, - ).unwrap() - - def __len__(self) -> int: - collection_status = self.collection - if collection_status.is_err(): - return 0 - collection: MongoCollection = collection_status.ok() - return collection.count_documents(filter={}) - - @as_result(SyftException) - def _migrate_data( - self, to_klass: SyftObject, context: AuthedServiceContext, has_permission: bool - ) -> bool: - credentials = context.credentials - has_permission = (credentials == self.root_verify_key) or has_permission - collection: MongoCollection = self.collection.unwrap() - - if has_permission: - storage_objs = collection.find({}) - for storage_obj in storage_objs: - obj = self.storage_type(storage_obj) - transform_context = TransformContext(output={}, obj=obj) - value = obj.to(self.settings.object_type, transform_context) - key = obj.get("_id") - try: - migrated_value = value.migrate_to(to_klass.__version__, context) - except Exception: - raise SyftException( - public_message=f"Failed to migrate data to {to_klass} for qk: {key}" - ) - qk = self.settings.store_key.with_obj(key) - self._update( - credentials, - qk=qk, - obj=migrated_value, - has_permission=has_permission, - ).unwrap() - return True - raise SyftException( - public_message="You don't have permissions to migrate data." - ) - - -@serializable(canonical_name="MongoDocumentStore", version=1) -class MongoDocumentStore(DocumentStore): - """Mongo Document Store - - Parameters: - `store_config`: MongoStoreConfig - Mongo specific configuration, including connection configuration, database name, or client class type. - """ - - partition_type = MongoStorePartition - - -@serializable( - attrs=["index_name", "settings", "store_config"], - canonical_name="MongoBackingStore", - version=1, -) -class MongoBackingStore(KeyValueBackingStore): - """ - Core logic for the MongoDB key-value store - - Parameters: - `index_name`: str - Index name (can be either 'data' or 'permissions') - `settings`: PartitionSettings - Syft specific settings - `store_config`: StoreConfig - Connection Configuration - `ddtype`: Type - Optional and should be None - Used to make a consistent interface with SQLiteBackingStore - """ - - def __init__( - self, - index_name: str, - settings: PartitionSettings, - store_config: StoreConfig, - ddtype: type | None = None, - ) -> None: - self.index_name = index_name - self.settings = settings - self.store_config = store_config - self.client: MongoClient - self.ddtype = ddtype - self.init_client() - - @as_result(SyftException) - def init_client(self) -> None: - self.client = MongoClient(config=self.store_config.client_config) - self._collection: MongoCollection = self.client.with_collection( - collection_settings=self.settings, - store_config=self.store_config, - collection_name=f"{self.settings.name}_{self.index_name}", - ).unwrap() - - @property - @as_result(SyftException) - def collection(self) -> MongoCollection: - if not hasattr(self, "_collection"): - self.init_client().unwrap() - return self._collection - - def _exist(self, key: UID) -> bool: - collection: MongoCollection = self.collection.unwrap() - result: dict | None = collection.find_one({"_id": key}) - if result is not None: - return True - return False - - def _set(self, key: UID, value: Any) -> None: - if self._exist(key): - self._update(key, value) - else: - collection: MongoCollection = self.collection.unwrap() - try: - bson_data = { - "_id": key, - f"{key}": _serialize(value, to_bytes=True), - "_repr_debug_": _repr_debug_(value), - } - collection.insert_one(bson_data) - except Exception: - raise SyftException(public_message="Cannot insert data.") - - def _update(self, key: UID, value: Any) -> None: - collection: MongoCollection = self.collection.unwrap() - try: - collection.update_one( - {"_id": key}, - { - "$set": { - f"{key}": _serialize(value, to_bytes=True), - "_repr_debug_": _repr_debug_(value), - } - }, - ) - except Exception as e: - raise SyftException( - public_message=f"Failed to update obj: {key} with value: {value}. Error: {e}" - ) - - def __setitem__(self, key: Any, value: Any) -> None: - self._set(key, value) - - def _get(self, key: UID) -> Any: - collection: MongoCollection = self.collection.unwrap() - result: dict | None = collection.find_one({"_id": key}) - if result is not None: - return _deserialize(result[f"{key}"], from_bytes=True) - else: - raise KeyError(f"{key} does not exist") - - def __getitem__(self, key: Any) -> Self: - try: - return self._get(key) - except KeyError as e: - if self.ddtype is not None: - return self.ddtype() - raise e - - def _len(self) -> int: - collection: MongoCollection = self.collection.unwrap() - return collection.count_documents(filter={}) - - def __len__(self) -> int: - return self._len() - - def _delete(self, key: UID) -> SyftSuccess: - collection: MongoCollection = self.collection.unwrap() - result = collection.delete_one({"_id": key}) - if result.deleted_count != 1: - raise SyftException(public_message=f"{key} does not exist") - return SyftSuccess(message="Deleted") - - def __delitem__(self, key: str) -> None: - self._delete(key) - - def _delete_all(self) -> None: - collection: MongoCollection = self.collection.unwrap() - collection.delete_many({}) - - def clear(self) -> None: - self._delete_all() - - def _get_all(self) -> Any: - collection_status = self.collection - if collection_status.is_err(): - return collection_status - collection: MongoCollection = collection_status.ok() - result = collection.find() - keys, values = [], [] - for row in result: - keys.append(row["_id"]) - values.append(_deserialize(row[f"{row['_id']}"], from_bytes=True)) - return dict(zip(keys, values)) - - def keys(self) -> Any: - return self._get_all().keys() - - def values(self) -> Any: - return self._get_all().values() - - def items(self) -> Any: - return self._get_all().items() - - def pop(self, key: Any) -> Self: - value = self._get(key) - self._delete(key) - return value - - def __contains__(self, key: Any) -> bool: - return self._exist(key) - - def __iter__(self) -> Any: - return iter(self.keys()) - - def __repr__(self) -> str: - return repr(self._get_all()) - - def copy(self) -> Self: - # 🟡 TODO - raise NotImplementedError - - def update(self, *args: Any, **kwargs: Any) -> None: - """ - Inserts the specified items to the dictionary. - """ - # 🟡 TODO - raise NotImplementedError - - def __del__(self) -> None: - """ - Close the mongo client connection: - - Cleanup client resources and disconnect from MongoDB - - End all server sessions created by this client - - Close all sockets in the connection pools and stop the monitor threads - """ - self.client.close() - - -@serializable() -class MongoStoreConfig(StoreConfig): - __canonical_name__ = "MongoStoreConfig" - """Mongo Store configuration - - Parameters: - `client_config`: MongoStoreClientConfig - Mongo connection details: hostname, port, user, password etc. - `store_type`: Type[DocumentStore] - The type of the DocumentStore. Default: MongoDocumentStore - `db_name`: str - Database name - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to NoLockingConfig. - """ - - client_config: MongoStoreClientConfig - store_type: type[DocumentStore] = MongoDocumentStore - db_name: str = "app" - backing_store: type[KeyValueBackingStore] = MongoBackingStore - # TODO: should use a distributed lock, with RedisLockingConfig - locking_config: LockingConfig = Field(default_factory=NoLockingConfig) diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 2cbef952862..82d75d68e6b 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -7,6 +7,8 @@ import logging from pathlib import Path import sqlite3 +from sqlite3 import Connection +from sqlite3 import Cursor import tempfile from typing import Any @@ -40,8 +42,8 @@ # by its filename and optionally the thread that its running in # we keep track of each SQLiteBackingStore init in REF_COUNTS # when it hits 0 we can close the connection and release the file descriptor -SQLITE_CONNECTION_POOL_DB: dict[str, sqlite3.Connection] = {} -SQLITE_CONNECTION_POOL_CUR: dict[str, sqlite3.Cursor] = {} +SQLITE_CONNECTION_POOL_DB: dict[str, Connection] = {} +SQLITE_CONNECTION_POOL_CUR: dict[str, Cursor] = {} REF_COUNTS: dict[str, int] = defaultdict(int) @@ -114,6 +116,7 @@ def __init__( self.lock = SyftLock(NoLockingConfig()) self.create_table() REF_COUNTS[cache_key(self.db_filename)] += 1 + self.subs_char = r"?" @property def table_name(self) -> str: @@ -134,7 +137,7 @@ def _connect(self) -> None: connection = sqlite3.connect( self.file_path, timeout=self.store_config.client_config.timeout, - check_same_thread=False, # do we need this if we use the lock? + check_same_thread=False, # do we need this if we use the lock # check_same_thread=self.store_config.client_config.check_same_thread, ) # Set journal mode to WAL. @@ -159,13 +162,13 @@ def create_table(self) -> None: raise SyftException.from_exception(e, public_message=public_message) @property - def db(self) -> sqlite3.Connection: + def db(self) -> Connection: if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_DB: self._connect() return SQLITE_CONNECTION_POOL_DB[cache_key(self.db_filename)] @property - def cur(self) -> sqlite3.Cursor: + def cur(self) -> Cursor: if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_CUR: SQLITE_CONNECTION_POOL_CUR[cache_key(self.db_filename)] = self.db.cursor() @@ -191,49 +194,74 @@ def _close(self) -> None: def _commit(self) -> None: self.db.commit() + @staticmethod @as_result(SyftException) - def _execute(self, sql: str, *args: list[Any] | None) -> sqlite3.Cursor: - with self.lock: - cursor: sqlite3.Cursor | None = None - # err = None + def _execute( + lock: SyftLock, + cursor: Cursor, + db: Connection, + table_name: str, + sql: str, + args: list[Any] | None, + ) -> Cursor: + with lock: + cur: Cursor | None = None try: - cursor = self.cur.execute(sql, *args) + cur = cursor.execute(sql, args) except Exception as e: - public_message = special_exception_public_message(self.table_name, e) + public_message = special_exception_public_message(table_name, e) raise SyftException.from_exception(e, public_message=public_message) - # TODO: Which exception is safe to rollback on? + # TODO: Which exception is safe to rollback on # we should map out some more clear exceptions that can be returned # rather than halting the program like disk I/O error etc # self.db.rollback() # Roll back all changes if an exception occurs. # err = Err(str(e)) - self.db.commit() # Commit if everything went ok - return cursor + db.commit() # Commit if everything went ok + return cur def _set(self, key: UID, value: Any) -> None: if self._exists(key): self._update(key, value) else: insert_sql = ( - f"insert into {self.table_name} (uid, repr, value) VALUES (?, ?, ?)" # nosec + f"insert into {self.table_name} (uid, repr, value) VALUES " # nosec + f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec ) data = _serialize(value, to_bytes=True) - self._execute(insert_sql, [str(key), _repr_debug_(value), data]).unwrap() + self._execute( + self.lock, + self.cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data], + ).unwrap() def _update(self, key: UID, value: Any) -> None: insert_sql = ( - f"update {self.table_name} set uid = ?, repr = ?, value = ? where uid = ?" # nosec + f"update {self.table_name} set uid = {self.subs_char}, " # nosec + f"repr = {self.subs_char}, value = {self.subs_char} " # nosec + f"where uid = {self.subs_char}" # nosec ) data = _serialize(value, to_bytes=True) self._execute( - insert_sql, [str(key), _repr_debug_(value), data, str(key)] + self.lock, + self.cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data, str(key)], ).unwrap() def _get(self, key: UID) -> Any: - select_sql = f"select * from {self.table_name} where uid = ? order by sqltime" # nosec - cursor = self._execute(select_sql, [str(key)]).unwrap( - public_message=f"Query {select_sql} failed" + select_sql = ( + f"select * from {self.table_name} where uid = {self.subs_char} " # nosec + "order by sqltime" ) + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap(public_message=f"Query {select_sql} failed") row = cursor.fetchone() if row is None or len(row) == 0: raise KeyError(f"{key} not in {type(self)}") @@ -241,13 +269,10 @@ def _get(self, key: UID) -> Any: return _deserialize(data, from_bytes=True) def _exists(self, key: UID) -> bool: - select_sql = f"select uid from {self.table_name} where uid = ?" # nosec - - res = self._execute(select_sql, [str(key)]) - if res.is_err(): - return False - cursor = res.ok() - + select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() row = cursor.fetchone() # type: ignore if row is None: return False @@ -259,13 +284,11 @@ def _get_all(self) -> Any: keys = [] data = [] - res = self._execute(select_sql) - if res.is_err(): - return {} - cursor = res.ok() - + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() rows = cursor.fetchall() # type: ignore - if rows is None: + if not rows: return {} for row in rows: @@ -276,29 +299,33 @@ def _get_all(self) -> Any: def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec - res = self._execute(select_sql) - if res.is_err(): - return [] - cursor = res.ok() - + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() rows = cursor.fetchall() # type: ignore - if rows is None: + if not rows: return [] keys = [UID(row[0]) for row in rows] return keys def _delete(self, key: UID) -> None: - select_sql = f"delete from {self.table_name} where uid = ?" # nosec - self._execute(select_sql, [str(key)]).unwrap() + select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec + self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() def _delete_all(self) -> None: select_sql = f"delete from {self.table_name}" # nosec - self._execute(select_sql).unwrap() + self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() def _len(self) -> int: select_sql = f"select count(uid) from {self.table_name}" # nosec - cursor = self._execute(select_sql).unwrap() + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() cnt = cursor.fetchone()[0] return cnt @@ -369,7 +396,7 @@ class SQLiteStorePartition(KeyValueStorePartition): def close(self) -> None: self.lock.acquire() try: - # I think we don't want these now, because of the REF_COUNT? + # I think we don't want these now, because of the REF_COUNT # self.data._close() # self.unique_keys._close() # self.searchable_keys._close() diff --git a/packages/syft/src/syft/types/datetime.py b/packages/syft/src/syft/types/datetime.py index c78dc04f5a2..93dd4ffc65e 100644 --- a/packages/syft/src/syft/types/datetime.py +++ b/packages/syft/src/syft/types/datetime.py @@ -67,6 +67,15 @@ def timedelta(self, other: "DateTime") -> timedelta: utc_timestamp_delta = self.utc_timestamp - other.utc_timestamp return timedelta(seconds=utc_timestamp_delta) + @classmethod + def from_timestamp(cls, ts: float) -> datetime: + return cls(utc_timestamp=ts) + + @classmethod + def from_datetime(cls, dt: datetime) -> "DateTime": + utc_datetime = dt.astimezone(timezone.utc) + return cls(utc_timestamp=utc_datetime.timestamp()) + def format_timedelta(local_timedelta: timedelta) -> str: total_seconds = int(local_timedelta.total_seconds()) diff --git a/packages/syft/src/syft/types/result.py b/packages/syft/src/syft/types/result.py index 52d392e48cc..bba6e80777b 100644 --- a/packages/syft/src/syft/types/result.py +++ b/packages/syft/src/syft/types/result.py @@ -114,7 +114,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, BE]: if isinstance(output, Ok) or isinstance(output, Err): raise _AsResultError( f"Functions decorated with `as_result` should not return Result.\n" - f"Did you forget to unwrap() the result?\n" + f"Did you forget to unwrap() the result in {func.__name__}?\n" f"result: {output}" ) return Ok(output) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 20dedae88c6..f6a4d3233cb 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -396,7 +396,6 @@ class SyftObject(SyftObjectVersioned): # all objects have a UID id: UID - created_date: BaseDateTime | None = None updated_date: BaseDateTime | None = None deleted_date: BaseDateTime | None = None @@ -411,6 +410,7 @@ def make_id(cls, values: Any) -> Any: values["id"] = id_field.annotation() return values + __order_by__: ClassVar[tuple[str, str]] = ("_created_at", "asc") __attr_searchable__: ClassVar[ list[str] ] = [] # keys which can be searched in the ORM diff --git a/packages/syft/src/syft/types/uid.py b/packages/syft/src/syft/types/uid.py index 96bc5af31b8..de364d7b10a 100644 --- a/packages/syft/src/syft/types/uid.py +++ b/packages/syft/src/syft/types/uid.py @@ -154,6 +154,10 @@ def is_valid_uuid(value: Any) -> bool: def no_dash(self) -> str: return str(self.value).replace("-", "") + @property + def hex(self) -> str: + return self.value.hex + def __repr__(self) -> str: """Returns a human-readable version of the ID diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 4e93e82c45a..538614b4cb8 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -24,7 +24,7 @@ def make_links(text: str) -> str: file_pattern = re.compile(r"([\w/.-]+\.py)\", line (\d+)") - return file_pattern.sub(r'\1, line \2', text) + return file_pattern.sub(r'\1, line \2', text) DEFAULT_ID_WIDTH = 110 diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index eca68d13b12..56506b43fad 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -7,6 +7,7 @@ import sys from tempfile import gettempdir from unittest import mock +from uuid import uuid4 # third party from faker import Faker @@ -22,24 +23,10 @@ from syft.protocol.data_protocol import protocol_release_dir from syft.protocol.data_protocol import stage_protocol_changes from syft.server.worker import Worker +from syft.service.queue.queue_stash import QueueStash from syft.service.user import user -# relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support -from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import dict_action_store -from .syft.stores.store_fixtures_test import dict_document_store -from .syft.stores.store_fixtures_test import dict_queue_stash -from .syft.stores.store_fixtures_test import dict_store_partition -from .syft.stores.store_fixtures_test import mongo_action_store -from .syft.stores.store_fixtures_test import mongo_document_store -from .syft.stores.store_fixtures_test import mongo_queue_stash -from .syft.stores.store_fixtures_test import mongo_store_partition -from .syft.stores.store_fixtures_test import sqlite_action_store -from .syft.stores.store_fixtures_test import sqlite_document_store -from .syft.stores.store_fixtures_test import sqlite_queue_stash -from .syft.stores.store_fixtures_test import sqlite_store_partition -from .syft.stores.store_fixtures_test import sqlite_workspace def patch_protocol_file(filepath: Path): @@ -129,7 +116,12 @@ def faker(): @pytest.fixture(scope="function") def worker() -> Worker: - worker = sy.Worker.named(name=token_hex(8)) + """ + NOTE in-memory sqlite is not shared between connections, so: + - using 2 workers (high/low) will not share a db + - re-using a connection (e.g. for a Job worker) will not share a db + """ + worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://") yield worker worker.cleanup() del worker @@ -138,7 +130,7 @@ def worker() -> Worker: @pytest.fixture(scope="function") def second_worker() -> Worker: # Used in server syncing tests - worker = sy.Worker.named(name=token_hex(8)) + worker = sy.Worker.named(name=uuid4().hex, db_url="sqlite://") yield worker worker.cleanup() del worker @@ -147,7 +139,7 @@ def second_worker() -> Worker: @pytest.fixture(scope="function") def high_worker() -> Worker: worker = sy.Worker.named( - name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE + name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE, db_url="sqlite://" ) yield worker worker.cleanup() @@ -157,7 +149,10 @@ def high_worker() -> Worker: @pytest.fixture(scope="function") def low_worker() -> Worker: worker = sy.Worker.named( - name=token_hex(8), server_side_type=ServerSideType.LOW_SIDE, dev_mode=True + name=token_hex(8), + server_side_type=ServerSideType.LOW_SIDE, + dev_mode=True, + db_url="sqlite://", ) yield worker worker.cleanup() @@ -212,8 +207,7 @@ def ds_verify_key(ds_client: DatasiteClient): @pytest.fixture def document_store(worker): - yield worker.document_store - worker.document_store.reset() + yield worker.db @pytest.fixture @@ -221,31 +215,6 @@ def action_store(worker): yield worker.action_store -@pytest.fixture(scope="session") -def mongo_client(testrun_uid): - """ - A race-free fixture that starts a MongoDB server for an entire pytest session. - Cleans up the server when the session ends, or when the last client disconnects. - """ - db_name = f"pytest_mongo_{testrun_uid}" - - # rand conn str - conn_str = f"mongodb://localhost:27017/{db_name}" - - # create a client, and test the connection - client = MongoClient(conn_str) - assert client.server_info().get("ok") == 1.0 - - yield client - - # stop_mongo_server(db_name) - - -@pytest.fixture(autouse=True) -def patched_mongo_client(monkeypatch): - monkeypatch.setattr("pymongo.mongo_client.MongoClient", MongoClient) - - @pytest.fixture(autouse=True) def patched_session_cache(monkeypatch): # patching compute heavy hashing to speed up tests @@ -307,21 +276,18 @@ def big_dataset() -> Dataset: yield dataset -__all__ = [ - "mongo_store_partition", - "mongo_document_store", - "mongo_queue_stash", - "mongo_action_store", - "sqlite_store_partition", - "sqlite_workspace", - "sqlite_document_store", - "sqlite_queue_stash", - "sqlite_action_store", - "dict_store_partition", - "dict_action_store", - "dict_document_store", - "dict_queue_stash", -] +@pytest.fixture( + scope="function", + params=[ + "tODOsqlite_address", + # "TODOpostgres_address", # will be used when we have a postgres CI tests + ], +) +def queue_stash(request): + _ = request.param + stash = QueueStash.random() + yield stash + pytest_plugins = [ "tests.syft.users.fixtures", diff --git a/packages/syft/tests/mongomock/__init__.py b/packages/syft/tests/mongomock/__init__.py deleted file mode 100644 index 6ce7670902b..00000000000 --- a/packages/syft/tests/mongomock/__init__.py +++ /dev/null @@ -1,138 +0,0 @@ -# stdlib -import os - -try: - # third party - from pymongo.errors import PyMongoError -except ImportError: - - class PyMongoError(Exception): - pass - - -try: - # third party - from pymongo.errors import OperationFailure -except ImportError: - - class OperationFailure(PyMongoError): - def __init__(self, message, code=None, details=None): - super(OperationFailure, self).__init__() - self._message = message - self._code = code - self._details = details - - code = property(lambda self: self._code) - details = property(lambda self: self._details) - - def __str__(self): - return self._message - - -try: - # third party - from pymongo.errors import WriteError -except ImportError: - - class WriteError(OperationFailure): - pass - - -try: - # third party - from pymongo.errors import DuplicateKeyError -except ImportError: - - class DuplicateKeyError(WriteError): - pass - - -try: - # third party - from pymongo.errors import BulkWriteError -except ImportError: - - class BulkWriteError(OperationFailure): - def __init__(self, results): - super(BulkWriteError, self).__init__( - "batch op errors occurred", 65, results - ) - - -try: - # third party - from pymongo.errors import CollectionInvalid -except ImportError: - - class CollectionInvalid(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidName -except ImportError: - - class InvalidName(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidOperation -except ImportError: - - class InvalidOperation(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import ConfigurationError -except ImportError: - - class ConfigurationError(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidURI -except ImportError: - - class InvalidURI(ConfigurationError): - pass - - -from .helpers import ObjectId, utcnow # noqa - - -__all__ = [ - "Database", - "DuplicateKeyError", - "Collection", - "CollectionInvalid", - "InvalidName", - "MongoClient", - "ObjectId", - "OperationFailure", - "WriteConcern", - "ignore_feature", - "patch", - "warn_on_feature", - "SERVER_VERSION", -] - -# relative -from .collection import Collection -from .database import Database -from .mongo_client import MongoClient -from .not_implemented import ignore_feature -from .not_implemented import warn_on_feature -from .patch import patch -from .write_concern import WriteConcern - -# The version of the server faked by mongomock. Callers may patch it before creating connections to -# update the behavior of mongomock. -# Keep the default version in sync with docker-compose.yml and travis.yml. -SERVER_VERSION = os.getenv("MONGODB", "5.0.5") diff --git a/packages/syft/tests/mongomock/__init__.pyi b/packages/syft/tests/mongomock/__init__.pyi deleted file mode 100644 index b7ba5e4c03c..00000000000 --- a/packages/syft/tests/mongomock/__init__.pyi +++ /dev/null @@ -1,30 +0,0 @@ -# stdlib -from typing import Any -from typing import Callable -from typing import Literal -from typing import Sequence -from typing import Tuple -from typing import Union -from unittest import mock - -# third party -from bson.objectid import ObjectId -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.database import Database -from pymongo.errors import CollectionInvalid -from pymongo.errors import DuplicateKeyError -from pymongo.errors import InvalidName -from pymongo.errors import OperationFailure - -def patch( - servers: Union[str, Tuple[str, int], Sequence[Union[str, Tuple[str, int]]]] = ..., - on_new: Literal["error", "create", "timeout", "pymongo"] = ..., -) -> mock._patch: ... - -_FeatureName = Literal["collation", "session"] - -def ignore_feature(feature: _FeatureName) -> None: ... -def warn_on_feature(feature: _FeatureName) -> None: ... - -SERVER_VERSION: str = ... diff --git a/packages/syft/tests/mongomock/__version__.py b/packages/syft/tests/mongomock/__version__.py deleted file mode 100644 index 14863a3db29..00000000000 --- a/packages/syft/tests/mongomock/__version__.py +++ /dev/null @@ -1,15 +0,0 @@ -# stdlib -from platform import python_version_tuple - -python_version = python_version_tuple() - -if (int(python_version[0]), int(python_version[1])) >= (3, 8): - # stdlib - from importlib.metadata import version - - __version__ = version("mongomock") -else: - # third party - import pkg_resources - - __version__ = pkg_resources.get_distribution("mongomock").version diff --git a/packages/syft/tests/mongomock/aggregate.py b/packages/syft/tests/mongomock/aggregate.py deleted file mode 100644 index 243720b690f..00000000000 --- a/packages/syft/tests/mongomock/aggregate.py +++ /dev/null @@ -1,1811 +0,0 @@ -"""Module to handle the operations within the aggregate pipeline.""" - -# stdlib -import bisect -import collections -import copy -import datetime -import decimal -import functools -import itertools -import math -import numbers -import random -import re -import sys -import warnings - -# third party -from packaging import version -import pytz - -# relative -from . import OperationFailure -from . import command_cursor -from . import filtering -from . import helpers - -try: - # third party - from bson import Regex - from bson import decimal128 - from bson.errors import InvalidDocument - - decimal_support = True - _RE_TYPES = (helpers.RE_TYPE, Regex) -except ImportError: - InvalidDocument = OperationFailure - decimal_support = False - _RE_TYPES = helpers.RE_TYPE - -_random = random.Random() - - -group_operators = [ - "$addToSet", - "$avg", - "$first", - "$last", - "$max", - "$mergeObjects", - "$min", - "$push", - "$stdDevPop", - "$stdDevSamp", - "$sum", -] -unary_arithmetic_operators = { - "$abs", - "$ceil", - "$exp", - "$floor", - "$ln", - "$log10", - "$sqrt", - "$trunc", -} -binary_arithmetic_operators = { - "$divide", - "$log", - "$mod", - "$pow", - "$subtract", -} -arithmetic_operators = ( - unary_arithmetic_operators - | binary_arithmetic_operators - | { - "$add", - "$multiply", - } -) -project_operators = [ - "$max", - "$min", - "$avg", - "$sum", - "$stdDevPop", - "$stdDevSamp", - "$arrayElemAt", - "$first", - "$last", -] -control_flow_operators = [ - "$switch", -] -projection_operators = [ - "$let", - "$literal", -] -date_operators = [ - "$dateFromString", - "$dateToString", - "$dateFromParts", - "$dayOfMonth", - "$dayOfWeek", - "$dayOfYear", - "$hour", - "$isoDayOfWeek", - "$isoWeek", - "$isoWeekYear", - "$millisecond", - "$minute", - "$month", - "$second", - "$week", - "$year", -] -conditional_operators = ["$cond", "$ifNull"] -array_operators = [ - "$concatArrays", - "$filter", - "$indexOfArray", - "$map", - "$range", - "$reduce", - "$reverseArray", - "$size", - "$slice", - "$zip", -] -object_operators = [ - "$mergeObjects", -] -text_search_operators = ["$meta"] -string_operators = [ - "$concat", - "$indexOfBytes", - "$indexOfCP", - "$regexMatch", - "$split", - "$strcasecmp", - "$strLenBytes", - "$strLenCP", - "$substr", - "$substrBytes", - "$substrCP", - "$toLower", - "$toUpper", - "$trim", -] -comparison_operators = [ - "$cmp", - "$eq", - "$ne", -] + list(filtering.SORTING_OPERATOR_MAP.keys()) -boolean_operators = ["$and", "$or", "$not"] -set_operators = [ - "$in", - "$setEquals", - "$setIntersection", - "$setDifference", - "$setUnion", - "$setIsSubset", - "$anyElementTrue", - "$allElementsTrue", -] - -type_convertion_operators = [ - "$convert", - "$toString", - "$toInt", - "$toDecimal", - "$toLong", - "$arrayToObject", - "$objectToArray", -] -type_operators = [ - "$isNumber", - "$isArray", -] - - -def _avg_operation(values): - values_list = list(v for v in values if isinstance(v, numbers.Number)) - if not values_list: - return None - return sum(values_list) / float(len(list(values_list))) - - -def _group_operation(values, operator): - values_list = list(v for v in values if v is not None) - if not values_list: - return None - return operator(values_list) - - -def _sum_operation(values): - values_list = list() - if decimal_support: - for v in values: - if isinstance(v, numbers.Number): - values_list.append(v) - elif isinstance(v, decimal128.Decimal128): - values_list.append(v.to_decimal()) - else: - values_list = list(v for v in values if isinstance(v, numbers.Number)) - sum_value = sum(values_list) - return ( - decimal128.Decimal128(sum_value) - if isinstance(sum_value, decimal.Decimal) - else sum_value - ) - - -def _merge_objects_operation(values): - merged_doc = dict() - for v in values: - if isinstance(v, dict): - merged_doc.update(v) - return merged_doc - - -_GROUPING_OPERATOR_MAP = { - "$sum": _sum_operation, - "$avg": _avg_operation, - "$mergeObjects": _merge_objects_operation, - "$min": lambda values: _group_operation(values, min), - "$max": lambda values: _group_operation(values, max), - "$first": lambda values: values[0] if values else None, - "$last": lambda values: values[-1] if values else None, -} - - -class _Parser(object): - """Helper to parse expressions within the aggregate pipeline.""" - - def __init__(self, doc_dict, user_vars=None, ignore_missing_keys=False): - self._doc_dict = doc_dict - self._ignore_missing_keys = ignore_missing_keys - self._user_vars = user_vars or {} - - def parse(self, expression): - """Parse a MongoDB expression.""" - if not isinstance(expression, dict): - # May raise a KeyError despite the ignore missing key. - return self._parse_basic_expression(expression) - - if len(expression) > 1 and any(key.startswith("$") for key in expression): - raise OperationFailure( - "an expression specification must contain exactly one field, " - "the name of the expression. Found %d fields in %s" - % (len(expression), expression) - ) - - value_dict = {} - for k, v in expression.items(): - if k in arithmetic_operators: - return self._handle_arithmetic_operator(k, v) - if k in project_operators: - return self._handle_project_operator(k, v) - if k in projection_operators: - return self._handle_projection_operator(k, v) - if k in comparison_operators: - return self._handle_comparison_operator(k, v) - if k in date_operators: - return self._handle_date_operator(k, v) - if k in array_operators: - return self._handle_array_operator(k, v) - if k in conditional_operators: - return self._handle_conditional_operator(k, v) - if k in control_flow_operators: - return self._handle_control_flow_operator(k, v) - if k in set_operators: - return self._handle_set_operator(k, v) - if k in string_operators: - return self._handle_string_operator(k, v) - if k in type_convertion_operators: - return self._handle_type_convertion_operator(k, v) - if k in type_operators: - return self._handle_type_operator(k, v) - if k in boolean_operators: - return self._handle_boolean_operator(k, v) - if k in text_search_operators + projection_operators + object_operators: - raise NotImplementedError( - "'%s' is a valid operation but it is not supported by Mongomock yet." - % k - ) - if k.startswith("$"): - raise OperationFailure("Unrecognized expression '%s'" % k) - try: - value = self.parse(v) - except KeyError: - if self._ignore_missing_keys: - continue - raise - value_dict[k] = value - - return value_dict - - def parse_many(self, values): - for value in values: - try: - yield self.parse(value) - except KeyError: - if self._ignore_missing_keys: - yield None - else: - raise - - def _parse_to_bool(self, expression): - """Parse a MongoDB expression and then convert it to bool""" - # handles converting `undefined` (in form of KeyError) to False - try: - return helpers.mongodb_to_bool(self.parse(expression)) - except KeyError: - return False - - def _parse_or_None(self, expression): - try: - return self.parse(expression) - except KeyError: - return None - - def _parse_basic_expression(self, expression): - if isinstance(expression, str) and expression.startswith("$"): - if expression.startswith("$$"): - return helpers.get_value_by_dot( - dict( - { - "ROOT": self._doc_dict, - "CURRENT": self._doc_dict, - }, - **self._user_vars, - ), - expression[2:], - can_generate_array=True, - ) - return helpers.get_value_by_dot( - self._doc_dict, expression[1:], can_generate_array=True - ) - return expression - - def _handle_boolean_operator(self, operator, values): - if operator == "$and": - return all([self._parse_to_bool(value) for value in values]) - if operator == "$or": - return any(self._parse_to_bool(value) for value in values) - if operator == "$not": - return not self._parse_to_bool(values) - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid boolean operator for the " - "aggregation pipeline, it is currently not implemented" - " in Mongomock." % operator - ) - - def _handle_arithmetic_operator(self, operator, values): - if operator in unary_arithmetic_operators: - try: - number = self.parse(values) - except KeyError: - return None - if number is None: - return None - if not isinstance(number, numbers.Number): - raise OperationFailure( - "Parameter to %s must evaluate to a number, got '%s'" - % (operator, type(number)) - ) - - if operator == "$abs": - return abs(number) - if operator == "$ceil": - return math.ceil(number) - if operator == "$exp": - return math.exp(number) - if operator == "$floor": - return math.floor(number) - if operator == "$ln": - return math.log(number) - if operator == "$log10": - return math.log10(number) - if operator == "$sqrt": - return math.sqrt(number) - if operator == "$trunc": - return math.trunc(number) - - if operator in binary_arithmetic_operators: - if not isinstance(values, (tuple, list)): - raise OperationFailure( - "Parameter to %s must evaluate to a list, got '%s'" - % (operator, type(values)) - ) - - if len(values) != 2: - raise OperationFailure("%s must have only 2 parameters" % operator) - number_0, number_1 = self.parse_many(values) - if number_0 is None or number_1 is None: - return None - - if operator == "$divide": - return number_0 / number_1 - if operator == "$log": - return math.log(number_0, number_1) - if operator == "$mod": - return math.fmod(number_0, number_1) - if operator == "$pow": - return math.pow(number_0, number_1) - if operator == "$subtract": - if isinstance(number_0, datetime.datetime) and isinstance( - number_1, (int, float) - ): - number_1 = datetime.timedelta(milliseconds=number_1) - res = number_0 - number_1 - if isinstance(res, datetime.timedelta): - return round(res.total_seconds() * 1000) - return res - - assert isinstance(values, (tuple, list)), ( - "Parameter to %s must evaluate to a list, got '%s'" - % ( - operator, - type(values), - ) - ) - - parsed_values = list(self.parse_many(values)) - assert parsed_values, "%s must have at least one parameter" % operator - for value in parsed_values: - if value is None: - return None - assert isinstance(value, numbers.Number), "%s only uses numbers" % operator - if operator == "$add": - return sum(parsed_values) - if operator == "$multiply": - return functools.reduce(lambda x, y: x * y, parsed_values) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid aritmetic operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - def _handle_project_operator(self, operator, values): - if operator in _GROUPING_OPERATOR_MAP: - values = ( - self.parse(values) - if isinstance(values, str) - else self.parse_many(values) - ) - return _GROUPING_OPERATOR_MAP[operator](values) - if operator == "$arrayElemAt": - key, value = values - array = self.parse(key) - index = self.parse(value) - try: - return array[index] - except IndexError as error: - raise KeyError("Array have length less than index value") from error - - raise NotImplementedError( - "Although '%s' is a valid project operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_projection_operator(self, operator, value): - if operator == "$literal": - return value - if operator == "$let": - if not isinstance(value, dict): - raise InvalidDocument("$let only supports an object as its argument") - for field in ("vars", "in"): - if field not in value: - raise OperationFailure( - "Missing '{}' parameter to $let".format(field) - ) - if not isinstance(value["vars"], dict): - raise OperationFailure("invalid parameter: expected an object (vars)") - user_vars = { - var_key: self.parse(var_value) - for var_key, var_value in value["vars"].items() - } - return _Parser( - self._doc_dict, - dict(self._user_vars, **user_vars), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(value["in"]) - raise NotImplementedError( - "Although '%s' is a valid project operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_comparison_operator(self, operator, values): - assert len(values) == 2, "Comparison requires two expressions" - a = self.parse(values[0]) - b = self.parse(values[1]) - if operator == "$eq": - return a == b - if operator == "$ne": - return a != b - if operator in filtering.SORTING_OPERATOR_MAP: - return filtering.bson_compare( - filtering.SORTING_OPERATOR_MAP[operator], a, b - ) - raise NotImplementedError( - "Although '%s' is a valid comparison operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_string_operator(self, operator, values): - if operator == "$toLower": - parsed = self.parse(values) - return str(parsed).lower() if parsed is not None else "" - if operator == "$toUpper": - parsed = self.parse(values) - return str(parsed).upper() if parsed is not None else "" - if operator == "$concat": - parsed_list = list(self.parse_many(values)) - return ( - None if None in parsed_list else "".join([str(x) for x in parsed_list]) - ) - if operator == "$split": - if len(values) != 2: - raise OperationFailure("split must have 2 items") - try: - string = self.parse(values[0]) - delimiter = self.parse(values[1]) - except KeyError: - return None - - if string is None or delimiter is None: - return None - if not isinstance(string, str): - raise TypeError("split first argument must evaluate to string") - if not isinstance(delimiter, str): - raise TypeError("split second argument must evaluate to string") - return string.split(delimiter) - if operator == "$substr": - if len(values) != 3: - raise OperationFailure("substr must have 3 items") - string = str(self.parse(values[0])) - first = self.parse(values[1]) - length = self.parse(values[2]) - if string is None: - return "" - if first < 0: - warnings.warn( - "Negative starting point given to $substr is accepted only until " - "MongoDB 3.7. This behavior will change in the future." - ) - return "" - if length < 0: - warnings.warn( - "Negative length given to $substr is accepted only until " - "MongoDB 3.7. This behavior will change in the future." - ) - second = len(string) if length < 0 else first + length - return string[first:second] - if operator == "$strcasecmp": - if len(values) != 2: - raise OperationFailure("strcasecmp must have 2 items") - a, b = str(self.parse(values[0])), str(self.parse(values[1])) - return 0 if a == b else -1 if a < b else 1 - if operator == "$regexMatch": - if not isinstance(values, dict): - raise OperationFailure( - "$regexMatch expects an object of named arguments but found: %s" - % type(values) - ) - for field in ("input", "regex"): - if field not in values: - raise OperationFailure( - "$regexMatch requires '%s' parameter" % field - ) - unknown_args = set(values) - {"input", "regex", "options"} - if unknown_args: - raise OperationFailure( - "$regexMatch found an unknown argument: %s" % list(unknown_args)[0] - ) - - try: - input_value = self.parse(values["input"]) - except KeyError: - return False - if not isinstance(input_value, str): - raise OperationFailure("$regexMatch needs 'input' to be of type string") - - try: - regex_val = self.parse(values["regex"]) - except KeyError: - return False - options = None - for option in values.get("options", ""): - if option not in "imxs": - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" % option - ) - re_option = getattr(re, option.upper()) - if options is None: - options = re_option - else: - options |= re_option - if isinstance(regex_val, str): - if options is None: - regex = re.compile(regex_val) - else: - regex = re.compile(regex_val, options) - elif "options" in values and regex_val.flags: - raise OperationFailure( - "$regexMatch: regex option(s) specified in both 'regex' and 'option' fields" - ) - elif isinstance(regex_val, helpers.RE_TYPE): - if options and not regex_val.flags: - regex = re.compile(regex_val.pattern, options) - elif regex_val.flags & ~(re.I | re.M | re.X | re.S): - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" - % regex_val.flags - ) - else: - regex = regex_val - elif isinstance(regex_val, _RE_TYPES): - # bson.Regex - if regex_val.flags & ~(re.I | re.M | re.X | re.S): - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" - % regex_val.flags - ) - regex = re.compile(regex_val.pattern, regex_val.flags or options) - else: - raise OperationFailure( - "$regexMatch needs 'regex' to be of type string or regex" - ) - - return bool(regex.search(input_value)) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid string operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - def _handle_date_operator(self, operator, values): - if isinstance(values, dict) and values.keys() == {"date", "timezone"}: - value = self.parse(values["date"]) - target_tz = pytz.timezone(values["timezone"]) - out_value = value.replace(tzinfo=pytz.utc).astimezone(target_tz) - else: - out_value = self.parse(values) - - if operator == "$dayOfYear": - return out_value.timetuple().tm_yday - if operator == "$dayOfMonth": - return out_value.day - if operator == "$dayOfWeek": - return (out_value.isoweekday() % 7) + 1 - if operator == "$year": - return out_value.year - if operator == "$month": - return out_value.month - if operator == "$week": - return int(out_value.strftime("%U")) - if operator == "$hour": - return out_value.hour - if operator == "$minute": - return out_value.minute - if operator == "$second": - return out_value.second - if operator == "$millisecond": - return int(out_value.microsecond / 1000) - if operator == "$dateToString": - if not isinstance(values, dict): - raise OperationFailure( - "$dateToString operator must correspond a dict" - 'that has "format" and "date" field.' - ) - if not isinstance(values, dict) or not {"format", "date"} <= set(values): - raise OperationFailure( - "$dateToString operator must correspond a dict" - 'that has "format" and "date" field.' - ) - if "%L" in out_value["format"]: - raise NotImplementedError( - "Although %L is a valid date format for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - if "onNull" in values: - raise NotImplementedError( - "Although onNull is a valid field for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - if "timezone" in values.keys(): - raise NotImplementedError( - "Although timezone is a valid field for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - return out_value["date"].strftime(out_value["format"]) - if operator == "$dateFromParts": - if not isinstance(out_value, dict): - raise OperationFailure( - f"{operator} operator must correspond a dict " - 'that has "year" or "isoWeekYear" field.' - ) - if len(set(out_value) & {"year", "isoWeekYear"}) != 1: - raise OperationFailure( - f"{operator} operator must correspond a dict " - 'that has "year" or "isoWeekYear" field.' - ) - for field in ("isoWeekYear", "isoWeek", "isoDayOfWeek", "timezone"): - if field in out_value: - raise NotImplementedError( - f"Although {field} is a valid field for the " - f"{operator} operator, it is currently not implemented " - "in Mongomock." - ) - - year = out_value["year"] - month = out_value.get("month", 1) or 1 - day = out_value.get("day", 1) or 1 - hour = out_value.get("hour", 0) or 0 - minute = out_value.get("minute", 0) or 0 - second = out_value.get("second", 0) or 0 - millisecond = out_value.get("millisecond", 0) or 0 - - return datetime.datetime( - year=year, - month=month, - day=day, - hour=hour, - minute=minute, - second=second, - microsecond=millisecond, - ) - - raise NotImplementedError( - "Although '%s' is a valid date operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_array_operator(self, operator, value): - if operator == "$concatArrays": - if not isinstance(value, (list, tuple)): - value = [value] - - parsed_list = list(self.parse_many(value)) - for parsed_item in parsed_list: - if parsed_item is not None and not isinstance( - parsed_item, (list, tuple) - ): - raise OperationFailure( - "$concatArrays only supports arrays, not {}".format( - type(parsed_item) - ) - ) - - return ( - None - if None in parsed_list - else list(itertools.chain.from_iterable(parsed_list)) - ) - - if operator == "$map": - if not isinstance(value, dict): - raise OperationFailure("$map only supports an object as its argument") - - # NOTE: while the two validations below could be achieved with - # one-liner set operations (e.g. set(value) - {'input', 'as', - # 'in'}), we prefer the iteration-based approaches in order to - # mimic MongoDB's behavior regarding the order of evaluation. For - # example, MongoDB complains about 'input' parameter missing before - # 'in'. - for k in ("input", "in"): - if k not in value: - raise OperationFailure("Missing '%s' parameter to $map" % k) - - for k in value: - if k not in {"input", "as", "in"}: - raise OperationFailure("Unrecognized parameter to $map: %s" % k) - - input_array = self._parse_or_None(value["input"]) - - if input_array is None or input_array is None: - return None - - if not isinstance(input_array, (list, tuple)): - raise OperationFailure( - "input to $map must be an array not %s" % type(input_array) - ) - - fieldname = value.get("as", "this") - in_expr = value["in"] - return [ - _Parser( - self._doc_dict, - dict(self._user_vars, **{fieldname: item}), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(in_expr) - for item in input_array - ] - - if operator == "$size": - if isinstance(value, list): - if len(value) != 1: - raise OperationFailure( - "Expression $size takes exactly 1 arguments. " - "%d were passed in." % len(value) - ) - value = value[0] - array_value = self._parse_or_None(value) - if not isinstance(array_value, (list, tuple)): - raise OperationFailure( - "The argument to $size must be an array, but was of type: %s" - % ("missing" if array_value is None else type(array_value)) - ) - return len(array_value) - - if operator == "$filter": - if not isinstance(value, dict): - raise OperationFailure( - "$filter only supports an object as its argument" - ) - extra_params = set(value) - {"input", "cond", "as"} - if extra_params: - raise OperationFailure( - "Unrecognized parameter to $filter: %s" % extra_params.pop() - ) - missing_params = {"input", "cond"} - set(value) - if missing_params: - raise OperationFailure( - "Missing '%s' parameter to $filter" % missing_params.pop() - ) - - input_array = self.parse(value["input"]) - fieldname = value.get("as", "this") - cond = value["cond"] - return [ - item - for item in input_array - if _Parser( - self._doc_dict, - dict(self._user_vars, **{fieldname: item}), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(cond) - ] - if operator == "$slice": - if not isinstance(value, list): - raise OperationFailure("$slice only supports a list as its argument") - if len(value) < 2 or len(value) > 3: - raise OperationFailure( - "Expression $slice takes at least 2 arguments, and at most " - "3, but {} were passed in".format(len(value)) - ) - array_value = self.parse(value[0]) - if not isinstance(array_value, list): - raise OperationFailure( - "First argument to $slice must be an array, but is of type: {}".format( - type(array_value) - ) - ) - for num, v in zip(("Second", "Third"), value[1:]): - if not isinstance(v, int): - raise OperationFailure( - "{} argument to $slice must be numeric, but is of type: {}".format( - num, type(v) - ) - ) - if len(value) > 2 and value[2] <= 0: - raise OperationFailure( - "Third argument to $slice must be " "positive: {}".format(value[2]) - ) - - start = value[1] - if start < 0: - if len(value) > 2: - stop = len(array_value) + start + value[2] - else: - stop = None - elif len(value) > 2: - stop = start + value[2] - else: - stop = start - start = 0 - return array_value[start:stop] - - raise NotImplementedError( - "Although '%s' is a valid array operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_type_convertion_operator(self, operator, values): - if operator == "$toString": - try: - parsed = self.parse(values) - except KeyError: - return None - if isinstance(parsed, bool): - return str(parsed).lower() - if isinstance(parsed, datetime.datetime): - return parsed.isoformat()[:-3] + "Z" - return str(parsed) - - if operator == "$toInt": - try: - parsed = self.parse(values) - except KeyError: - return None - if decimal_support: - if isinstance(parsed, decimal128.Decimal128): - return int(parsed.to_decimal()) - return int(parsed) - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - - if operator == "$toLong": - try: - parsed = self.parse(values) - except KeyError: - return None - if decimal_support: - if isinstance(parsed, decimal128.Decimal128): - return int(parsed.to_decimal()) - return int(parsed) - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/toDecimal/ - if operator == "$toDecimal": - if not decimal_support: - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - try: - parsed = self.parse(values) - except KeyError: - return None - if isinstance(parsed, bool): - parsed = "1" if parsed is True else "0" - decimal_value = decimal128.Decimal128(parsed) - elif isinstance(parsed, int): - decimal_value = decimal128.Decimal128(str(parsed)) - elif isinstance(parsed, float): - exp = decimal.Decimal(".00000000000000") - decimal_value = decimal.Decimal(str(parsed)).quantize(exp) - decimal_value = decimal128.Decimal128(decimal_value) - elif isinstance(parsed, decimal128.Decimal128): - decimal_value = parsed - elif isinstance(parsed, str): - try: - decimal_value = decimal128.Decimal128(parsed) - except decimal.InvalidOperation as err: - raise OperationFailure( - "Failed to parse number '%s' in $convert with no onError value:" - "Failed to parse string to decimal" % parsed - ) from err - elif isinstance(parsed, datetime.datetime): - epoch = datetime.datetime.utcfromtimestamp(0) - string_micro_seconds = str( - (parsed - epoch).total_seconds() * 1000 - ).split(".", 1)[0] - decimal_value = decimal128.Decimal128(string_micro_seconds) - else: - raise TypeError("'%s' type is not supported" % type(parsed)) - return decimal_value - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/arrayToObject/ - if operator == "$arrayToObject": - try: - parsed = self.parse(values) - except KeyError: - return None - - if parsed is None: - return None - - if not isinstance(parsed, (list, tuple)): - raise OperationFailure( - "$arrayToObject requires an array input, found: {}".format( - type(parsed) - ) - ) - - if all(isinstance(x, dict) and set(x.keys()) == {"k", "v"} for x in parsed): - return {d["k"]: d["v"] for d in parsed} - - if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in parsed): - return dict(parsed) - - raise OperationFailure( - "arrays used with $arrayToObject must contain documents " - "with k and v fields or two-element arrays" - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/objectToArray/ - if operator == "$objectToArray": - try: - parsed = self.parse(values) - except KeyError: - return None - - if parsed is None: - return None - - if not isinstance(parsed, (dict, collections.OrderedDict)): - raise OperationFailure( - "$objectToArray requires an object input, found: {}".format( - type(parsed) - ) - ) - - if len(parsed) > 1 and sys.version_info < (3, 6): - raise NotImplementedError( - "Although '%s' is a valid type conversion, it is not implemented for Python 2 " - "and Python 3.5 in Mongomock yet." % operator - ) - - return [{"k": k, "v": v} for k, v in parsed.items()] - - raise NotImplementedError( - "Although '%s' is a valid type conversion operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_type_operator(self, operator, values): - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isNumber/ - if operator == "$isNumber": - try: - parsed = self.parse(values) - except KeyError: - return False - return ( - False - if isinstance(parsed, bool) - else isinstance(parsed, numbers.Number) - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isArray/ - if operator == "$isArray": - try: - parsed = self.parse(values) - except KeyError: - return False - return isinstance(parsed, (tuple, list)) - - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid type operator for the aggregation pipeline, it is currently " - "not implemented in Mongomock." % operator - ) - - def _handle_conditional_operator(self, operator, values): - # relative - from . import SERVER_VERSION - - if operator == "$ifNull": - fields = values[:-1] - if len(fields) > 1 and version.parse(SERVER_VERSION) <= version.parse( - "4.4" - ): - raise OperationFailure( - "$ifNull supports only one input expression " - " in MongoDB v4.4 and lower" - ) - fallback = values[-1] - for field in fields: - try: - out_value = self.parse(field) - if out_value is not None: - return out_value - except KeyError: - pass - return self.parse(fallback) - if operator == "$cond": - if isinstance(values, list): - condition, true_case, false_case = values - elif isinstance(values, dict): - condition = values["if"] - true_case = values["then"] - false_case = values["else"] - condition_value = self._parse_to_bool(condition) - expression = true_case if condition_value else false_case - return self.parse(expression) - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid conditional operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_control_flow_operator(self, operator, values): - if operator == "$switch": - if not isinstance(values, dict): - raise OperationFailure( - "$switch requires an object as an argument, " - "found: %s" % type(values) - ) - - branches = values.get("branches", []) - if not isinstance(branches, (list, tuple)): - raise OperationFailure( - "$switch expected an array for 'branches', " - "found: %s" % type(branches) - ) - if not branches: - raise OperationFailure("$switch requires at least one branch.") - - for branch in branches: - if not isinstance(branch, dict): - raise OperationFailure( - "$switch expected each branch to be an object, " - "found: %s" % type(branch) - ) - if "case" not in branch: - raise OperationFailure( - "$switch requires each branch have a 'case' expression" - ) - if "then" not in branch: - raise OperationFailure( - "$switch requires each branch have a 'then' expression." - ) - - for branch in branches: - if self._parse_to_bool(branch["case"]): - return self.parse(branch["then"]) - - if "default" not in values: - raise OperationFailure( - "$switch could not find a matching branch for an input, " - "and no default was specified." - ) - return self.parse(values["default"]) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid control flow operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_set_operator(self, operator, values): - if operator == "$in": - expression, array = values - return self.parse(expression) in self.parse(array) - if operator == "$setUnion": - result = [] - for set_value in values: - for value in self.parse(set_value): - if value not in result: - result.append(value) - return result - if operator == "$setEquals": - set_values = [set(self.parse(value)) for value in values] - for set1, set2 in itertools.combinations(set_values, 2): - if set1 != set2: - return False - return True - raise NotImplementedError( - "Although '%s' is a valid set operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - -def _parse_expression(expression, doc_dict, ignore_missing_keys=False): - """Parse an expression. - - Args: - expression: an Aggregate Expression, see - https://docs.mongodb.com/manual/meta/aggregation-quick-reference/#aggregation-expressions. - doc_dict: the document on which to evaluate the expression. - ignore_missing_keys: if True, missing keys evaluated by the expression are ignored silently - if it is possible. - """ - return _Parser(doc_dict, ignore_missing_keys=ignore_missing_keys).parse(expression) - - -filtering.register_parse_expression(_parse_expression) - - -def _accumulate_group(output_fields, group_list): - doc_dict = {} - for field, value in output_fields.items(): - if field == "_id": - continue - for operator, key in value.items(): - values = [] - for doc in group_list: - try: - values.append(_parse_expression(key, doc)) - except KeyError: - continue - if operator in _GROUPING_OPERATOR_MAP: - doc_dict[field] = _GROUPING_OPERATOR_MAP[operator](values) - elif operator == "$addToSet": - value = [] - val_it = (val or None for val in values) - # Don't use set in case elt in not hashable (like dicts). - for elt in val_it: - if elt not in value: - value.append(elt) - doc_dict[field] = value - elif operator == "$push": - if field not in doc_dict: - doc_dict[field] = values - else: - doc_dict[field].extend(values) - elif operator in group_operators: - raise NotImplementedError( - "Although %s is a valid group operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - else: - raise NotImplementedError( - "%s is not a valid group operator for the aggregation " - "pipeline. See http://docs.mongodb.org/manual/meta/" - "aggregation-quick-reference/ for a complete list of " - "valid operators." % operator - ) - return doc_dict - - -def _fix_sort_key(key_getter): - def fixed_getter(doc): - key = key_getter(doc) - # Convert dictionaries to make sorted() work in Python 3. - if isinstance(key, dict): - return [(k, v) for (k, v) in sorted(key.items())] - return key - - return fixed_getter - - -def _handle_lookup_stage(in_collection, database, options): - for operator in ("let", "pipeline"): - if operator in options: - raise NotImplementedError( - "Although '%s' is a valid lookup operator for the " - "aggregation pipeline, it is currently not " - "implemented in Mongomock." % operator - ) - for operator in ("from", "localField", "foreignField", "as"): - if operator not in options: - raise OperationFailure("Must specify '%s' field for a $lookup" % operator) - if not isinstance(options[operator], str): - raise OperationFailure("Arguments to $lookup must be strings") - if operator in ("as", "localField", "foreignField") and options[ - operator - ].startswith("$"): - raise OperationFailure("FieldPath field names may not start with '$'") - if operator == "as" and "." in options[operator]: - raise NotImplementedError( - "Although '.' is valid in the 'as' " - "parameters for the lookup stage of the aggregation " - "pipeline, it is currently not implemented in Mongomock." - ) - - foreign_name = options["from"] - local_field = options["localField"] - foreign_field = options["foreignField"] - local_name = options["as"] - foreign_collection = database.get_collection(foreign_name) - for doc in in_collection: - try: - query = helpers.get_value_by_dot(doc, local_field) - except KeyError: - query = None - if isinstance(query, list): - query = {"$in": query} - matches = foreign_collection.find({foreign_field: query}) - doc[local_name] = [foreign_doc for foreign_doc in matches] - - return in_collection - - -def _recursive_get(match, nested_fields): - head = match.get(nested_fields[0]) - remaining_fields = nested_fields[1:] - if not remaining_fields: - # Final/last field reached. - yield head - return - # More fields to go, must be list, tuple, or dict. - if isinstance(head, (list, tuple)): - for m in head: - # Yield from _recursive_get(m, remaining_fields). - for answer in _recursive_get(m, remaining_fields): - yield answer - elif isinstance(head, dict): - # Yield from _recursive_get(head, remaining_fields). - for answer in _recursive_get(head, remaining_fields): - yield answer - - -def _handle_graph_lookup_stage(in_collection, database, options): - if not isinstance(options.get("maxDepth", 0), int): - raise OperationFailure("Argument 'maxDepth' to $graphLookup must be a number") - if not isinstance(options.get("restrictSearchWithMatch", {}), dict): - raise OperationFailure( - "Argument 'restrictSearchWithMatch' to $graphLookup must be a Dictionary" - ) - if not isinstance(options.get("depthField", ""), str): - raise OperationFailure("Argument 'depthField' to $graphlookup must be a string") - if "startWith" not in options: - raise OperationFailure("Must specify 'startWith' field for a $graphLookup") - for operator in ("as", "connectFromField", "connectToField", "from"): - if operator not in options: - raise OperationFailure( - "Must specify '%s' field for a $graphLookup" % operator - ) - if not isinstance(options[operator], str): - raise OperationFailure( - "Argument '%s' to $graphLookup must be string" % operator - ) - if options[operator].startswith("$"): - raise OperationFailure("FieldPath field names may not start with '$'") - if operator == "as" and "." in options[operator]: - raise NotImplementedError( - "Although '.' is valid in the '%s' " - "parameter for the $graphLookup stage of the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - foreign_name = options["from"] - start_with = options["startWith"] - connect_from_field = options["connectFromField"] - connect_to_field = options["connectToField"] - local_name = options["as"] - max_depth = options.get("maxDepth", None) - depth_field = options.get("depthField", None) - restrict_search_with_match = options.get("restrictSearchWithMatch", {}) - foreign_collection = database.get_collection(foreign_name) - out_doc = copy.deepcopy(in_collection) # TODO(pascal): speed the deep copy - - def _find_matches_for_depth(query): - if isinstance(query, list): - query = {"$in": query} - matches = foreign_collection.find({connect_to_field: query}) - new_matches = [] - for new_match in matches: - if ( - filtering.filter_applies(restrict_search_with_match, new_match) - and new_match["_id"] not in found_items - ): - if depth_field is not None: - new_match = collections.OrderedDict( - new_match, **{depth_field: depth} - ) - new_matches.append(new_match) - found_items.add(new_match["_id"]) - return new_matches - - for doc in out_doc: - found_items = set() - depth = 0 - try: - result = _parse_expression(start_with, doc) - except KeyError: - continue - origin_matches = doc[local_name] = _find_matches_for_depth(result) - while origin_matches and (max_depth is None or depth < max_depth): - depth += 1 - newly_discovered_matches = [] - for match in origin_matches: - nested_fields = connect_from_field.split(".") - for match_target in _recursive_get(match, nested_fields): - newly_discovered_matches += _find_matches_for_depth(match_target) - doc[local_name] += newly_discovered_matches - origin_matches = newly_discovered_matches - return out_doc - - -def _handle_group_stage(in_collection, unused_database, options): - grouped_collection = [] - _id = options["_id"] - if _id: - - def _key_getter(doc): - try: - return _parse_expression(_id, doc, ignore_missing_keys=True) - except KeyError: - return None - - def _sort_key_getter(doc): - return filtering.BsonComparable(_key_getter(doc)) - - # Sort the collection only for the itertools.groupby. - # $group does not order its output document. - sorted_collection = sorted(in_collection, key=_sort_key_getter) - grouped = itertools.groupby(sorted_collection, _key_getter) - else: - grouped = [(None, in_collection)] - - for doc_id, group in grouped: - group_list = [x for x in group] - doc_dict = _accumulate_group(options, group_list) - doc_dict["_id"] = doc_id - grouped_collection.append(doc_dict) - - return grouped_collection - - -def _handle_bucket_stage(in_collection, unused_database, options): - unknown_options = set(options) - {"groupBy", "boundaries", "output", "default"} - if unknown_options: - raise OperationFailure( - "Unrecognized option to $bucket: %s." % unknown_options.pop() - ) - if "groupBy" not in options or "boundaries" not in options: - raise OperationFailure( - "$bucket requires 'groupBy' and 'boundaries' to be specified." - ) - group_by = options["groupBy"] - boundaries = options["boundaries"] - if not isinstance(boundaries, list): - raise OperationFailure( - "The $bucket 'boundaries' field must be an array, but found type: %s" - % type(boundaries) - ) - if len(boundaries) < 2: - raise OperationFailure( - "The $bucket 'boundaries' field must have at least 2 values, but " - "found %d value(s)." % len(boundaries) - ) - if sorted(boundaries) != boundaries: - raise OperationFailure( - "The 'boundaries' option to $bucket must be sorted in ascending order" - ) - output_fields = options.get("output", {"count": {"$sum": 1}}) - default_value = options.get("default", None) - try: - is_default_last = default_value >= boundaries[-1] - except TypeError: - is_default_last = True - - def _get_default_bucket(): - try: - return options["default"] - except KeyError as err: - raise OperationFailure( - "$bucket could not find a matching branch for " - "an input, and no default was specified." - ) from err - - def _get_bucket_id(doc): - """Get the bucket ID for a document. - - Note that it actually returns a tuple with the first - param being a sort key to sort the default bucket even - if it's not the same type as the boundaries. - """ - try: - value = _parse_expression(group_by, doc) - except KeyError: - return (is_default_last, _get_default_bucket()) - index = bisect.bisect_right(boundaries, value) - if index and index < len(boundaries): - return (False, boundaries[index - 1]) - return (is_default_last, _get_default_bucket()) - - in_collection = ((_get_bucket_id(doc), doc) for doc in in_collection) - out_collection = sorted(in_collection, key=lambda kv: kv[0]) - grouped = itertools.groupby(out_collection, lambda kv: kv[0]) - - out_collection = [] - for (unused_key, doc_id), group in grouped: - group_list = [kv[1] for kv in group] - doc_dict = _accumulate_group(output_fields, group_list) - doc_dict["_id"] = doc_id - out_collection.append(doc_dict) - return out_collection - - -def _handle_sample_stage(in_collection, unused_database, options): - if not isinstance(options, dict): - raise OperationFailure("the $sample stage specification must be an object") - size = options.pop("size", None) - if size is None: - raise OperationFailure("$sample stage must specify a size") - if options: - raise OperationFailure( - "unrecognized option to $sample: %s" % set(options).pop() - ) - shuffled = list(in_collection) - _random.shuffle(shuffled) - return shuffled[:size] - - -def _handle_sort_stage(in_collection, unused_database, options): - sort_array = reversed([{x: y} for x, y in options.items()]) - sorted_collection = in_collection - for sort_pair in sort_array: - for sortKey, sortDirection in sort_pair.items(): - sorted_collection = sorted( - sorted_collection, - key=lambda x: filtering.resolve_sort_key(sortKey, x), - reverse=sortDirection < 0, - ) - return sorted_collection - - -def _handle_unwind_stage(in_collection, unused_database, options): - if not isinstance(options, dict): - options = {"path": options} - path = options["path"] - if not isinstance(path, str) or path[0] != "$": - raise ValueError( - "$unwind failed: exception: field path references must be prefixed " - "with a '$' '%s'" % path - ) - path = path[1:] - should_preserve_null_and_empty = options.get("preserveNullAndEmptyArrays") - include_array_index = options.get("includeArrayIndex") - unwound_collection = [] - for doc in in_collection: - try: - array_value = helpers.get_value_by_dot(doc, path) - except KeyError: - if should_preserve_null_and_empty: - unwound_collection.append(doc) - continue - if array_value is None: - if should_preserve_null_and_empty: - unwound_collection.append(doc) - continue - if array_value == []: - if should_preserve_null_and_empty: - new_doc = copy.deepcopy(doc) - # We just ran a get_value_by_dot so we know the value exists. - helpers.delete_value_by_dot(new_doc, path) - unwound_collection.append(new_doc) - continue - if isinstance(array_value, list): - iter_array = enumerate(array_value) - else: - iter_array = [(None, array_value)] - for index, field_item in iter_array: - new_doc = copy.deepcopy(doc) - new_doc = helpers.set_value_by_dot(new_doc, path, field_item) - if include_array_index: - new_doc = helpers.set_value_by_dot(new_doc, include_array_index, index) - unwound_collection.append(new_doc) - - return unwound_collection - - -# TODO(pascal): Combine with the equivalent function in collection but check -# what are the allowed overriding. -def _combine_projection_spec(filter_list, original_filter, prefix=""): - """Re-format a projection fields spec into a nested dictionary. - - e.g: ['a', 'b.c', 'b.d'] => {'a': 1, 'b': {'c': 1, 'd': 1}} - """ - if not isinstance(filter_list, list): - return filter_list - - filter_dict = collections.OrderedDict() - - for key in filter_list: - field, separator, subkey = key.partition(".") - if not separator: - if isinstance(filter_dict.get(field), list): - other_key = field + "." + filter_dict[field][0] - raise OperationFailure( - "Invalid $project :: caused by :: specification contains two conflicting paths." - " Cannot specify both %s and %s: %s" - % (repr(prefix + field), repr(prefix + other_key), original_filter) - ) - filter_dict[field] = 1 - continue - if not isinstance(filter_dict.get(field, []), list): - raise OperationFailure( - "Invalid $project :: caused by :: specification contains two conflicting paths." - " Cannot specify both %s and %s: %s" - % (repr(prefix + field), repr(prefix + key), original_filter) - ) - filter_dict[field] = filter_dict.get(field, []) + [subkey] - - return collections.OrderedDict( - (k, _combine_projection_spec(v, original_filter, prefix="%s%s." % (prefix, k))) - for k, v in filter_dict.items() - ) - - -def _project_by_spec(doc, proj_spec, is_include): - output = {} - for key, value in doc.items(): - if key not in proj_spec: - if not is_include: - output[key] = value - continue - - if not isinstance(proj_spec[key], dict): - if is_include: - output[key] = value - continue - - if isinstance(value, dict): - output[key] = _project_by_spec(value, proj_spec[key], is_include) - elif isinstance(value, list): - output[key] = [ - _project_by_spec(array_value, proj_spec[key], is_include) - for array_value in value - if isinstance(array_value, dict) - ] - elif not is_include: - output[key] = value - - return output - - -def _handle_replace_root_stage(in_collection, unused_database, options): - if "newRoot" not in options: - raise OperationFailure( - "Parameter 'newRoot' is missing for $replaceRoot operation." - ) - new_root = options["newRoot"] - out_collection = [] - for doc in in_collection: - try: - new_doc = _parse_expression(new_root, doc, ignore_missing_keys=True) - except KeyError: - new_doc = None - if not isinstance(new_doc, dict): - raise OperationFailure( - "'newRoot' expression must evaluate to an object, but resulting value was: {}".format( - new_doc - ) - ) - out_collection.append(new_doc) - return out_collection - - -def _handle_project_stage(in_collection, unused_database, options): - filter_list = [] - method = None - include_id = options.get("_id") - # Compute new values for each field, except inclusion/exclusions that are - # handled in one final step. - new_fields_collection = None - for field, value in options.items(): - if method is None and (field != "_id" or value): - method = "include" if value else "exclude" - elif method == "include" and not value and field != "_id": - raise OperationFailure( - "Bad projection specification, cannot exclude fields " - "other than '_id' in an inclusion projection: %s" % options - ) - elif method == "exclude" and value: - raise OperationFailure( - "Bad projection specification, cannot include fields " - "or add computed fields during an exclusion projection: %s" % options - ) - if value in (0, 1, True, False): - if field != "_id": - filter_list.append(field) - continue - if not new_fields_collection: - new_fields_collection = [{} for unused_doc in in_collection] - - for in_doc, out_doc in zip(in_collection, new_fields_collection): - try: - out_doc[field] = _parse_expression( - value, in_doc, ignore_missing_keys=True - ) - except KeyError: - # Ignore missing key. - pass - if (method == "include") == (include_id is not False and include_id != 0): - filter_list.append("_id") - - if not filter_list: - return new_fields_collection - - # Final steps: include or exclude fields and merge with newly created fields. - projection_spec = _combine_projection_spec(filter_list, original_filter=options) - out_collection = [ - _project_by_spec(doc, projection_spec, is_include=(method == "include")) - for doc in in_collection - ] - if new_fields_collection: - return [dict(a, **b) for a, b in zip(out_collection, new_fields_collection)] - return out_collection - - -def _handle_add_fields_stage(in_collection, unused_database, options): - if not options: - raise OperationFailure( - "Invalid $addFields :: caused by :: specification must have at least one field" - ) - out_collection = [dict(doc) for doc in in_collection] - for field, value in options.items(): - for in_doc, out_doc in zip(in_collection, out_collection): - try: - out_value = _parse_expression(value, in_doc, ignore_missing_keys=True) - except KeyError: - continue - parts = field.split(".") - for subfield in parts[:-1]: - out_doc[subfield] = out_doc.get(subfield, {}) - if not isinstance(out_doc[subfield], dict): - out_doc[subfield] = {} - out_doc = out_doc[subfield] - out_doc[parts[-1]] = out_value - return out_collection - - -def _handle_out_stage(in_collection, database, options): - # TODO(MetrodataTeam): should leave the origin collection unchanged - out_collection = database.get_collection(options) - if out_collection.find_one(): - out_collection.drop() - if in_collection: - out_collection.insert_many(in_collection) - return in_collection - - -def _handle_count_stage(in_collection, database, options): - if not isinstance(options, str) or options == "": - raise OperationFailure("the count field must be a non-empty string") - elif options.startswith("$"): - raise OperationFailure("the count field cannot be a $-prefixed path") - elif "." in options: - raise OperationFailure("the count field cannot contain '.'") - return [{options: len(in_collection)}] - - -def _handle_facet_stage(in_collection, database, options): - out_collection_by_pipeline = {} - for pipeline_title, pipeline in options.items(): - out_collection_by_pipeline[pipeline_title] = list( - process_pipeline(in_collection, database, pipeline, None) - ) - return [out_collection_by_pipeline] - - -def _handle_match_stage(in_collection, database, options): - spec = helpers.patch_datetime_awareness_in_document(options) - return [ - doc - for doc in in_collection - if filtering.filter_applies( - spec, helpers.patch_datetime_awareness_in_document(doc) - ) - ] - - -_PIPELINE_HANDLERS = { - "$addFields": _handle_add_fields_stage, - "$bucket": _handle_bucket_stage, - "$bucketAuto": None, - "$collStats": None, - "$count": _handle_count_stage, - "$currentOp": None, - "$facet": _handle_facet_stage, - "$geoNear": None, - "$graphLookup": _handle_graph_lookup_stage, - "$group": _handle_group_stage, - "$indexStats": None, - "$limit": lambda c, d, o: c[:o], - "$listLocalSessions": None, - "$listSessions": None, - "$lookup": _handle_lookup_stage, - "$match": _handle_match_stage, - "$merge": None, - "$out": _handle_out_stage, - "$planCacheStats": None, - "$project": _handle_project_stage, - "$redact": None, - "$replaceRoot": _handle_replace_root_stage, - "$replaceWith": None, - "$sample": _handle_sample_stage, - "$set": _handle_add_fields_stage, - "$skip": lambda c, d, o: c[o:], - "$sort": _handle_sort_stage, - "$sortByCount": None, - "$unset": None, - "$unwind": _handle_unwind_stage, -} - - -def process_pipeline(collection, database, pipeline, session): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - for stage in pipeline: - for operator, options in stage.items(): - try: - handler = _PIPELINE_HANDLERS[operator] - except KeyError as err: - raise NotImplementedError( - "%s is not a valid operator for the aggregation pipeline. " - "See http://docs.mongodb.org/manual/meta/aggregation-quick-reference/ " - "for a complete list of valid operators." % operator - ) from err - if not handler: - raise NotImplementedError( - "Although '%s' is a valid operator for the aggregation pipeline, it is " - "currently not implemented in Mongomock." % operator - ) - collection = handler(collection, database, options) - - return command_cursor.CommandCursor(collection) diff --git a/packages/syft/tests/mongomock/codec_options.py b/packages/syft/tests/mongomock/codec_options.py deleted file mode 100644 index e71eb41d672..00000000000 --- a/packages/syft/tests/mongomock/codec_options.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Tools for specifying BSON codec options.""" - -# stdlib -import collections - -# third party -from packaging import version - -# relative -from . import helpers - -try: - # third party - from bson import codec_options - from pymongo.common import _UUID_REPRESENTATIONS -except ImportError: - codec_options = None - _UUID_REPRESENTATIONS = None - - -class TypeRegistry(object): - pass - - -_FIELDS = ( - "document_class", - "tz_aware", - "uuid_representation", - "unicode_decode_error_handler", - "tzinfo", -) - -if codec_options and helpers.PYMONGO_VERSION >= version.parse("3.8"): - _DEFAULT_TYPE_REGISTRY = codec_options.TypeRegistry() - _FIELDS = _FIELDS + ("type_registry",) -else: - _DEFAULT_TYPE_REGISTRY = TypeRegistry() - -if codec_options and helpers.PYMONGO_VERSION >= version.parse("4.3.0"): - _DATETIME_CONVERSION_VALUES = codec_options.DatetimeConversion._value2member_map_ - _DATETIME_CONVERSION_DEFAULT_VALUE = codec_options.DatetimeConversion.DATETIME - _FIELDS = _FIELDS + ("datetime_conversion",) -else: - _DATETIME_CONVERSION_VALUES = () - _DATETIME_CONVERSION_DEFAULT_VALUE = None - -# New default in Pymongo v4: -# https://pymongo.readthedocs.io/en/stable/examples/uuid.html#unspecified -if helpers.PYMONGO_VERSION >= version.parse("4.0"): - _DEFAULT_UUID_REPRESENTATION = 0 -else: - _DEFAULT_UUID_REPRESENTATION = 3 - - -class CodecOptions(collections.namedtuple("CodecOptions", _FIELDS)): - def __new__( - cls, - document_class=dict, - tz_aware=False, - uuid_representation=None, - unicode_decode_error_handler="strict", - tzinfo=None, - type_registry=None, - datetime_conversion=_DATETIME_CONVERSION_DEFAULT_VALUE, - ): - if document_class != dict: - raise NotImplementedError( - "Mongomock does not implement custom document_class yet: %r" - % document_class - ) - - if not isinstance(tz_aware, bool): - raise TypeError("tz_aware must be True or False") - - if uuid_representation is None: - uuid_representation = _DEFAULT_UUID_REPRESENTATION - - if unicode_decode_error_handler not in ("strict", None): - raise NotImplementedError( - "Mongomock does not handle custom unicode_decode_error_handler yet" - ) - - if tzinfo: - raise NotImplementedError("Mongomock does not handle custom tzinfo yet") - - values = ( - document_class, - tz_aware, - uuid_representation, - unicode_decode_error_handler, - tzinfo, - ) - - if "type_registry" in _FIELDS: - if not type_registry: - type_registry = _DEFAULT_TYPE_REGISTRY - values = values + (type_registry,) - - if "datetime_conversion" in _FIELDS: - if ( - datetime_conversion - and datetime_conversion not in _DATETIME_CONVERSION_VALUES - ): - raise TypeError( - "datetime_conversion must be member of DatetimeConversion" - ) - values = values + (datetime_conversion,) - - return tuple.__new__(cls, values) - - def with_options(self, **kwargs): - opts = self._asdict() - opts.update(kwargs) - return CodecOptions(**opts) - - def to_pymongo(self): - if not codec_options: - return None - - uuid_representation = self.uuid_representation - if _UUID_REPRESENTATIONS and isinstance(self.uuid_representation, str): - uuid_representation = _UUID_REPRESENTATIONS[uuid_representation] - - return codec_options.CodecOptions( - uuid_representation=uuid_representation, - unicode_decode_error_handler=self.unicode_decode_error_handler, - type_registry=self.type_registry, - ) - - -def is_supported(custom_codec_options): - if not custom_codec_options: - return None - - return CodecOptions(**custom_codec_options._asdict()) diff --git a/packages/syft/tests/mongomock/collection.py b/packages/syft/tests/mongomock/collection.py deleted file mode 100644 index 8a677300355..00000000000 --- a/packages/syft/tests/mongomock/collection.py +++ /dev/null @@ -1,2596 +0,0 @@ -# future -from __future__ import division - -# stdlib -import collections -from collections import OrderedDict -from collections.abc import Iterable -from collections.abc import Mapping -from collections.abc import MutableMapping -import copy -import functools -import itertools -import json -import math -import time -import warnings - -# third party -from packaging import version - -try: - # third party - from bson import BSON - from bson import SON - from bson import json_util - from bson.codec_options import CodecOptions - from bson.errors import InvalidDocument -except ImportError: - json_utils = SON = BSON = None - CodecOptions = None -try: - # third party - import execjs -except ImportError: - execjs = None - -try: - # third party - from pymongo import ReadPreference - from pymongo import ReturnDocument - from pymongo.operations import IndexModel - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - - class IndexModel(object): - pass - - class ReturnDocument(object): - BEFORE = False - AFTER = True - - from .read_preferences import PRIMARY as _READ_PREFERENCE_PRIMARY - -# relative -from . import BulkWriteError -from . import ConfigurationError -from . import DuplicateKeyError -from . import InvalidOperation -from . import ObjectId -from . import OperationFailure -from . import WriteError -from . import aggregate -from . import codec_options as mongomock_codec_options -from . import filtering -from . import helpers -from . import utcnow -from .filtering import filter_applies -from .not_implemented import raise_for_feature as raise_not_implemented -from .results import BulkWriteResult -from .results import DeleteResult -from .results import InsertManyResult -from .results import InsertOneResult -from .results import UpdateResult -from .write_concern import WriteConcern - -try: - # third party - from pymongo.read_concern import ReadConcern -except ImportError: - # relative - from .read_concern import ReadConcern - -_KwargOption = collections.namedtuple("KwargOption", ["typename", "default", "attrs"]) - -_WITH_OPTIONS_KWARGS = { - "read_preference": _KwargOption( - "pymongo.read_preference.ReadPreference", - _READ_PREFERENCE_PRIMARY, - ("document", "mode", "mongos_mode", "max_staleness"), - ), - "write_concern": _KwargOption( - "pymongo.write_concern.WriteConcern", - WriteConcern(), - ("acknowledged", "document"), - ), -} - - -def _bson_encode(document, codec_options): - if CodecOptions: - if isinstance(codec_options, mongomock_codec_options.CodecOptions): - codec_options = codec_options.to_pymongo() - if isinstance(codec_options, CodecOptions): - BSON.encode(document, check_keys=True, codec_options=codec_options) - else: - BSON.encode(document, check_keys=True) - - -def validate_is_mapping(option, value): - if not isinstance(value, Mapping): - raise TypeError( - "%s must be an instance of dict, bson.son.SON, or " - "other type that inherits from " - "collections.Mapping" % (option,) - ) - - -def validate_is_mutable_mapping(option, value): - if not isinstance(value, MutableMapping): - raise TypeError( - "%s must be an instance of dict, bson.son.SON, or " - "other type that inherits from " - "collections.MutableMapping" % (option,) - ) - - -def validate_ok_for_replace(replacement): - validate_is_mapping("replacement", replacement) - if replacement: - first = next(iter(replacement)) - if first.startswith("$"): - raise ValueError("replacement can not include $ operators") - - -def validate_ok_for_update(update): - validate_is_mapping("update", update) - if not update: - raise ValueError("update only works with $ operators") - first = next(iter(update)) - if not first.startswith("$"): - raise ValueError("update only works with $ operators") - - -def validate_write_concern_params(**params): - if params: - WriteConcern(**params) - - -class BulkWriteOperation(object): - def __init__(self, builder, selector, is_upsert=False): - self.builder = builder - self.selector = selector - self.is_upsert = is_upsert - - def upsert(self): - assert not self.is_upsert - return BulkWriteOperation(self.builder, self.selector, is_upsert=True) - - def register_remove_op(self, multi, hint=None): - collection = self.builder.collection - selector = self.selector - - def exec_remove(): - if multi: - op_result = collection.delete_many(selector, hint=hint).raw_result - else: - op_result = collection.delete_one(selector, hint=hint).raw_result - if op_result.get("ok"): - return {"nRemoved": op_result.get("n")} - err = op_result.get("err") - if err: - return {"writeErrors": [err]} - return {} - - self.builder.executors.append(exec_remove) - - def remove(self): - assert not self.is_upsert - self.register_remove_op(multi=True) - - def remove_one( - self, - ): - assert not self.is_upsert - self.register_remove_op(multi=False) - - def register_update_op(self, document, multi, **extra_args): - if not extra_args.get("remove"): - validate_ok_for_update(document) - - collection = self.builder.collection - selector = self.selector - - def exec_update(): - result = collection._update( - spec=selector, - document=document, - multi=multi, - upsert=self.is_upsert, - **extra_args, - ) - ret_val = {} - if result.get("upserted"): - ret_val["upserted"] = result.get("upserted") - ret_val["nUpserted"] = result.get("n") - else: - matched = result.get("n") - if matched is not None: - ret_val["nMatched"] = matched - modified = result.get("nModified") - if modified is not None: - ret_val["nModified"] = modified - if result.get("err"): - ret_val["err"] = result.get("err") - return ret_val - - self.builder.executors.append(exec_update) - - def update(self, document, hint=None): - self.register_update_op(document, multi=True, hint=hint) - - def update_one(self, document, hint=None): - self.register_update_op(document, multi=False, hint=hint) - - def replace_one(self, document, hint=None): - self.register_update_op(document, multi=False, remove=True, hint=hint) - - -def _combine_projection_spec(projection_fields_spec): - """Re-format a projection fields spec into a nested dictionary. - - e.g: {'a': 1, 'b.c': 1, 'b.d': 1} => {'a': 1, 'b': {'c': 1, 'd': 1}} - """ - - tmp_spec = OrderedDict() - for f, v in projection_fields_spec.items(): - if "." not in f: - if isinstance(tmp_spec.get(f), dict): - if not v: - raise NotImplementedError( - "Mongomock does not support overriding excluding projection: %s" - % projection_fields_spec - ) - raise OperationFailure("Path collision at %s" % f) - tmp_spec[f] = v - else: - split_field = f.split(".", 1) - base_field, new_field = tuple(split_field) - if not isinstance(tmp_spec.get(base_field), dict): - if base_field in tmp_spec: - raise OperationFailure( - "Path collision at %s remaining portion %s" % (f, new_field) - ) - tmp_spec[base_field] = OrderedDict() - tmp_spec[base_field][new_field] = v - - combined_spec = OrderedDict() - for f, v in tmp_spec.items(): - if isinstance(v, dict): - combined_spec[f] = _combine_projection_spec(v) - else: - combined_spec[f] = v - - return combined_spec - - -def _project_by_spec(doc, combined_projection_spec, is_include, container): - if "$" in combined_projection_spec: - if is_include: - raise NotImplementedError( - "Positional projection is not implemented in mongomock" - ) - raise OperationFailure( - "Cannot exclude array elements with the positional operator" - ) - - doc_copy = container() - - for key, val in doc.items(): - spec = combined_projection_spec.get(key, None) - if isinstance(spec, dict): - if isinstance(val, (list, tuple)): - doc_copy[key] = [ - _project_by_spec(sub_doc, spec, is_include, container) - for sub_doc in val - ] - elif isinstance(val, dict): - doc_copy[key] = _project_by_spec(val, spec, is_include, container) - elif (is_include and spec is not None) or (not is_include and spec is None): - doc_copy[key] = _copy_field(val, container) - - return doc_copy - - -def _copy_field(obj, container): - if isinstance(obj, list): - new = [] - for item in obj: - new.append(_copy_field(item, container)) - return new - if isinstance(obj, dict): - new = container() - for key, value in obj.items(): - new[key] = _copy_field(value, container) - return new - return copy.copy(obj) - - -def _recursive_key_check_null_character(data): - for key, value in data.items(): - if "\0" in key: - raise InvalidDocument( - f"Field names cannot contain the null character (found: {key})" - ) - if isinstance(value, Mapping): - _recursive_key_check_null_character(value) - - -def _validate_data_fields(data): - _recursive_key_check_null_character(data) - for key in data.keys(): - if key.startswith("$"): - raise InvalidDocument( - f'Top-level field names cannot start with the "$" sign ' - f"(found: {key})" - ) - - -class BulkOperationBuilder(object): - def __init__(self, collection, ordered=False, bypass_document_validation=False): - self.collection = collection - self.ordered = ordered - self.results = {} - self.executors = [] - self.done = False - self._insert_returns_nModified = True - self._update_returns_nModified = True - self._bypass_document_validation = bypass_document_validation - - def find(self, selector): - return BulkWriteOperation(self, selector) - - def insert(self, doc): - def exec_insert(): - self.collection.insert_one( - doc, bypass_document_validation=self._bypass_document_validation - ) - return {"nInserted": 1} - - self.executors.append(exec_insert) - - def __aggregate_operation_result(self, total_result, key, value): - agg_val = total_result.get(key) - assert agg_val is not None, ( - "Unknow operation result %s=%s" " (unrecognized key)" % (key, value) - ) - if isinstance(agg_val, int): - total_result[key] += value - elif isinstance(agg_val, list): - if key == "upserted": - new_element = {"index": len(agg_val), "_id": value} - agg_val.append(new_element) - else: - agg_val.append(value) - else: - assert False, ( - "Fixme: missed aggreation rule for type: %s for" - " key {%s=%s}" - % ( - type(agg_val), - key, - agg_val, - ) - ) - - def _set_nModified_policy(self, insert, update): - self._insert_returns_nModified = insert - self._update_returns_nModified = update - - def execute(self, write_concern=None): - if not self.executors: - raise InvalidOperation("Bulk operation empty!") - if self.done: - raise InvalidOperation("Bulk operation already executed!") - self.done = True - result = { - "nModified": 0, - "nUpserted": 0, - "nMatched": 0, - "writeErrors": [], - "upserted": [], - "writeConcernErrors": [], - "nRemoved": 0, - "nInserted": 0, - } - - has_update = False - has_insert = False - broken_nModified_info = False - for index, execute_func in enumerate(self.executors): - exec_name = execute_func.__name__ - try: - op_result = execute_func() - except WriteError as error: - result["writeErrors"].append( - { - "index": index, - "code": error.code, - "errmsg": str(error), - } - ) - if self.ordered: - break - continue - for key, value in op_result.items(): - self.__aggregate_operation_result(result, key, value) - if exec_name == "exec_update": - has_update = True - if "nModified" not in op_result: - broken_nModified_info = True - has_insert |= exec_name == "exec_insert" - - if broken_nModified_info: - result.pop("nModified") - elif has_insert and self._insert_returns_nModified: - pass - elif has_update and self._update_returns_nModified: - pass - elif self._update_returns_nModified and self._insert_returns_nModified: - pass - else: - result.pop("nModified") - - if result.get("writeErrors"): - raise BulkWriteError(result) - - return result - - def add_insert(self, doc): - self.insert(doc) - - def add_update( - self, - selector, - doc, - multi=False, - upsert=False, - collation=None, - array_filters=None, - hint=None, - ): - if array_filters: - raise_not_implemented( - "array_filters", "Array filters are not implemented in mongomock yet." - ) - write_operation = BulkWriteOperation(self, selector, is_upsert=upsert) - write_operation.register_update_op(doc, multi, hint=hint) - - def add_replace(self, selector, doc, upsert, collation=None, hint=None): - write_operation = BulkWriteOperation(self, selector, is_upsert=upsert) - write_operation.replace_one(doc, hint=hint) - - def add_delete(self, selector, just_one, collation=None, hint=None): - write_operation = BulkWriteOperation(self, selector, is_upsert=False) - write_operation.register_remove_op(not just_one, hint=hint) - - -class Collection(object): - def __init__( - self, - database, - name, - _db_store, - write_concern=None, - read_concern=None, - read_preference=None, - codec_options=None, - ): - self.database = database - self._name = name - self._db_store = _db_store - self._write_concern = write_concern or WriteConcern() - if read_concern and not isinstance(read_concern, ReadConcern): - raise TypeError( - "read_concern must be an instance of pymongo.read_concern.ReadConcern" - ) - self._read_concern = read_concern or ReadConcern() - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - self._codec_options = codec_options or mongomock_codec_options.CodecOptions() - - def __repr__(self): - return "Collection({0}, '{1}')".format(self.database, self.name) - - def __getitem__(self, name): - return self.database[self.name + "." + name] - - def __getattr__(self, attr): - if attr.startswith("_"): - raise AttributeError( - "%s has no attribute '%s'. To access the %s.%s collection, use database['%s.%s']." - % (self.__class__.__name__, attr, self.name, attr, self.name, attr) - ) - return self.__getitem__(attr) - - def __call__(self, *args, **kwargs): - name = self._name if "." not in self._name else self._name.split(".")[-1] - raise TypeError( - "'Collection' object is not callable. If you meant to call the '%s' method on a " - "'Collection' object it is failing because no such method exists." % name - ) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.database == other.database and self.name == other.name - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash((self.database, self.name)) - - @property - def full_name(self): - return "{0}.{1}".format(self.database.name, self._name) - - @property - def name(self): - return self._name - - @property - def write_concern(self): - return self._write_concern - - @property - def read_concern(self): - return self._read_concern - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - def initialize_unordered_bulk_op(self, bypass_document_validation=False): - return BulkOperationBuilder( - self, ordered=False, bypass_document_validation=bypass_document_validation - ) - - def initialize_ordered_bulk_op(self, bypass_document_validation=False): - return BulkOperationBuilder( - self, ordered=True, bypass_document_validation=bypass_document_validation - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def insert( - self, - data, - manipulate=True, - check_keys=True, - continue_on_error=False, - **kwargs, - ): - warnings.warn( - "insert is deprecated. Use insert_one or insert_many " "instead.", - DeprecationWarning, - stacklevel=2, - ) - validate_write_concern_params(**kwargs) - return self._insert(data) - - def insert_one(self, document, bypass_document_validation=False, session=None): - if not bypass_document_validation: - validate_is_mutable_mapping("document", document) - return InsertOneResult(self._insert(document, session), acknowledged=True) - - def insert_many( - self, documents, ordered=True, bypass_document_validation=False, session=None - ): - if not isinstance(documents, Iterable) or not documents: - raise TypeError("documents must be a non-empty list") - documents = list(documents) - if not bypass_document_validation: - for document in documents: - validate_is_mutable_mapping("document", document) - return InsertManyResult( - self._insert(documents, session, ordered=ordered), acknowledged=True - ) - - @property - def _store(self): - return self._db_store[self._name] - - def _insert(self, data, session=None, ordered=True): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if not isinstance(data, Mapping): - results = [] - write_errors = [] - num_inserted = 0 - for index, item in enumerate(data): - try: - results.append(self._insert(item)) - except WriteError as error: - write_errors.append( - { - "index": index, - "code": error.code, - "errmsg": str(error), - "op": item, - } - ) - if ordered: - break - else: - continue - num_inserted += 1 - if write_errors: - raise BulkWriteError( - { - "writeErrors": write_errors, - "nInserted": num_inserted, - } - ) - return results - - if not all(isinstance(k, str) for k in data): - raise ValueError("Document keys must be strings") - - if BSON: - # bson validation - check_keys = helpers.PYMONGO_VERSION < version.parse("3.6") - if not check_keys: - _validate_data_fields(data) - - _bson_encode(data, self._codec_options) - - # Like pymongo, we should fill the _id in the inserted dict (odd behavior, - # but we need to stick to it), so we must patch in-place the data dict - if "_id" not in data: - data["_id"] = ObjectId() - - object_id = data["_id"] - if isinstance(object_id, dict): - object_id = helpers.hashdict(object_id) - if object_id in self._store: - raise DuplicateKeyError("E11000 Duplicate Key Error", 11000) - - data = helpers.patch_datetime_awareness_in_document(data) - - self._store[object_id] = data - try: - self._ensure_uniques(data) - except DuplicateKeyError: - # Rollback - del self._store[object_id] - raise - return data["_id"] - - def _ensure_uniques(self, new_data): - # Note we consider new_data is already inserted in db - for index in self._store.indexes.values(): - if not index.get("unique"): - continue - unique = index.get("key") - is_sparse = index.get("sparse") - partial_filter_expression = index.get("partialFilterExpression") - find_kwargs = {} - for key, _ in unique: - try: - find_kwargs[key] = helpers.get_value_by_dot(new_data, key) - except KeyError: - find_kwargs[key] = None - if is_sparse and set(find_kwargs.values()) == {None}: - continue - if partial_filter_expression is not None: - find_kwargs = {"$and": [partial_filter_expression, find_kwargs]} - answer_count = len(list(self._iter_documents(find_kwargs))) - if answer_count > 1: - raise DuplicateKeyError("E11000 Duplicate Key Error", 11000) - - def _internalize_dict(self, d): - return {k: copy.deepcopy(v) for k, v in d.items()} - - def _has_key(self, doc, key): - key_parts = key.split(".") - sub_doc = doc - for part in key_parts: - if part not in sub_doc: - return False - sub_doc = sub_doc[part] - return True - - def update_one( - self, - filter, - update, - upsert=False, - bypass_document_validation=False, - collation=None, - array_filters=None, - hint=None, - session=None, - let=None, - ): - if not bypass_document_validation: - validate_ok_for_update(update) - return UpdateResult( - self._update( - filter, - update, - upsert=upsert, - hint=hint, - session=session, - collation=collation, - array_filters=array_filters, - let=let, - ), - acknowledged=True, - ) - - def update_many( - self, - filter, - update, - upsert=False, - array_filters=None, - bypass_document_validation=False, - collation=None, - hint=None, - session=None, - let=None, - ): - if not bypass_document_validation: - validate_ok_for_update(update) - return UpdateResult( - self._update( - filter, - update, - upsert=upsert, - multi=True, - hint=hint, - session=session, - collation=collation, - array_filters=array_filters, - let=let, - ), - acknowledged=True, - ) - - def replace_one( - self, - filter, - replacement, - upsert=False, - bypass_document_validation=False, - session=None, - hint=None, - ): - if not bypass_document_validation: - validate_ok_for_replace(replacement) - return UpdateResult( - self._update( - filter, replacement, upsert=upsert, hint=hint, session=session - ), - acknowledged=True, - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def update( - self, - spec, - document, - upsert=False, - manipulate=False, - multi=False, - check_keys=False, - **kwargs, - ): - warnings.warn( - "update is deprecated. Use replace_one, update_one or " - "update_many instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._update( - spec, document, upsert, manipulate, multi, check_keys, **kwargs - ) - - def _update( - self, - spec, - document, - upsert=False, - manipulate=False, - multi=False, - check_keys=False, - hint=None, - session=None, - collation=None, - let=None, - array_filters=None, - **kwargs, - ): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if hint: - raise NotImplementedError( - "The hint argument of update is valid but has not been implemented in " - "mongomock yet" - ) - if collation: - raise_not_implemented( - "collation", - "The collation argument of update is valid but has not been implemented in " - "mongomock yet", - ) - if array_filters: - raise_not_implemented( - "array_filters", "Array filters are not implemented in mongomock yet." - ) - if let: - raise_not_implemented( - "let", - "The let argument of update is valid but has not been implemented in mongomock " - "yet", - ) - spec = helpers.patch_datetime_awareness_in_document(spec) - document = helpers.patch_datetime_awareness_in_document(document) - validate_is_mapping("spec", spec) - validate_is_mapping("document", document) - - if self.database.client.server_info()["versionArray"] < [5]: - for operator in _updaters: - if not document.get(operator, True): - raise WriteError( - "'%s' is empty. You must specify a field like so: {%s: {: ...}}" - % (operator, operator), - ) - - updated_existing = False - upserted_id = None - num_updated = 0 - num_matched = 0 - for existing_document in itertools.chain(self._iter_documents(spec), [None]): - # we need was_insert for the setOnInsert update operation - was_insert = False - # the sentinel document means we should do an upsert - if existing_document is None: - if not upsert or num_matched: - continue - # For upsert operation we have first to create a fake existing_document, - # update it like a regular one, then finally insert it - if spec.get("_id") is not None: - _id = spec["_id"] - elif document.get("_id") is not None: - _id = document["_id"] - else: - _id = ObjectId() - to_insert = dict(spec, _id=_id) - to_insert = self._expand_dots(to_insert) - to_insert, _ = self._discard_operators(to_insert) - existing_document = to_insert - was_insert = True - else: - original_document_snapshot = copy.deepcopy(existing_document) - updated_existing = True - num_matched += 1 - first = True - subdocument = None - for k, v in document.items(): - if k in _updaters: - updater = _updaters[k] - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, v, spec, updater, subdocument - ) - ) - - elif k == "$rename": - for src, dst in v.items(): - if "." in src or "." in dst: - raise NotImplementedError( - "Using the $rename operator with dots is a valid MongoDB " - "operation, but it is not yet supported by mongomock" - ) - if self._has_key(existing_document, src): - existing_document[dst] = existing_document.pop(src) - - elif k == "$setOnInsert": - if not was_insert: - continue - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, v, spec, _set_updater, subdocument - ) - ) - - elif k == "$currentDate": - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, - v, - spec, - _current_date_updater, - subdocument, - ) - ) - - elif k == "$addToSet": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - if len(nested_field_list) == 1: - if field not in existing_document: - existing_document[field] = [] - # document should be a list append to it - if isinstance(value, dict): - if "$each" in value: - # append the list to the field - existing_document[field] += [ - obj - for obj in list(value["$each"]) - if obj not in existing_document[field] - ] - continue - if value not in existing_document[field]: - existing_document[field].append(value) - continue - # push to array in a nested attribute - else: - # create nested attributes if they do not exist - subdocument = existing_document - for field_part in nested_field_list[:-1]: - if field_part == "$": - break - if field_part not in subdocument: - subdocument[field_part] = {} - - subdocument = subdocument[field_part] - - # get subdocument with $ oprator support - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # we're pushing a list - push_results = [] - if nested_field_list[-1] in subdocument: - # if the list exists, then use that list - push_results = subdocument[nested_field_list[-1]] - - if isinstance(value, dict) and "$each" in value: - push_results += [ - obj - for obj in list(value["$each"]) - if obj not in push_results - ] - elif value not in push_results: - push_results.append(value) - - subdocument[nested_field_list[-1]] = push_results - elif k == "$pull": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - # nested fields includes a positional element - # need to find that element - if "$" in nested_field_list: - if not subdocument: - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # value should be a dictionary since we're pulling - pull_results = [] - # and the last subdoc should be an array - for obj in subdocument[nested_field_list[-1]]: - if isinstance(obj, dict): - for pull_key, pull_value in value.items(): - if obj[pull_key] != pull_value: - pull_results.append(obj) - continue - if obj != value: - pull_results.append(obj) - - # cannot write to doc directly as it doesn't save to - # existing_document - subdocument[nested_field_list[-1]] = pull_results - else: - arr = existing_document - for field_part in nested_field_list: - if field_part not in arr: - break - arr = arr[field_part] - if not isinstance(arr, list): - continue - - arr_copy = copy.deepcopy(arr) - if isinstance(value, dict): - for obj in arr_copy: - try: - is_matching = filter_applies(value, obj) - except OperationFailure: - is_matching = False - if is_matching: - arr.remove(obj) - continue - - if filter_applies({"field": value}, {"field": obj}): - arr.remove(obj) - else: - for obj in arr_copy: - if value == obj: - arr.remove(obj) - elif k == "$pullAll": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - if len(nested_field_list) == 1: - if field in existing_document: - arr = existing_document[field] - existing_document[field] = [ - obj for obj in arr if obj not in value - ] - continue - else: - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - if nested_field_list[-1] in subdocument: - arr = subdocument[nested_field_list[-1]] - subdocument[nested_field_list[-1]] = [ - obj for obj in arr if obj not in value - ] - elif k == "$push": - for field, value in v.items(): - # Find the place where to push. - nested_field_list = field.rsplit(".") - subdocument, field = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # Push the new element or elements. - if isinstance(subdocument, dict) and field not in subdocument: - subdocument[field] = [] - push_results = subdocument[field] - if isinstance(value, dict) and "$each" in value: - if "$position" in value: - push_results = ( - push_results[0 : value["$position"]] - + list(value["$each"]) - + push_results[value["$position"] :] - ) - else: - push_results += list(value["$each"]) - - if "$sort" in value: - sort_spec = value["$sort"] - if isinstance(sort_spec, dict): - sort_key = set(sort_spec.keys()).pop() - push_results = sorted( - push_results, - key=lambda d: helpers.get_value_by_dot( - d, sort_key - ), - reverse=set(sort_spec.values()).pop() < 0, - ) - else: - push_results = sorted( - push_results, reverse=sort_spec < 0 - ) - - if "$slice" in value: - slice_value = value["$slice"] - if slice_value < 0: - push_results = push_results[slice_value:] - elif slice_value == 0: - push_results = [] - else: - push_results = push_results[:slice_value] - - unused_modifiers = set(value.keys()) - { - "$each", - "$slice", - "$position", - "$sort", - } - if unused_modifiers: - raise WriteError( - "Unrecognized clause in $push: " - + unused_modifiers.pop() - ) - else: - push_results.append(value) - subdocument[field] = push_results - else: - if first: - # replace entire document - for key in document.keys(): - if key.startswith("$"): - # can't mix modifiers with non-modifiers in - # update - raise ValueError( - "field names cannot start with $ [{}]".format(k) - ) - _id = spec.get("_id", existing_document.get("_id")) - existing_document.clear() - if _id is not None: - existing_document["_id"] = _id - if BSON: - # bson validation - check_keys = helpers.PYMONGO_VERSION < version.parse("3.6") - if not check_keys: - _validate_data_fields(document) - _bson_encode(document, self.codec_options) - existing_document.update(self._internalize_dict(document)) - if existing_document["_id"] != _id: - raise OperationFailure( - "The _id field cannot be changed from {0} to {1}".format( - existing_document["_id"], _id - ) - ) - break - else: - # can't mix modifiers with non-modifiers in update - raise ValueError("Invalid modifier specified: {}".format(k)) - first = False - # if empty document comes - if not document: - _id = spec.get("_id", existing_document.get("_id")) - existing_document.clear() - if _id: - existing_document["_id"] = _id - - if was_insert: - upserted_id = self._insert(existing_document) - num_updated += 1 - elif existing_document != original_document_snapshot: - # Document has been modified in-place. - - # Make sure the ID was not change. - if original_document_snapshot.get("_id") != existing_document.get( - "_id" - ): - # Rollback. - self._store[original_document_snapshot["_id"]] = ( - original_document_snapshot - ) - raise WriteError( - "After applying the update, the (immutable) field '_id' was found to have " - "been altered to _id: {}".format(existing_document.get("_id")) - ) - - # Make sure it still respect the unique indexes and, if not, to - # revert modifications - try: - self._ensure_uniques(existing_document) - num_updated += 1 - except DuplicateKeyError: - # Rollback. - self._store[original_document_snapshot["_id"]] = ( - original_document_snapshot - ) - raise - - if not multi: - break - - return { - "connectionId": self.database.client._id, - "err": None, - "n": num_matched, - "nModified": num_updated if updated_existing else 0, - "ok": 1, - "upserted": upserted_id, - "updatedExisting": updated_existing, - } - - def _get_subdocument(self, existing_document, spec, nested_field_list): - """This method retrieves the subdocument of the existing_document.nested_field_list. - - It uses the spec to filter through the items. It will continue to grab nested documents - until it can go no further. It will then return the subdocument that was last saved. - '$' is the positional operator, so we use the $elemMatch in the spec to find the right - subdocument in the array. - """ - # Current document in view. - doc = existing_document - # Previous document in view. - parent_doc = existing_document - # Current spec in view. - subspec = spec - # Whether spec is following the document. - is_following_spec = True - # Walk down the dictionary. - for index, subfield in enumerate(nested_field_list): - if subfield == "$": - if not is_following_spec: - raise WriteError( - "The positional operator did not find the match needed from the query" - ) - # Positional element should have the equivalent elemMatch in the query. - subspec = subspec["$elemMatch"] - is_following_spec = False - # Iterate through. - for spec_index, item in enumerate(doc): - if filter_applies(subspec, item): - subfield = spec_index - break - else: - raise WriteError( - "The positional operator did not find the match needed from the query" - ) - - parent_doc = doc - if isinstance(parent_doc, list): - subfield = int(subfield) - if is_following_spec and (subfield < 0 or subfield >= len(subspec)): - is_following_spec = False - - if index == len(nested_field_list) - 1: - return parent_doc, subfield - - if not isinstance(parent_doc, list): - if subfield not in parent_doc: - parent_doc[subfield] = {} - if is_following_spec and subfield not in subspec: - is_following_spec = False - - doc = parent_doc[subfield] - if is_following_spec: - subspec = subspec[subfield] - - def _expand_dots(self, doc): - expanded = {} - paths = {} - for k, v in doc.items(): - - def _raise_incompatible(subkey): - raise WriteError( - "cannot infer query fields to set, both paths '%s' and '%s' are matched" - % (k, paths[subkey]) - ) - - if k in paths: - _raise_incompatible(k) - - key_parts = k.split(".") - sub_expanded = expanded - - paths[k] = k - for i, key_part in enumerate(key_parts[:-1]): - if key_part not in sub_expanded: - sub_expanded[key_part] = {} - sub_expanded = sub_expanded[key_part] - key = ".".join(key_parts[: i + 1]) - if not isinstance(sub_expanded, dict): - _raise_incompatible(key) - paths[key] = k - sub_expanded[key_parts[-1]] = v - return expanded - - def _discard_operators(self, doc): - if not doc or not isinstance(doc, dict): - return doc, False - new_doc = OrderedDict() - for k, v in doc.items(): - if k == "$eq": - return v, False - if k.startswith("$"): - continue - new_v, discarded = self._discard_operators(v) - if not discarded: - new_doc[k] = new_v - return new_doc, not bool(new_doc) - - def find( - self, - filter=None, - projection=None, - skip=0, - limit=0, - no_cursor_timeout=False, - cursor_type=None, - sort=None, - allow_partial_results=False, - oplog_replay=False, - modifiers=None, - batch_size=0, - manipulate=True, - collation=None, - session=None, - max_time_ms=None, - allow_disk_use=False, - **kwargs, - ): - spec = filter - if spec is None: - spec = {} - validate_is_mapping("filter", spec) - for kwarg, value in kwargs.items(): - if value: - raise OperationFailure("Unrecognized field '%s'" % kwarg) - return ( - Cursor(self, spec, sort, projection, skip, limit, collation=collation) - .max_time_ms(max_time_ms) - .allow_disk_use(allow_disk_use) - ) - - def _get_dataset(self, spec, sort, fields, as_class): - dataset = self._iter_documents(spec) - if sort: - for sort_key, sort_direction in reversed(sort): - if sort_key == "$natural": - if sort_direction < 0: - dataset = iter(reversed(list(dataset))) - continue - if sort_key.startswith("$"): - raise NotImplementedError( - "Sorting by {} is not implemented in mongomock yet".format( - sort_key - ) - ) - dataset = iter( - sorted( - dataset, - key=lambda x: filtering.resolve_sort_key(sort_key, x), - reverse=sort_direction < 0, - ) - ) - for document in dataset: - yield self._copy_only_fields(document, fields, as_class) - - def _extract_projection_operators(self, fields): - """Removes and returns fields with projection operators.""" - result = {} - allowed_projection_operators = {"$elemMatch", "$slice"} - for key, value in fields.items(): - if isinstance(value, dict): - for op in value: - if op not in allowed_projection_operators: - raise ValueError("Unsupported projection option: {}".format(op)) - result[key] = value - - for key in result: - del fields[key] - - return result - - def _apply_projection_operators(self, ops, doc, doc_copy): - """Applies projection operators to copied document.""" - for field, op in ops.items(): - if field not in doc_copy: - if field in doc: - # field was not copied yet (since we are in include mode) - doc_copy[field] = doc[field] - else: - # field doesn't exist in original document, no work to do - continue - - if "$slice" in op: - if not isinstance(doc_copy[field], list): - raise OperationFailure( - "Unsupported type {} for slicing operation: {}".format( - type(doc_copy[field]), op - ) - ) - op_value = op["$slice"] - slice_ = None - if isinstance(op_value, list): - if len(op_value) != 2: - raise OperationFailure( - "Unsupported slice format {} for slicing operation: {}".format( - op_value, op - ) - ) - skip, limit = op_value - if skip < 0: - skip = len(doc_copy[field]) + skip - last = min(skip + limit, len(doc_copy[field])) - slice_ = slice(skip, last) - elif isinstance(op_value, int): - count = op_value - start = 0 - end = len(doc_copy[field]) - if count < 0: - start = max(0, len(doc_copy[field]) + count) - else: - end = min(count, len(doc_copy[field])) - slice_ = slice(start, end) - - if slice_: - doc_copy[field] = doc_copy[field][slice_] - else: - raise OperationFailure( - "Unsupported slice value {} for slicing operation: {}".format( - op_value, op - ) - ) - - if "$elemMatch" in op: - if isinstance(doc_copy[field], list): - # find the first item that matches - matched = False - for item in doc_copy[field]: - if filter_applies(op["$elemMatch"], item): - matched = True - doc_copy[field] = [item] - break - - # None have matched - if not matched: - del doc_copy[field] - - else: - # remove the field since there is None to iterate - del doc_copy[field] - - def _copy_only_fields(self, doc, fields, container): - """Copy only the specified fields.""" - - # https://pymongo.readthedocs.io/en/stable/migrate-to-pymongo4.html#collection-find-returns-entire-document-with-empty-projection - if ( - fields is None - or not fields - and helpers.PYMONGO_VERSION >= version.parse("4.0") - ): - return _copy_field(doc, container) - - if not fields: - fields = {"_id": 1} - if not isinstance(fields, dict): - fields = helpers.fields_list_to_dict(fields) - - # we can pass in something like {'_id':0, 'field':1}, so pull the id - # value out and hang on to it until later - id_value = fields.pop("_id", 1) - - # filter out fields with projection operators, we will take care of them later - projection_operators = self._extract_projection_operators(fields) - - # other than the _id field, all fields must be either includes or - # excludes, this can evaluate to 0 - if len(set(list(fields.values()))) > 1: - raise ValueError("You cannot currently mix including and excluding fields.") - - # if we have novalues passed in, make a doc_copy based on the - # id_value - if not fields: - if id_value == 1: - doc_copy = container() - else: - doc_copy = _copy_field(doc, container) - else: - doc_copy = _project_by_spec( - doc, - _combine_projection_spec(fields), - is_include=list(fields.values())[0], - container=container, - ) - - # set the _id value if we requested it, otherwise remove it - if id_value == 0: - doc_copy.pop("_id", None) - else: - if "_id" in doc: - doc_copy["_id"] = doc["_id"] - - fields["_id"] = id_value # put _id back in fields - - # time to apply the projection operators and put back their fields - self._apply_projection_operators(projection_operators, doc, doc_copy) - for field, op in projection_operators.items(): - fields[field] = op - return doc_copy - - def _update_document_fields(self, doc, fields, updater): - """Implements the $set behavior on an existing document""" - for k, v in fields.items(): - self._update_document_single_field(doc, k, v, updater) - - def _update_document_fields_positional( - self, doc, fields, spec, updater, subdocument=None - ): - """Implements the $set behavior on an existing document""" - for k, v in fields.items(): - if "$" in k: - field_name_parts = k.split(".") - if not subdocument: - current_doc = doc - subspec = spec - for part in field_name_parts[:-1]: - if part == "$": - subspec_dollar = subspec.get("$elemMatch", subspec) - for item in current_doc: - if filter_applies(subspec_dollar, item): - current_doc = item - break - continue - - new_spec = {} - for el in subspec: - if el.startswith(part): - if len(el.split(".")) > 1: - new_spec[".".join(el.split(".")[1:])] = subspec[el] - else: - new_spec = subspec[el] - subspec = new_spec - current_doc = current_doc[part] - - subdocument = current_doc - if field_name_parts[-1] == "$" and isinstance(subdocument, list): - for i, doc in enumerate(subdocument): - subspec_dollar = subspec.get("$elemMatch", subspec) - if filter_applies(subspec_dollar, doc): - subdocument[i] = v - break - continue - - updater(subdocument, field_name_parts[-1], v) - continue - # otherwise, we handle it the standard way - self._update_document_single_field(doc, k, v, updater) - - return subdocument - - def _update_document_fields_with_positional_awareness( - self, existing_document, v, spec, updater, subdocument - ): - positional = any("$" in key for key in v.keys()) - - if positional: - return self._update_document_fields_positional( - existing_document, v, spec, updater, subdocument - ) - self._update_document_fields(existing_document, v, updater) - return subdocument - - def _update_document_single_field(self, doc, field_name, field_value, updater): - field_name_parts = field_name.split(".") - for part in field_name_parts[:-1]: - if isinstance(doc, list): - try: - if part == "$": - doc = doc[0] - else: - doc = doc[int(part)] - continue - except ValueError: - pass - elif isinstance(doc, dict): - if updater is _unset_updater and part not in doc: - # If the parent doesn't exists, so does it child. - return - doc = doc.setdefault(part, {}) - else: - return - field_name = field_name_parts[-1] - updater(doc, field_name, field_value, codec_options=self._codec_options) - - def _iter_documents(self, filter): - # Validate the filter even if no documents can be returned. - if self._store.is_empty: - filter_applies(filter, {}) - - return ( - document - for document in list(self._store.documents) - if filter_applies(filter, document) - ) - - def find_one(self, filter=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - # Allow calling find_one with a non-dict argument that gets used as - # the id for the query. - if filter is None: - filter = {} - if not isinstance(filter, Mapping): - filter = {"_id": filter} - - try: - return next(self.find(filter, *args, **kwargs)) - except StopIteration: - return None - - def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs): - kwargs["remove"] = True - validate_is_mapping("filter", filter) - return self._find_and_modify(filter, projection, sort=sort, **kwargs) - - def find_one_and_replace( - self, - filter, - replacement, - projection=None, - sort=None, - upsert=False, - return_document=ReturnDocument.BEFORE, - **kwargs, - ): - validate_is_mapping("filter", filter) - validate_ok_for_replace(replacement) - return self._find_and_modify( - filter, projection, replacement, upsert, sort, return_document, **kwargs - ) - - def find_one_and_update( - self, - filter, - update, - projection=None, - sort=None, - upsert=False, - return_document=ReturnDocument.BEFORE, - **kwargs, - ): - validate_is_mapping("filter", filter) - validate_ok_for_update(update) - return self._find_and_modify( - filter, projection, update, upsert, sort, return_document, **kwargs - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def find_and_modify( - self, - query={}, - update=None, - upsert=False, - sort=None, - full_response=False, - manipulate=False, - fields=None, - **kwargs, - ): - warnings.warn( - "find_and_modify is deprecated, use find_one_and_delete" - ", find_one_and_replace, or find_one_and_update instead", - DeprecationWarning, - stacklevel=2, - ) - if "projection" in kwargs: - raise TypeError( - "find_and_modify() got an unexpected keyword argument 'projection'" - ) - return self._find_and_modify( - query, - update=update, - upsert=upsert, - sort=sort, - projection=fields, - **kwargs, - ) - - def _find_and_modify( - self, - query, - projection=None, - update=None, - upsert=False, - sort=None, - return_document=ReturnDocument.BEFORE, - session=None, - **kwargs, - ): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - remove = kwargs.get("remove", False) - if kwargs.get("new", False) and remove: - # message from mongodb - raise OperationFailure("remove and returnNew can't co-exist") - - if not (remove or update): - raise ValueError("Must either update or remove") - - if remove and update: - raise ValueError("Can't do both update and remove") - - old = self.find_one(query, projection=projection, sort=sort) - if not old and not upsert: - return - - if old and "_id" in old: - query = {"_id": old["_id"]} - - if remove: - self.delete_one(query) - else: - updated = self._update(query, update, upsert) - if updated["upserted"]: - query = {"_id": updated["upserted"]} - - if return_document is ReturnDocument.AFTER or kwargs.get("new"): - return self.find_one(query, projection) - return old - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def save(self, to_save, manipulate=True, check_keys=True, **kwargs): - warnings.warn( - "save is deprecated. Use insert_one or replace_one " "instead", - DeprecationWarning, - stacklevel=2, - ) - validate_is_mutable_mapping("to_save", to_save) - validate_write_concern_params(**kwargs) - - if "_id" not in to_save: - return self.insert(to_save) - self._update( - {"_id": to_save["_id"]}, - to_save, - True, - manipulate, - check_keys=True, - **kwargs, - ) - return to_save.get("_id", None) - - def delete_one(self, filter, collation=None, hint=None, session=None): - validate_is_mapping("filter", filter) - return DeleteResult( - self._delete(filter, collation=collation, hint=hint, session=session), True - ) - - def delete_many(self, filter, collation=None, hint=None, session=None): - validate_is_mapping("filter", filter) - return DeleteResult( - self._delete( - filter, collation=collation, hint=hint, multi=True, session=session - ), - True, - ) - - def _delete(self, filter, collation=None, hint=None, multi=False, session=None): - if hint: - raise NotImplementedError( - "The hint argument of delete is valid but has not been implemented in " - "mongomock yet" - ) - if collation: - raise_not_implemented( - "collation", - "The collation argument of delete is valid but has not been " - "implemented in mongomock yet", - ) - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - filter = helpers.patch_datetime_awareness_in_document(filter) - if filter is None: - filter = {} - if not isinstance(filter, Mapping): - filter = {"_id": filter} - to_delete = list(self.find(filter)) - deleted_count = 0 - for doc in to_delete: - doc_id = doc["_id"] - if isinstance(doc_id, dict): - doc_id = helpers.hashdict(doc_id) - del self._store[doc_id] - deleted_count += 1 - if not multi: - break - - return { - "connectionId": self.database.client._id, - "n": deleted_count, - "ok": 1.0, - "err": None, - } - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def remove(self, spec_or_id=None, multi=True, **kwargs): - warnings.warn( - "remove is deprecated. Use delete_one or delete_many " "instead.", - DeprecationWarning, - stacklevel=2, - ) - validate_write_concern_params(**kwargs) - return self._delete(spec_or_id, multi=multi) - - def count(self, filter=None, **kwargs): - warnings.warn( - "count is deprecated. Use estimated_document_count or " - "count_documents instead. Please note that $where must be replaced " - "by $expr, $near must be replaced by $geoWithin with $center, and " - "$nearSphere must be replaced by $geoWithin with $centerSphere", - DeprecationWarning, - stacklevel=2, - ) - if kwargs.pop("session", None): - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - if filter is None: - return len(self._store) - spec = helpers.patch_datetime_awareness_in_document(filter) - return len(list(self._iter_documents(spec))) - - def count_documents(self, filter, **kwargs): - if kwargs.pop("collation", None): - raise_not_implemented( - "collation", - "The collation argument of count_documents is valid but has not been " - "implemented in mongomock yet", - ) - if kwargs.pop("session", None): - raise_not_implemented("session", "Mongomock does not handle sessions yet") - skip = kwargs.pop("skip", 0) - if "limit" in kwargs: - limit = kwargs.pop("limit") - if not isinstance(limit, (int, float)): - raise OperationFailure("the limit must be specified as a number") - if limit <= 0: - raise OperationFailure("the limit must be positive") - limit = math.floor(limit) - else: - limit = None - unknown_kwargs = set(kwargs) - {"maxTimeMS", "hint"} - if unknown_kwargs: - raise OperationFailure("unrecognized field '%s'" % unknown_kwargs.pop()) - - spec = helpers.patch_datetime_awareness_in_document(filter) - doc_num = len(list(self._iter_documents(spec))) - count = max(doc_num - skip, 0) - return count if limit is None else min(count, limit) - - def estimated_document_count(self, **kwargs): - if kwargs.pop("session", None): - raise ConfigurationError( - "estimated_document_count does not support sessions" - ) - unknown_kwargs = set(kwargs) - {"limit", "maxTimeMS", "hint"} - if self.database.client.server_info()["versionArray"] < [5]: - unknown_kwargs.discard("skip") - if unknown_kwargs: - raise OperationFailure( - "BSON field 'count.%s' is an unknown field." % list(unknown_kwargs)[0] - ) - return self.count_documents({}, **kwargs) - - def drop(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - self.database.drop_collection(self.name) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def ensure_index(self, key_or_list, cache_for=300, **kwargs): - return self.create_index(key_or_list, cache_for, **kwargs) - - def create_index(self, key_or_list, cache_for=300, session=None, **kwargs): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - index_list = helpers.create_index_list(key_or_list) - is_unique = kwargs.pop("unique", False) - is_sparse = kwargs.pop("sparse", False) - - index_name = kwargs.pop("name", helpers.gen_index_name(index_list)) - index_dict = {"key": index_list} - if is_sparse: - index_dict["sparse"] = True - if is_unique: - index_dict["unique"] = True - if "expireAfterSeconds" in kwargs and kwargs["expireAfterSeconds"] is not None: - index_dict["expireAfterSeconds"] = kwargs.pop("expireAfterSeconds") - if ( - "partialFilterExpression" in kwargs - and kwargs["partialFilterExpression"] is not None - ): - index_dict["partialFilterExpression"] = kwargs.pop( - "partialFilterExpression" - ) - - existing_index = self._store.indexes.get(index_name) - if existing_index and index_dict != existing_index: - raise OperationFailure( - "Index with name: %s already exists with different options" % index_name - ) - - # Check that documents already verify the uniquess of this new index. - if is_unique: - indexed = set() - indexed_list = [] - documents_gen = self._store.documents - for doc in documents_gen: - index = [] - for key, unused_order in index_list: - try: - index.append(helpers.get_value_by_dot(doc, key)) - except KeyError: - if is_sparse: - continue - index.append(None) - if is_sparse and not index: - continue - index = tuple(index) - try: - if index in indexed: - # Need to throw this inside the generator so it can clean the locks - documents_gen.throw( - DuplicateKeyError("E11000 Duplicate Key Error", 11000), - None, - None, - ) - indexed.add(index) - except TypeError as err: - # index is not hashable. - if index in indexed_list: - documents_gen.throw( - DuplicateKeyError("E11000 Duplicate Key Error", 11000), - None, - err, - ) - indexed_list.append(index) - - self._store.create_index(index_name, index_dict) - - return index_name - - def create_indexes(self, indexes, session=None): - for index in indexes: - if not isinstance(index, IndexModel): - raise TypeError( - "%s is not an instance of pymongo.operations.IndexModel" % index - ) - - return [ - self.create_index( - index.document["key"].items(), - session=session, - expireAfterSeconds=index.document.get("expireAfterSeconds"), - unique=index.document.get("unique", False), - sparse=index.document.get("sparse", False), - name=index.document.get("name"), - ) - for index in indexes - ] - - def drop_index(self, index_or_name, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if isinstance(index_or_name, list): - name = helpers.gen_index_name(index_or_name) - else: - name = index_or_name - try: - self._store.drop_index(name) - except KeyError as err: - raise OperationFailure("index not found with name [%s]" % name) from err - - def drop_indexes(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - self._store.indexes = {} - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def reindex(self, session=None): - if session: - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - - def _list_all_indexes(self): - if not self._store.is_created: - return - yield "_id_", {"key": [("_id", 1)]} - for name, information in self._store.indexes.items(): - yield name, information - - def list_indexes(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - for name, information in self._list_all_indexes(): - yield dict(information, key=dict(information["key"]), name=name, v=2) - - def index_information(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - return {name: dict(index, v=2) for name, index in self._list_all_indexes()} - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def map_reduce( - self, - map_func, - reduce_func, - out, - full_response=False, - query=None, - limit=0, - session=None, - ): - if execjs is None: - raise NotImplementedError( - "PyExecJS is required in order to run Map-Reduce. " - "Use 'pip install pyexecjs pymongo' to support Map-Reduce mock." - ) - if session: - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - if limit == 0: - limit = None - start_time = time.perf_counter() - out_collection = None - reduced_rows = None - full_dict = { - "counts": {"input": 0, "reduce": 0, "emit": 0, "output": 0}, - "timeMillis": 0, - "ok": 1.0, - "result": None, - } - map_ctx = execjs.compile( - """ - function doMap(fnc, docList) { - var mappedDict = {}; - function emit(key, val) { - if (key['$oid']) { - mapped_key = '$oid' + key['$oid']; - } - else { - mapped_key = key; - } - if(!mappedDict[mapped_key]) { - mappedDict[mapped_key] = []; - } - mappedDict[mapped_key].push(val); - } - mapper = eval('('+fnc+')'); - var mappedList = new Array(); - for(var i=0; i 1: - full_dict["counts"]["reduce"] += 1 - full_dict["counts"]["output"] = len(reduced_rows) - if isinstance(out, (str, bytes)): - out_collection = getattr(self.database, out) - out_collection.drop() - out_collection.insert(reduced_rows) - ret_val = out_collection - full_dict["result"] = out - elif isinstance(out, SON) and out.get("replace") and out.get("db"): - # Must be of the format SON([('replace','results'),('db','outdb')]) - out_db = getattr(self.database._client, out["db"]) - out_collection = getattr(out_db, out["replace"]) - out_collection.insert(reduced_rows) - ret_val = out_collection - full_dict["result"] = {"db": out["db"], "collection": out["replace"]} - elif isinstance(out, dict) and out.get("inline"): - ret_val = reduced_rows - full_dict["result"] = reduced_rows - else: - raise TypeError("'out' must be an instance of string, dict or bson.SON") - time_millis = (time.perf_counter() - start_time) * 1000 - full_dict["timeMillis"] = int(round(time_millis)) - if full_response: - ret_val = full_dict - return ret_val - - def inline_map_reduce( - self, - map_func, - reduce_func, - full_response=False, - query=None, - limit=0, - session=None, - ): - return self.map_reduce( - map_func, - reduce_func, - {"inline": 1}, - full_response, - query, - limit, - session=session, - ) - - def distinct(self, key, filter=None, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - return self.find(filter).distinct(key) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def group(self, key, condition, initial, reduce, finalize=None): - if helpers.PYMONGO_VERSION >= version.parse("3.6"): - raise OperationFailure("no such command: 'group'") - if execjs is None: - raise NotImplementedError( - "PyExecJS is required in order to use group. " - "Use 'pip install pyexecjs pymongo' to support group mock." - ) - reduce_ctx = execjs.compile( - """ - function doReduce(fnc, docList) { - reducer = eval('('+fnc+')'); - for(var i=0, l=docList.length; i 0: - doc += [None] * len_diff - doc[field_index] = value - - -def _unset_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc.pop(field_name, None) - - -def _inc_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = doc.get(field_name, 0) + value - - if isinstance(doc, list): - field_index = int(field_name) - if field_index < 0: - raise WriteError("Negative index provided") - try: - doc[field_index] += value - except IndexError: - len_diff = field_index - (len(doc) - 1) - doc += [None] * len_diff - doc[field_index] = value - - -def _max_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = max(doc.get(field_name, value), value) - - -def _min_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = min(doc.get(field_name, value), value) - - -def _pop_updater(doc, field_name, value, codec_options=None): - if value not in {1, -1}: - raise WriteError("$pop expects 1 or -1, found: " + str(value)) - - if isinstance(doc, dict): - if isinstance(doc[field_name], (tuple, list)): - doc[field_name] = list(doc[field_name]) - _pop_from_list(doc[field_name], value) - return - raise WriteError("Path contains element of non-array type") - - if isinstance(doc, list): - field_index = int(field_name) - if field_index < 0: - raise WriteError("Negative index provided") - if field_index >= len(doc): - return - _pop_from_list(doc[field_index], value) - - -def _pop_from_list(list_instance, mongo_pop_value, codec_options=None): - if not list_instance: - return - - if mongo_pop_value == 1: - list_instance.pop() - elif mongo_pop_value == -1: - list_instance.pop(0) - - -def _current_date_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - if value == {"$type": "timestamp"}: - # TODO(juannyg): get_current_timestamp should also be using helpers utcnow, - # as it currently using time.time internally - doc[field_name] = helpers.get_current_timestamp() - else: - doc[field_name] = utcnow() - - -_updaters = { - "$set": _set_updater, - "$unset": _unset_updater, - "$inc": _inc_updater, - "$max": _max_updater, - "$min": _min_updater, - "$pop": _pop_updater, -} diff --git a/packages/syft/tests/mongomock/command_cursor.py b/packages/syft/tests/mongomock/command_cursor.py deleted file mode 100644 index 025bb836e24..00000000000 --- a/packages/syft/tests/mongomock/command_cursor.py +++ /dev/null @@ -1,37 +0,0 @@ -class CommandCursor(object): - def __init__(self, collection, curser_info=None, address=None, retrieved=0): - self._collection = iter(collection) - self._id = None - self._address = address - self._data = {} - self._retrieved = retrieved - self._batch_size = 0 - self._killed = self._id == 0 - - @property - def address(self): - return self._address - - def close(self): - pass - - def batch_size(self, batch_size): - return self - - @property - def alive(self): - return True - - def __iter__(self): - return self - - def next(self): - return next(self._collection) - - __next__ = next - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return diff --git a/packages/syft/tests/mongomock/database.py b/packages/syft/tests/mongomock/database.py deleted file mode 100644 index 3b1a7e59f70..00000000000 --- a/packages/syft/tests/mongomock/database.py +++ /dev/null @@ -1,301 +0,0 @@ -# stdlib -import warnings - -# third party -from packaging import version - -# relative -from . import CollectionInvalid -from . import InvalidName -from . import OperationFailure -from . import codec_options as mongomock_codec_options -from . import helpers -from . import read_preferences -from . import store -from .collection import Collection -from .filtering import filter_applies - -try: - # third party - from pymongo import ReadPreference - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY - -try: - # third party - from pymongo.read_concern import ReadConcern -except ImportError: - # relative - from .read_concern import ReadConcern - -_LIST_COLLECTION_FILTER_ALLOWED_OPERATORS = frozenset(["$regex", "$eq", "$ne"]) - - -def _verify_list_collection_supported_op(keys): - if set(keys) - _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS: - raise NotImplementedError( - "list collection names filter operator {0} is not implemented yet in mongomock " - "allowed operators are {1}".format( - keys, _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS - ) - ) - - -class Database(object): - def __init__( - self, - client, - name, - _store, - read_preference=None, - codec_options=None, - read_concern=None, - ): - self.name = name - self._client = client - self._collection_accesses = {} - self._store = _store or store.DatabaseStore() - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - mongomock_codec_options.is_supported(codec_options) - self._codec_options = codec_options or mongomock_codec_options.CodecOptions() - if read_concern and not isinstance(read_concern, ReadConcern): - raise TypeError( - "read_concern must be an instance of pymongo.read_concern.ReadConcern" - ) - self._read_concern = read_concern or ReadConcern() - - def __getitem__(self, coll_name): - return self.get_collection(coll_name) - - def __getattr__(self, attr): - if attr.startswith("_"): - raise AttributeError( - "%s has no attribute '%s'. To access the %s collection, use database['%s']." - % (self.__class__.__name__, attr, attr, attr) - ) - return self[attr] - - def __repr__(self): - return "Database({0}, '{1}')".format(self._client, self.name) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self._client == other._client and self.name == other.name - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash((self._client, self.name)) - - @property - def client(self): - return self._client - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - @property - def read_concern(self): - return self._read_concern - - def _get_created_collections(self): - return self._store.list_created_collection_names() - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def collection_names(self, include_system_collections=True, session=None): - warnings.warn( - "collection_names is deprecated. Use list_collection_names instead." - ) - if include_system_collections: - return list(self._get_created_collections()) - return self.list_collection_names(session=session) - - def list_collections(self, filter=None, session=None, nameOnly=False): - raise NotImplementedError( - "list_collections is a valid method of Database but has not been implemented in " - "mongomock yet." - ) - - def list_collection_names(self, filter=None, session=None): - """filter: only name field type with eq,ne or regex operator - - session: not supported - for supported operator please see _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS - """ - field_name = "name" - - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - if filter: - if not filter.get("name"): - raise NotImplementedError( - "list collection {0} might be valid but is not " - "implemented yet in mongomock".format(filter) - ) - - filter = ( - {field_name: {"$eq": filter.get(field_name)}} - if isinstance(filter.get(field_name), str) - else filter - ) - - _verify_list_collection_supported_op(filter.get(field_name).keys()) - - return [ - name - for name in list(self._store._collections) - if filter_applies(filter, {field_name: name}) - and not name.startswith("system.") - ] - - return [ - name - for name in self._get_created_collections() - if not name.startswith("system.") - ] - - def get_collection( - self, - name, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - if read_preference is not None: - read_preferences.ensure_read_preference_type( - "read_preference", read_preference - ) - mongomock_codec_options.is_supported(codec_options) - try: - return self._collection_accesses[name].with_options( - codec_options=codec_options or self._codec_options, - read_preference=read_preference or self.read_preference, - read_concern=read_concern, - write_concern=write_concern, - ) - except KeyError: - self._ensure_valid_collection_name(name) - collection = self._collection_accesses[name] = Collection( - self, - name=name, - read_concern=read_concern, - write_concern=write_concern, - read_preference=read_preference or self.read_preference, - codec_options=codec_options or self._codec_options, - _db_store=self._store, - ) - return collection - - def drop_collection(self, name_or_collection, session=None): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - if isinstance(name_or_collection, Collection): - name_or_collection._store.drop() - else: - self._store[name_or_collection].drop() - - def _ensure_valid_collection_name(self, name): - # These are the same checks that are done in pymongo. - if not isinstance(name, str): - raise TypeError("name must be an instance of str") - if not name or ".." in name: - raise InvalidName("collection names cannot be empty") - if name[0] == "." or name[-1] == ".": - raise InvalidName("collection names must not start or end with '.'") - if "$" in name: - raise InvalidName("collection names must not contain '$'") - if "\x00" in name: - raise InvalidName("collection names must not contain the null character") - - def create_collection(self, name, **kwargs): - self._ensure_valid_collection_name(name) - if name in self.list_collection_names(): - raise CollectionInvalid("collection %s already exists" % name) - - if kwargs: - raise NotImplementedError("Special options not supported") - - self._store.create_collection(name) - return self[name] - - def rename_collection(self, name, new_name, dropTarget=False): - """Changes the name of an existing collection.""" - self._ensure_valid_collection_name(new_name) - - # Reference for server implementation: - # https://docs.mongodb.com/manual/reference/command/renameCollection/ - if not self._store[name].is_created: - raise OperationFailure( - 'The collection "{0}" does not exist.'.format(name), 10026 - ) - if new_name in self._store: - if dropTarget: - self.drop_collection(new_name) - else: - raise OperationFailure( - 'The target collection "{0}" already exists'.format(new_name), 10027 - ) - self._store.rename(name, new_name) - return {"ok": 1} - - def dereference(self, dbref, session=None): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - if not hasattr(dbref, "collection") or not hasattr(dbref, "id"): - raise TypeError("cannot dereference a %s" % type(dbref)) - if dbref.database is not None and dbref.database != self.name: - raise ValueError( - "trying to dereference a DBRef that points to " - "another database (%r not %r)" % (dbref.database, self.name) - ) - return self[dbref.collection].find_one({"_id": dbref.id}) - - def command(self, command, **unused_kwargs): - if isinstance(command, str): - command = {command: 1} - if "ping" in command: - return {"ok": 1.0} - # TODO(pascal): Differentiate NotImplementedError for valid commands - # and OperationFailure if the command is not valid. - raise NotImplementedError( - "command is a valid Database method but is not implemented in Mongomock yet" - ) - - def with_options( - self, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - mongomock_codec_options.is_supported(codec_options) - - if write_concern: - raise NotImplementedError( - "write_concern is a valid parameter for with_options but is not implemented yet in " - "mongomock" - ) - - if read_preference is None or read_preference == self._read_preference: - return self - - return Database( - self._client, - self.name, - self._store, - read_preference=read_preference or self._read_preference, - codec_options=codec_options or self._codec_options, - read_concern=read_concern or self._read_concern, - ) diff --git a/packages/syft/tests/mongomock/filtering.py b/packages/syft/tests/mongomock/filtering.py deleted file mode 100644 index 345b94c7b88..00000000000 --- a/packages/syft/tests/mongomock/filtering.py +++ /dev/null @@ -1,601 +0,0 @@ -# stdlib -from datetime import datetime -import itertools -import numbers -import operator -import re -import uuid - -# relative -from . import OperationFailure -from .helpers import ObjectId -from .helpers import RE_TYPE - -try: - # stdlib - from types import NoneType -except ImportError: - NoneType = type(None) - -try: - # third party - from bson import DBRef - from bson import Regex - - _RE_TYPES = (RE_TYPE, Regex) -except ImportError: - DBRef = None - _RE_TYPES = (RE_TYPE,) - -try: - # third party - from bson.decimal128 import Decimal128 -except ImportError: - Decimal128 = None - -_TOP_LEVEL_OPERATORS = {"$expr", "$text", "$where", "$jsonSchema"} - - -_NOT_IMPLEMENTED_OPERATORS = { - "$bitsAllClear", - "$bitsAllSet", - "$bitsAnyClear", - "$bitsAnySet", - "$geoIntersects", - "$geoWithin", - "$maxDistance", - "$minDistance", - "$near", - "$nearSphere", -} - - -def filter_applies(search_filter, document): - """Applies given filter - - This function implements MongoDB's matching strategy over documents in the find() method - and other related scenarios (like $elemMatch) - """ - return _filterer_inst.apply(search_filter, document) - - -class _Filterer(object): - """An object to help applying a filter, using the MongoDB query language.""" - - # This is populated using register_parse_expression further down. - parse_expression = [] - - def __init__(self): - self._operator_map = dict( - { - "$eq": _list_expand(operator_eq), - "$ne": _list_expand( - lambda dv, sv: not operator_eq(dv, sv), negative=True - ), - "$all": self._all_op, - "$in": _in_op, - "$nin": lambda dv, sv: not _in_op(dv, sv), - "$exists": lambda dv, sv: bool(sv) == (dv is not None), - "$regex": _not_None_and(_regex), - "$elemMatch": self._elem_match_op, - "$size": _size_op, - "$type": _type_op, - }, - **{ - key: _not_None_and(_list_expand(_compare_objects(op))) - for key, op in SORTING_OPERATOR_MAP.items() - }, - ) - - def apply(self, search_filter, document): - if not isinstance(search_filter, dict): - raise OperationFailure( - "the match filter must be an expression in an object" - ) - - for key, search in search_filter.items(): - # Top level operators. - if key == "$comment": - continue - if key in LOGICAL_OPERATOR_MAP: - if not search: - raise OperationFailure( - "BadValue $and/$or/$nor must be a nonempty array" - ) - if not LOGICAL_OPERATOR_MAP[key](document, search, self.apply): - return False - continue - if key == "$expr": - parse_expression = self.parse_expression[0] - if not parse_expression(search, document, ignore_missing_keys=True): - return False - continue - if key in _TOP_LEVEL_OPERATORS: - raise NotImplementedError( - "The {} operator is not implemented in mongomock yet".format(key) - ) - if key.startswith("$"): - raise OperationFailure("unknown top level operator: " + key) - - is_match = False - - is_checking_negative_match = isinstance(search, dict) and { - "$ne", - "$nin", - } & set(search.keys()) - is_checking_positive_match = not isinstance(search, dict) or ( - set(search.keys()) - {"$ne", "$nin"} - ) - has_candidates = False - - if search == {"$exists": False} and not iter_key_candidates(key, document): - continue - - if isinstance(search, dict) and "$all" in search: - if not self._all_op(iter_key_candidates(key, document), search["$all"]): - return False - # if there are no query operators then continue - if len(search) == 1: - continue - - for doc_val in iter_key_candidates(key, document): - has_candidates |= doc_val is not None - is_ops_filter = ( - search - and isinstance(search, dict) - and all(key.startswith("$") for key in search.keys()) - ) - if is_ops_filter: - if "$options" in search and "$regex" in search: - search = _combine_regex_options(search) - unknown_operators = set(search) - set(self._operator_map) - {"$not"} - if unknown_operators: - not_implemented_operators = ( - unknown_operators & _NOT_IMPLEMENTED_OPERATORS - ) - if not_implemented_operators: - raise NotImplementedError( - "'%s' is a valid operation but it is not supported by Mongomock " - "yet." % list(not_implemented_operators)[0] - ) - raise OperationFailure( - "unknown operator: " + list(unknown_operators)[0] - ) - is_match = ( - all( - operator_string in self._operator_map - and self._operator_map[operator_string](doc_val, search_val) - or operator_string == "$not" - and self._not_op(document, key, search_val) - for operator_string, search_val in search.items() - ) - and search - ) - elif isinstance(search, _RE_TYPES) and isinstance(doc_val, (str, list)): - is_match = _regex(doc_val, search) - elif key in LOGICAL_OPERATOR_MAP: - if not search: - raise OperationFailure( - "BadValue $and/$or/$nor must be a nonempty array" - ) - is_match = LOGICAL_OPERATOR_MAP[key](document, search, self.apply) - elif isinstance(doc_val, (list, tuple)): - is_match = search in doc_val or search == doc_val - if isinstance(search, ObjectId): - is_match |= str(search) in doc_val - else: - is_match = (doc_val == search) or ( - search is None and doc_val is None - ) - - # When checking negative match, all the elements should match. - if is_checking_negative_match and not is_match: - return False - - # If not checking negative matches, the first match is enouh for this criteria. - if is_match and not is_checking_negative_match: - break - - if not is_match and (has_candidates or is_checking_positive_match): - return False - - return True - - def _not_op(self, d, k, s): - if isinstance(s, dict): - for key in s.keys(): - if key not in self._operator_map and key not in LOGICAL_OPERATOR_MAP: - raise OperationFailure("unknown operator: %s" % key) - elif isinstance(s, _RE_TYPES): - pass - else: - raise OperationFailure("$not needs a regex or a document") - return not self.apply({k: s}, d) - - def _elem_match_op(self, doc_val, query): - if not isinstance(doc_val, list): - return False - if not isinstance(query, dict): - raise OperationFailure("$elemMatch needs an Object") - for item in doc_val: - try: - if self.apply(query, item): - return True - except OperationFailure: - if self.apply({"field": query}, {"field": item}): - return True - return False - - def _all_op(self, doc_val, search_val): - if isinstance(doc_val, list) and doc_val and isinstance(doc_val[0], list): - doc_val = list(itertools.chain.from_iterable(doc_val)) - dv = _force_list(doc_val) - matches = [] - for x in search_val: - if isinstance(x, dict) and "$elemMatch" in x: - matches.append(self._elem_match_op(doc_val, x["$elemMatch"])) - else: - matches.append(x in dv) - return all(matches) - - -def iter_key_candidates(key, doc): - """Get possible subdocuments or lists that are referred to by the key in question - - Returns the appropriate nested value if the key includes dot notation. - """ - if not key: - return [doc] - - if doc is None: - return () - - if isinstance(doc, list): - return _iter_key_candidates_sublist(key, doc) - - if not isinstance(doc, dict): - return () - - key_parts = key.split(".") - if len(key_parts) == 1: - return [doc.get(key, None)] - - sub_key = ".".join(key_parts[1:]) - sub_doc = doc.get(key_parts[0], {}) - return iter_key_candidates(sub_key, sub_doc) - - -def _iter_key_candidates_sublist(key, doc): - """Iterates of candidates - - :param doc: a list to be searched for candidates for our key - :param key: the string key to be matched - """ - key_parts = key.split(".") - sub_key = key_parts.pop(0) - key_remainder = ".".join(key_parts) - try: - sub_key_int = int(sub_key) - except ValueError: - sub_key_int = None - - if sub_key_int is None: - # subkey is not an integer... - ret = [] - for sub_doc in doc: - if isinstance(sub_doc, dict): - if sub_key in sub_doc: - ret.extend(iter_key_candidates(key_remainder, sub_doc[sub_key])) - else: - ret.append(None) - return ret - - # subkey is an index - if sub_key_int >= len(doc): - return () # dead end - sub_doc = doc[sub_key_int] - if key_parts: - return iter_key_candidates(".".join(key_parts), sub_doc) - return [sub_doc] - - -def _force_list(v): - return v if isinstance(v, (list, tuple)) else [v] - - -def _in_op(doc_val, search_val): - if not isinstance(search_val, (list, tuple)): - raise OperationFailure("$in needs an array") - if doc_val is None and None in search_val: - return True - doc_val = _force_list(doc_val) - is_regex_list = [isinstance(x, _RE_TYPES) for x in search_val] - if not any(is_regex_list): - return any(x in search_val for x in doc_val) - for x, is_regex in zip(search_val, is_regex_list): - if (is_regex and _regex(doc_val, x)) or (x in doc_val): - return True - return False - - -def _not_None_and(f): - """wrap an operator to return False if the first arg is None""" - return lambda v, l: v is not None and f(v, l) - - -def _compare_objects(op): - """Wrap an operator to also compare objects following BSON comparison. - - See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#objects - """ - - def _wrapped(a, b): - # Do not compare uncomparable types, see Type Bracketing: - # https://docs.mongodb.com/manual/reference/method/db.collection.find/#type-bracketing - return bson_compare(op, a, b, can_compare_types=False) - - return _wrapped - - -def bson_compare(op, a, b, can_compare_types=True): - """Compare two elements using BSON comparison. - - Args: - op: the basic operation to compare (e.g. operator.lt, operator.ge). - a: the first operand - b: the second operand - can_compare_types: if True, according to BSON's definition order - between types is used, otherwise always return False when types are - different. - """ - a_type = _get_compare_type(a) - b_type = _get_compare_type(b) - if a_type != b_type: - return can_compare_types and op(a_type, b_type) - - # Compare DBRefs as dicts - if type(a).__name__ == "DBRef" and hasattr(a, "as_doc"): - a = a.as_doc() - if type(b).__name__ == "DBRef" and hasattr(b, "as_doc"): - b = b.as_doc() - - if isinstance(a, dict): - # MongoDb server compares the type before comparing the keys - # https://github.com/mongodb/mongo/blob/f10f214/src/mongo/bson/bsonelement.cpp#L516 - # even though the documentation does not say anything about that. - a = [(_get_compare_type(v), k, v) for k, v in a.items()] - b = [(_get_compare_type(v), k, v) for k, v in b.items()] - - if isinstance(a, (tuple, list)): - for item_a, item_b in zip(a, b): - if item_a != item_b: - return bson_compare(op, item_a, item_b) - return bson_compare(op, len(a), len(b)) - - if isinstance(a, NoneType): - return op(0, 0) - - # bson handles bytes as binary in python3+: - # https://api.mongodb.com/python/current/api/bson/index.html - if isinstance(a, bytes): - # Performs the same operation as described by: - # https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#bindata - if len(a) != len(b): - return op(len(a), len(b)) - # bytes is always treated as subtype 0 by the bson library - return op(a, b) - - -def _get_compare_type(val): - """Get a number representing the base type of the value used for comparison. - - See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/ - also https://github.com/mongodb/mongo/blob/46b28bb/src/mongo/bson/bsontypes.h#L175 - for canonical values. - """ - if isinstance(val, NoneType): - return 5 - if isinstance(val, bool): - return 40 - if isinstance(val, numbers.Number): - return 10 - if isinstance(val, str): - return 15 - if isinstance(val, dict): - return 20 - if isinstance(val, (tuple, list)): - return 25 - if isinstance(val, uuid.UUID): - return 30 - if isinstance(val, bytes): - return 30 - if isinstance(val, ObjectId): - return 35 - if isinstance(val, datetime): - return 45 - if isinstance(val, _RE_TYPES): - return 50 - if DBRef and isinstance(val, DBRef): - # According to the C++ code, this should be 55 but apparently sending a DBRef through - # pymongo is stored as a dict. - return 20 - return 0 - - -def _regex(doc_val, regex): - if not (isinstance(doc_val, (str, list)) or isinstance(doc_val, RE_TYPE)): - return False - if isinstance(regex, str): - regex = re.compile(regex) - if not isinstance(regex, RE_TYPE): - # bson.Regex - regex = regex.try_compile() - return any( - regex.search(item) for item in _force_list(doc_val) if isinstance(item, str) - ) - - -def _size_op(doc_val, search_val): - if isinstance(doc_val, (list, tuple, dict)): - return search_val == len(doc_val) - return search_val == 1 if doc_val and doc_val is not None else 0 - - -def _list_expand(f, negative=False): - def func(doc_val, search_val): - if isinstance(doc_val, (list, tuple)) and not isinstance( - search_val, (list, tuple) - ): - if negative: - return all(f(val, search_val) for val in doc_val) - return any(f(val, search_val) for val in doc_val) - return f(doc_val, search_val) - - return func - - -def _type_op(doc_val, search_val, in_array=False): - if search_val not in TYPE_MAP: - raise OperationFailure("%r is not a valid $type" % search_val) - elif TYPE_MAP[search_val] is None: - raise NotImplementedError( - "%s is a valid $type but not implemented" % search_val - ) - if TYPE_MAP[search_val](doc_val): - return True - if isinstance(doc_val, (list, tuple)) and not in_array: - return any(_type_op(val, search_val, in_array=True) for val in doc_val) - return False - - -def _combine_regex_options(search): - if not isinstance(search["$options"], str): - raise OperationFailure("$options has to be a string") - - options = None - for option in search["$options"]: - if option not in "imxs": - continue - re_option = getattr(re, option.upper()) - if options is None: - options = re_option - else: - options |= re_option - - search_copy = dict(search) - del search_copy["$options"] - - if options is None: - return search_copy - - if isinstance(search["$regex"], _RE_TYPES): - if isinstance(search["$regex"], RE_TYPE): - search_copy["$regex"] = re.compile( - search["$regex"].pattern, search["$regex"].flags | options - ) - else: - # bson.Regex - regex = search["$regex"] - search_copy["$regex"] = regex.__class__( - regex.pattern, regex.flags | options - ) - else: - search_copy["$regex"] = re.compile(search["$regex"], options) - return search_copy - - -def operator_eq(doc_val, search_val): - if doc_val is None and search_val is None: - return True - return operator.eq(doc_val, search_val) - - -SORTING_OPERATOR_MAP = { - "$gt": operator.gt, - "$gte": operator.ge, - "$lt": operator.lt, - "$lte": operator.le, -} - - -LOGICAL_OPERATOR_MAP = { - "$or": lambda d, subq, filter_func: any(filter_func(q, d) for q in subq), - "$and": lambda d, subq, filter_func: all(filter_func(q, d) for q in subq), - "$nor": lambda d, subq, filter_func: all(not filter_func(q, d) for q in subq), - "$not": lambda d, subq, filter_func: (not filter_func(q, d) for q in subq), -} - - -TYPE_MAP = { - "double": lambda v: isinstance(v, float), - "string": lambda v: isinstance(v, str), - "object": lambda v: isinstance(v, dict), - "array": lambda v: isinstance(v, list), - "binData": lambda v: isinstance(v, bytes), - "undefined": None, - "objectId": lambda v: isinstance(v, ObjectId), - "bool": lambda v: isinstance(v, bool), - "date": lambda v: isinstance(v, datetime), - "null": None, - "regex": None, - "dbPointer": None, - "javascript": None, - "symbol": None, - "javascriptWithScope": None, - "int": lambda v: ( - isinstance(v, int) and not isinstance(v, bool) and v.bit_length() <= 32 - ), - "timestamp": None, - "long": lambda v: ( - isinstance(v, int) and not isinstance(v, bool) and v.bit_length() > 32 - ), - "decimal": (lambda v: isinstance(v, Decimal128)) if Decimal128 else None, - "number": lambda v: ( - # pylint: disable-next=isinstance-second-argument-not-valid-type - isinstance(v, (int, float) + ((Decimal128,) if Decimal128 else ())) - and not isinstance(v, bool) - ), - "minKey": None, - "maxKey": None, -} - - -def resolve_key(key, doc): - return next(iter(iter_key_candidates(key, doc)), None) - - -def resolve_sort_key(key, doc): - value = resolve_key(key, doc) - # see http://docs.mongodb.org/manual/reference/method/cursor.sort/#ascending-descending-sort - if value is None: - return 1, BsonComparable(None) - - # List or tuples are sorted solely by their first value. - if isinstance(value, (tuple, list)): - if not value: - return 0, BsonComparable(None) - return 1, BsonComparable(value[0]) - - return 1, BsonComparable(value) - - -class BsonComparable(object): - """Wraps a value in an BSON like object that can be compared one to another.""" - - def __init__(self, obj): - self.obj = obj - - def __lt__(self, other): - return bson_compare(operator.lt, self.obj, other.obj) - - -_filterer_inst = _Filterer() - - -# Developer note: to avoid a cross-modules dependency (filtering requires aggregation, that requires -# filtering), the aggregation module needs to register its parse_expression function here. -def register_parse_expression(parse_expression): - """Register the parse_expression function from the aggregate module.""" - - del _Filterer.parse_expression[:] - _Filterer.parse_expression.append(parse_expression) diff --git a/packages/syft/tests/mongomock/gridfs.py b/packages/syft/tests/mongomock/gridfs.py deleted file mode 100644 index 13a59999855..00000000000 --- a/packages/syft/tests/mongomock/gridfs.py +++ /dev/null @@ -1,68 +0,0 @@ -# stdlib -from unittest import mock - -# relative -from . import Collection as MongoMockCollection -from . import Database as MongoMockDatabase -from ..collection import Cursor as MongoMockCursor - -try: - # third party - from gridfs.grid_file import GridOut as PyMongoGridOut - from gridfs.grid_file import GridOutCursor as PyMongoGridOutCursor - from pymongo.collection import Collection as PyMongoCollection - from pymongo.database import Database as PyMongoDatabase - - _HAVE_PYMONGO = True -except ImportError: - _HAVE_PYMONGO = False - - -# This is a copy of GridOutCursor but with a different base. Note that we -# need both classes as one might want to access both mongomock and real -# MongoDb. -class _MongoMockGridOutCursor(MongoMockCursor): - def __init__(self, collection, *args, **kwargs): - self.__root_collection = collection - super(_MongoMockGridOutCursor, self).__init__(collection.files, *args, **kwargs) - - def next(self): - next_file = super(_MongoMockGridOutCursor, self).next() - return PyMongoGridOut( - self.__root_collection, file_document=next_file, session=self.session - ) - - __next__ = next - - def add_option(self, *args, **kwargs): - raise NotImplementedError() - - def remove_option(self, *args, **kwargs): - raise NotImplementedError() - - def _clone_base(self, session): - return _MongoMockGridOutCursor(self.__root_collection, session=session) - - -def _create_grid_out_cursor(collection, *args, **kwargs): - if isinstance(collection, MongoMockCollection): - return _MongoMockGridOutCursor(collection, *args, **kwargs) - return PyMongoGridOutCursor(collection, *args, **kwargs) - - -def enable_gridfs_integration(): - """This function enables the use of mongomock Database's and Collection's inside gridfs - - Gridfs library use `isinstance` to make sure the passed elements - are valid `pymongo.Database/Collection` so we monkey patch those types in the gridfs modules - (luckily in the modules they are used, they are only used with isinstance). - """ - - if not _HAVE_PYMONGO: - raise NotImplementedError("gridfs mocking requires pymongo to work") - - mock.patch("gridfs.Database", (PyMongoDatabase, MongoMockDatabase)).start() - mock.patch( - "gridfs.grid_file.Collection", (PyMongoCollection, MongoMockCollection) - ).start() - mock.patch("gridfs.GridOutCursor", _create_grid_out_cursor).start() diff --git a/packages/syft/tests/mongomock/helpers.py b/packages/syft/tests/mongomock/helpers.py deleted file mode 100644 index 13f6892cae5..00000000000 --- a/packages/syft/tests/mongomock/helpers.py +++ /dev/null @@ -1,474 +0,0 @@ -# stdlib -from collections import OrderedDict -from collections import abc -from datetime import datetime -from datetime import timedelta -from datetime import tzinfo -import re -import time -from urllib.parse import unquote_plus -import warnings - -# third party -from packaging import version - -# relative -from . import InvalidURI - -# Get ObjectId from bson if available or import a crafted one. This is not used -# in this module but is made available for callers of this module. -try: - # third party - from bson import ObjectId # pylint: disable=unused-import - from bson import Timestamp - from pymongo import version as pymongo_version - - PYMONGO_VERSION = version.parse(pymongo_version) - HAVE_PYMONGO = True -except ImportError: - from .object_id import ObjectId # noqa - - Timestamp = None - # Default Pymongo version if not present. - PYMONGO_VERSION = version.parse("4.0") - HAVE_PYMONGO = False - -# Cache the RegExp pattern type. -RE_TYPE = type(re.compile("")) -_HOST_MATCH = re.compile(r"^([^@]+@)?([^:]+|\[[^\]]+\])(:([^:]+))?$") -_SIMPLE_HOST_MATCH = re.compile(r"^([^:]+|\[[^\]]+\])(:([^:]+))?$") - -try: - # third party - from bson.tz_util import utc -except ImportError: - - class _FixedOffset(tzinfo): - def __init__(self, offset, name): - self.__offset = timedelta(minutes=offset) - self.__name = name - - def __getinitargs__(self): - return self.__offset, self.__name - - def utcoffset(self, dt): - return self.__offset - - def tzname(self, dt): - return self.__name - - def dst(self, dt): - return timedelta(0) - - utc = _FixedOffset(0, "UTC") - - -ASCENDING = 1 -DESCENDING = -1 - - -def utcnow(): - """Simple wrapper for datetime.utcnow - - This provides a centralized definition of "now" in the mongomock realm, - allowing users to transform the value of "now" to the future or the past, - based on their testing needs. For example: - - ```python - def test_x(self): - with mock.patch("mongomock.utcnow") as mm_utc: - mm_utc = datetime.utcnow() + timedelta(hours=100) - # Test some things "100 hours" in the future - ``` - """ - return datetime.utcnow() - - -def print_deprecation_warning(old_param_name, new_param_name): - warnings.warn( - "'%s' has been deprecated to be in line with pymongo implementation, a new parameter '%s' " - "should be used instead. the old parameter will be kept for backward compatibility " - "purposes." % (old_param_name, new_param_name), - DeprecationWarning, - ) - - -def create_index_list(key_or_list, direction=None): - """Helper to generate a list of (key, direction) pairs. - - It takes such a list, or a single key, or a single key and direction. - """ - if isinstance(key_or_list, str): - return [(key_or_list, direction or ASCENDING)] - if not isinstance(key_or_list, (list, tuple, abc.Iterable)): - raise TypeError( - "if no direction is specified, " "key_or_list must be an instance of list" - ) - return key_or_list - - -def gen_index_name(index_list): - """Generate an index name based on the list of keys with directions.""" - - return "_".join(["%s_%s" % item for item in index_list]) - - -class hashdict(dict): - """hashable dict implementation, suitable for use as a key into other dicts. - - >>> h1 = hashdict({'apples': 1, 'bananas':2}) - >>> h2 = hashdict({'bananas': 3, 'mangoes': 5}) - >>> h1+h2 - hashdict(apples=1, bananas=3, mangoes=5) - >>> d1 = {} - >>> d1[h1] = 'salad' - >>> d1[h1] - 'salad' - >>> d1[h2] - Traceback (most recent call last): - ... - KeyError: hashdict(bananas=3, mangoes=5) - - based on answers from - http://stackoverflow.com/questions/1151658/python-hashable-dicts - """ - - def __key(self): - return frozenset( - ( - k, - ( - hashdict(v) - if isinstance(v, dict) - else tuple(v) - if isinstance(v, list) - else v - ), - ) - for k, v in self.items() - ) - - def __repr__(self): - return "{0}({1})".format( - self.__class__.__name__, - ", ".join( - "{0}={1}".format(str(i[0]), repr(i[1])) for i in sorted(self.__key()) - ), - ) - - def __hash__(self): - return hash(self.__key()) - - def __setitem__(self, key, value): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def __delitem__(self, key): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def clear(self): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def pop(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def popitem(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def setdefault(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def update(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def __add__(self, right): - result = hashdict(self) - dict.update(result, right) - return result - - -def fields_list_to_dict(fields): - """Takes a list of field names and returns a matching dictionary. - - ['a', 'b'] becomes {'a': 1, 'b': 1} - - and - - ['a.b.c', 'd', 'a.c'] becomes {'a.b.c': 1, 'd': 1, 'a.c': 1} - """ - as_dict = {} - for field in fields: - if not isinstance(field, str): - raise TypeError( - "fields must be a list of key names, each an instance of str" - ) - as_dict[field] = 1 - return as_dict - - -def parse_uri(uri, default_port=27017, warn=False): - """A simplified version of pymongo.uri_parser.parse_uri. - - Returns a dict with: - - nodelist, a tuple of (host, port) - - database the name of the database or None if no database is provided in the URI. - - An invalid MongoDB connection URI may raise an InvalidURI exception, - however, the URI is not fully parsed and some invalid URIs may not result - in an exception. - - 'mongodb://host1/database' becomes 'host1', 27017, 'database' - - and - - 'mongodb://host1' becomes 'host1', 27017, None - """ - SCHEME = "mongodb://" - - if not uri.startswith(SCHEME): - raise InvalidURI("Invalid URI scheme: URI " "must begin with '%s'" % (SCHEME,)) - - scheme_free = uri[len(SCHEME) :] - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - dbase = None - - # Check for unix domain sockets in the uri - if ".sock" in scheme_free: - host_part, _, path_part = scheme_free.rpartition("/") - if not host_part: - host_part = path_part - path_part = "" - if "/" in host_part: - raise InvalidURI( - "Any '/' in a unix domain socket must be" " URL encoded: %s" % host_part - ) - path_part = unquote_plus(path_part) - else: - host_part, _, path_part = scheme_free.partition("/") - - if not path_part and "?" in host_part: - raise InvalidURI("A '/' is required between " "the host list and any options.") - - nodelist = [] - if "," in host_part: - hosts = host_part.split(",") - else: - hosts = [host_part] - for host in hosts: - match = _HOST_MATCH.match(host) - if not match: - raise ValueError( - "Reserved characters such as ':' must be escaped according RFC " - "2396. An IPv6 address literal must be enclosed in '[' and ']' " - "according to RFC 2732." - ) - host = match.group(2) - if host.startswith("[") and host.endswith("]"): - host = host[1:-1] - - port = match.group(4) - if port: - try: - port = int(port) - if port < 0 or port > 65535: - raise ValueError() - except ValueError as err: - raise ValueError( - "Port must be an integer between 0 and 65535:", port - ) from err - else: - port = default_port - - nodelist.append((host, port)) - - if path_part and path_part[0] != "?": - dbase, _, _ = path_part.partition("?") - if "." in dbase: - dbase, _ = dbase.split(".", 1) - - if dbase is not None: - dbase = unquote_plus(dbase) - - return {"nodelist": tuple(nodelist), "database": dbase} - - -def split_hosts(hosts, default_port=27017): - """Split the entity into a list of tuples of host and port.""" - - nodelist = [] - for entity in hosts.split(","): - port = default_port - if entity.endswith(".sock"): - port = None - - match = _SIMPLE_HOST_MATCH.match(entity) - if not match: - raise ValueError( - "Reserved characters such as ':' must be escaped according RFC " - "2396. An IPv6 address literal must be enclosed in '[' and ']' " - "according to RFC 2732." - ) - host = match.group(1) - if host.startswith("[") and host.endswith("]"): - host = host[1:-1] - - if match.group(3): - try: - port = int(match.group(3)) - if port < 0 or port > 65535: - raise ValueError() - except ValueError as err: - raise ValueError( - "Port must be an integer between 0 and 65535:", port - ) from err - - nodelist.append((host, port)) - - return nodelist - - -_LAST_TIMESTAMP_INC = [] - - -def get_current_timestamp(): - """Get the current timestamp as a bson Timestamp object.""" - if not Timestamp: - raise NotImplementedError( - "timestamp is not supported. Import pymongo to use it." - ) - now = int(time.time()) - if _LAST_TIMESTAMP_INC and _LAST_TIMESTAMP_INC[0] == now: - _LAST_TIMESTAMP_INC[1] += 1 - else: - del _LAST_TIMESTAMP_INC[:] - _LAST_TIMESTAMP_INC.extend([now, 1]) - return Timestamp(now, _LAST_TIMESTAMP_INC[1]) - - -def patch_datetime_awareness_in_document(value): - # MongoDB is supposed to stock everything as timezone naive utc date - # Hence we have to convert incoming datetimes to avoid errors while - # mixing tz aware and naive. - # On top of that, MongoDB date precision is up to millisecond, where Python - # datetime use microsecond, so we must lower the precision to mimic mongo. - for best_type in (OrderedDict, dict): - if isinstance(value, best_type): - return best_type( - (k, patch_datetime_awareness_in_document(v)) for k, v in value.items() - ) - if isinstance(value, (tuple, list)): - return [patch_datetime_awareness_in_document(item) for item in value] - if isinstance(value, datetime): - mongo_us = (value.microsecond // 1000) * 1000 - if value.tzinfo: - return (value - value.utcoffset()).replace( - tzinfo=None, microsecond=mongo_us - ) - return value.replace(microsecond=mongo_us) - if Timestamp and isinstance(value, Timestamp) and not value.time and not value.inc: - return get_current_timestamp() - return value - - -def make_datetime_timezone_aware_in_document(value): - # MongoClient support tz_aware=True parameter to return timezone-aware - # datetime objects. Given the date is stored internally without timezone - # information, all returned datetime have utc as timezone. - if isinstance(value, dict): - return { - k: make_datetime_timezone_aware_in_document(v) for k, v in value.items() - } - if isinstance(value, (tuple, list)): - return [make_datetime_timezone_aware_in_document(item) for item in value] - if isinstance(value, datetime): - return value.replace(tzinfo=utc) - return value - - -def get_value_by_dot(doc, key, can_generate_array=False): - """Get dictionary value using dotted key""" - result = doc - key_items = key.split(".") - for key_index, key_item in enumerate(key_items): - if isinstance(result, dict): - result = result[key_item] - - elif isinstance(result, (list, tuple)): - try: - int_key = int(key_item) - except ValueError as err: - if not can_generate_array: - raise KeyError(key_index) from err - remaining_key = ".".join(key_items[key_index:]) - return [get_value_by_dot(subdoc, remaining_key) for subdoc in result] - - try: - result = result[int_key] - except (ValueError, IndexError) as err: - raise KeyError(key_index) from err - - else: - raise KeyError(key_index) - - return result - - -def set_value_by_dot(doc, key, value): - """Set dictionary value using dotted key""" - try: - parent_key, child_key = key.rsplit(".", 1) - parent = get_value_by_dot(doc, parent_key) - except ValueError: - child_key = key - parent = doc - - if isinstance(parent, dict): - parent[child_key] = value - elif isinstance(parent, (list, tuple)): - try: - parent[int(child_key)] = value - except (ValueError, IndexError) as err: - raise KeyError() from err - else: - raise KeyError() - - return doc - - -def delete_value_by_dot(doc, key): - """Delete dictionary value using dotted key. - - This function assumes that the value exists. - """ - try: - parent_key, child_key = key.rsplit(".", 1) - parent = get_value_by_dot(doc, parent_key) - except ValueError: - child_key = key - parent = doc - - del parent[child_key] - - return doc - - -def mongodb_to_bool(value): - """Converts any value to bool the way MongoDB does it""" - - return value not in [False, None, 0] diff --git a/packages/syft/tests/mongomock/mongo_client.py b/packages/syft/tests/mongomock/mongo_client.py deleted file mode 100644 index 560a7ce0f11..00000000000 --- a/packages/syft/tests/mongomock/mongo_client.py +++ /dev/null @@ -1,222 +0,0 @@ -# stdlib -import itertools -import warnings - -# third party -from packaging import version - -# relative -from . import ConfigurationError -from . import codec_options as mongomock_codec_options -from . import helpers -from . import read_preferences -from .database import Database -from .store import ServerStore - -try: - # third party - from pymongo import ReadPreference - from pymongo.uri_parser import parse_uri - from pymongo.uri_parser import split_hosts - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - # relative - from .helpers import parse_uri - from .helpers import split_hosts - - _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY - - -def _convert_version_to_list(version_str): - pieces = [int(part) for part in version_str.split(".")] - return pieces + [0] * (4 - len(pieces)) - - -class MongoClient(object): - HOST = "localhost" - PORT = 27017 - _CONNECTION_ID = itertools.count() - - def __init__( - self, - host=None, - port=None, - document_class=dict, - tz_aware=False, - connect=True, - _store=None, - read_preference=None, - uuidRepresentation=None, - type_registry=None, - **kwargs, - ): - if host: - self.host = host[0] if isinstance(host, (list, tuple)) else host - else: - self.host = self.HOST - self.port = port or self.PORT - - self._tz_aware = tz_aware - self._codec_options = mongomock_codec_options.CodecOptions( - tz_aware=tz_aware, - uuid_representation=uuidRepresentation, - type_registry=type_registry, - ) - self._database_accesses = {} - self._store = _store or ServerStore() - self._id = next(self._CONNECTION_ID) - self._document_class = document_class - if read_preference is not None: - read_preferences.ensure_read_preference_type( - "read_preference", read_preference - ) - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - - dbase = None - - if "://" in self.host: - res = parse_uri(self.host, default_port=self.port, warn=True) - self.host, self.port = res["nodelist"][0] - dbase = res["database"] - else: - self.host, self.port = split_hosts(self.host, default_port=self.port)[0] - - self.__default_database_name = dbase - # relative - from . import SERVER_VERSION - - self._server_version = SERVER_VERSION - - def __getitem__(self, db_name): - return self.get_database(db_name) - - def __getattr__(self, attr): - return self[attr] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __repr__(self): - return "mongomock.MongoClient('{0}', {1})".format(self.host, self.port) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.address == other.address - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash(self.address) - - def close(self): - pass - - @property - def is_mongos(self): - return True - - @property - def is_primary(self): - return True - - @property - def address(self): - return self.host, self.port - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - def server_info(self): - return { - "version": self._server_version, - "sysInfo": "Mock", - "versionArray": _convert_version_to_list(self._server_version), - "bits": 64, - "debug": False, - "maxBsonObjectSize": 16777216, - "ok": 1, - } - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def database_names(self): - warnings.warn( - "database_names is deprecated. Use list_database_names instead." - ) - return self.list_database_names() - - def list_database_names(self): - return self._store.list_created_database_names() - - def drop_database(self, name_or_db): - def drop_collections_for_db(_db): - db_store = self._store[_db.name] - for col_name in db_store.list_created_collection_names(): - _db.drop_collection(col_name) - - if isinstance(name_or_db, Database): - db = next(db for db in self._database_accesses.values() if db is name_or_db) - if db: - drop_collections_for_db(db) - - elif name_or_db in self._store: - db = self.get_database(name_or_db) - drop_collections_for_db(db) - - def get_database( - self, - name=None, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - if name is None: - db = self.get_default_database( - codec_options=codec_options, - read_preference=read_preference, - write_concern=write_concern, - read_concern=read_concern, - ) - else: - db = self._database_accesses.get(name) - if db is None: - db_store = self._store[name] - db = self._database_accesses[name] = Database( - self, - name, - read_preference=read_preference or self.read_preference, - codec_options=codec_options or self._codec_options, - _store=db_store, - read_concern=read_concern, - ) - return db - - def get_default_database(self, default=None, **kwargs): - name = self.__default_database_name - name = name if name is not None else default - if name is None: - raise ConfigurationError("No default database name defined or provided.") - - return self.get_database(name=name, **kwargs) - - def alive(self): - """The original MongoConnection.alive method checks the status of the server. - - In our case as we mock the actual server, we should always return True. - """ - return True - - def start_session(self, causal_consistency=True, default_transaction_options=None): - """Start a logical session.""" - raise NotImplementedError("Mongomock does not support sessions yet") diff --git a/packages/syft/tests/mongomock/not_implemented.py b/packages/syft/tests/mongomock/not_implemented.py deleted file mode 100644 index 990b89a411e..00000000000 --- a/packages/syft/tests/mongomock/not_implemented.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Module to handle features that are not implemented yet.""" - -_IGNORED_FEATURES = { - "array_filters": False, - "collation": False, - "let": False, - "session": False, -} - - -def _ensure_ignorable_feature(feature): - if feature not in _IGNORED_FEATURES: - raise KeyError( - "%s is not an error that can be ignored: maybe it has been implemented in Mongomock. " - "Here is the list of features that can be ignored: %s" - % (feature, _IGNORED_FEATURES.keys()) - ) - - -def ignore_feature(feature): - """Ignore a feature instead of raising a NotImplementedError.""" - _ensure_ignorable_feature(feature) - _IGNORED_FEATURES[feature] = True - - -def warn_on_feature(feature): - """Rasie a NotImplementedError the next times a feature is used.""" - _ensure_ignorable_feature(feature) - _IGNORED_FEATURES[feature] = False - - -def raise_for_feature(feature, reason): - _ensure_ignorable_feature(feature) - if _IGNORED_FEATURES[feature]: - return False - raise NotImplementedError(reason) diff --git a/packages/syft/tests/mongomock/object_id.py b/packages/syft/tests/mongomock/object_id.py deleted file mode 100644 index 281e8b02663..00000000000 --- a/packages/syft/tests/mongomock/object_id.py +++ /dev/null @@ -1,26 +0,0 @@ -# stdlib -import uuid - - -class ObjectId(object): - def __init__(self, id=None): - super(ObjectId, self).__init__() - if id is None: - self._id = uuid.uuid1() - else: - self._id = uuid.UUID(id) - - def __eq__(self, other): - return isinstance(other, ObjectId) and other._id == self._id - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash(self._id) - - def __repr__(self): - return "ObjectId({0})".format(self._id) - - def __str__(self): - return str(self._id) diff --git a/packages/syft/tests/mongomock/patch.py b/packages/syft/tests/mongomock/patch.py deleted file mode 100644 index 8db36497e75..00000000000 --- a/packages/syft/tests/mongomock/patch.py +++ /dev/null @@ -1,120 +0,0 @@ -# stdlib -import time - -# relative -from .mongo_client import MongoClient - -try: - # stdlib - from unittest import mock - - _IMPORT_MOCK_ERROR = None -except ImportError: - try: - # third party - import mock - - _IMPORT_MOCK_ERROR = None - except ImportError as error: - _IMPORT_MOCK_ERROR = error - -try: - # third party - import pymongo - from pymongo.uri_parser import parse_uri - from pymongo.uri_parser import split_hosts - - _IMPORT_PYMONGO_ERROR = None -except ImportError as error: - # relative - from .helpers import parse_uri - from .helpers import split_hosts - - _IMPORT_PYMONGO_ERROR = error - - -def _parse_any_host(host, default_port=27017): - if isinstance(host, tuple): - return _parse_any_host(host[0], host[1]) - if "://" in host: - return parse_uri(host, warn=True)["nodelist"] - return split_hosts(host, default_port=default_port) - - -def patch(servers="localhost", on_new="error"): - """Patch pymongo.MongoClient. - - This will patch the class MongoClient and use mongomock to mock MongoDB - servers. It keeps a consistant state of servers across multiple clients so - you can do: - - ``` - client = pymongo.MongoClient(host='localhost', port=27017) - client.db.coll.insert_one({'name': 'Pascal'}) - - other_client = pymongo.MongoClient('mongodb://localhost:27017') - client.db.coll.find_one() - ``` - - The data is persisted as long as the patch lives. - - Args: - on_new: Behavior when accessing a new server (not in servers): - 'create': mock a new empty server, accept any client connection. - 'error': raise a ValueError immediately when trying to access. - 'timeout': behave as pymongo when a server does not exist, raise an - error after a timeout. - 'pymongo': use an actual pymongo client. - servers: a list of server that are avaiable. - """ - - if _IMPORT_MOCK_ERROR: - raise _IMPORT_MOCK_ERROR # pylint: disable=raising-bad-type - - if _IMPORT_PYMONGO_ERROR: - PyMongoClient = None - else: - PyMongoClient = pymongo.MongoClient - - persisted_clients = {} - parsed_servers = set() - for server in servers if isinstance(servers, (list, tuple)) else [servers]: - parsed_servers.update(_parse_any_host(server)) - - def _create_persistent_client(*args, **kwargs): - if _IMPORT_PYMONGO_ERROR: - raise _IMPORT_PYMONGO_ERROR # pylint: disable=raising-bad-type - - client = MongoClient(*args, **kwargs) - - try: - persisted_client = persisted_clients[client.address] - client._store = persisted_client._store - return client - except KeyError: - pass - - if client.address in parsed_servers or on_new == "create": - persisted_clients[client.address] = client - return client - - if on_new == "timeout": - # TODO(pcorpet): Only wait when trying to access the server's data. - time.sleep(kwargs.get("serverSelectionTimeoutMS", 30000)) - raise pymongo.errors.ServerSelectionTimeoutError( - "%s:%d: [Errno 111] Connection refused" % client.address - ) - - if on_new == "pymongo": - return PyMongoClient(*args, **kwargs) - - raise ValueError( - "MongoDB server %s:%d does not exist.\n" % client.address - + "%s" % parsed_servers - ) - - class _PersistentClient: - def __new__(cls, *args, **kwargs): - return _create_persistent_client(*args, **kwargs) - - return mock.patch("pymongo.MongoClient", _PersistentClient) diff --git a/packages/syft/tests/mongomock/read_concern.py b/packages/syft/tests/mongomock/read_concern.py deleted file mode 100644 index 229e0f78bb4..00000000000 --- a/packages/syft/tests/mongomock/read_concern.py +++ /dev/null @@ -1,21 +0,0 @@ -class ReadConcern(object): - def __init__(self, level=None): - self._document = {} - - if level is not None: - self._document["level"] = level - - @property - def level(self): - return self._document.get("level") - - @property - def ok_for_legacy(self): - return True - - @property - def document(self): - return self._document.copy() - - def __eq__(self, other): - return other.document == self.document diff --git a/packages/syft/tests/mongomock/read_preferences.py b/packages/syft/tests/mongomock/read_preferences.py deleted file mode 100644 index a9349e6c576..00000000000 --- a/packages/syft/tests/mongomock/read_preferences.py +++ /dev/null @@ -1,42 +0,0 @@ -class _Primary(object): - @property - def mongos_mode(self): - return "primary" - - @property - def mode(self): - return 0 - - @property - def name(self): - return "Primary" - - @property - def document(self): - return {"mode": "primary"} - - @property - def tag_sets(self): - return [{}] - - @property - def max_staleness(self): - return -1 - - @property - def min_wire_version(self): - return 0 - - -def ensure_read_preference_type(key, value): - """Raise a TypeError if the value is not a type compatible for ReadPreference.""" - for attr in ("document", "mode", "mongos_mode", "max_staleness"): - if not hasattr(value, attr): - raise TypeError( - "{} must be an instance of {}".format( - key, "pymongo.read_preference.ReadPreference" - ) - ) - - -PRIMARY = _Primary() diff --git a/packages/syft/tests/mongomock/results.py b/packages/syft/tests/mongomock/results.py deleted file mode 100644 index 07633a6e82e..00000000000 --- a/packages/syft/tests/mongomock/results.py +++ /dev/null @@ -1,117 +0,0 @@ -try: - # third party - from pymongo.results import BulkWriteResult - from pymongo.results import DeleteResult - from pymongo.results import InsertManyResult - from pymongo.results import InsertOneResult - from pymongo.results import UpdateResult -except ImportError: - - class _WriteResult(object): - def __init__(self, acknowledged=True): - self.__acknowledged = acknowledged - - @property - def acknowledged(self): - return self.__acknowledged - - class InsertOneResult(_WriteResult): - __slots__ = ("__inserted_id", "__acknowledged") - - def __init__(self, inserted_id, acknowledged=True): - self.__inserted_id = inserted_id - super(InsertOneResult, self).__init__(acknowledged) - - @property - def inserted_id(self): - return self.__inserted_id - - class InsertManyResult(_WriteResult): - __slots__ = ("__inserted_ids", "__acknowledged") - - def __init__(self, inserted_ids, acknowledged=True): - self.__inserted_ids = inserted_ids - super(InsertManyResult, self).__init__(acknowledged) - - @property - def inserted_ids(self): - return self.__inserted_ids - - class UpdateResult(_WriteResult): - __slots__ = ("__raw_result", "__acknowledged") - - def __init__(self, raw_result, acknowledged=True): - self.__raw_result = raw_result - super(UpdateResult, self).__init__(acknowledged) - - @property - def raw_result(self): - return self.__raw_result - - @property - def matched_count(self): - if self.upserted_id is not None: - return 0 - return self.__raw_result.get("n", 0) - - @property - def modified_count(self): - return self.__raw_result.get("nModified") - - @property - def upserted_id(self): - return self.__raw_result.get("upserted") - - class DeleteResult(_WriteResult): - __slots__ = ("__raw_result", "__acknowledged") - - def __init__(self, raw_result, acknowledged=True): - self.__raw_result = raw_result - super(DeleteResult, self).__init__(acknowledged) - - @property - def raw_result(self): - return self.__raw_result - - @property - def deleted_count(self): - return self.__raw_result.get("n", 0) - - class BulkWriteResult(_WriteResult): - __slots__ = ("__bulk_api_result", "__acknowledged") - - def __init__(self, bulk_api_result, acknowledged): - self.__bulk_api_result = bulk_api_result - super(BulkWriteResult, self).__init__(acknowledged) - - @property - def bulk_api_result(self): - return self.__bulk_api_result - - @property - def inserted_count(self): - return self.__bulk_api_result.get("nInserted") - - @property - def matched_count(self): - return self.__bulk_api_result.get("nMatched") - - @property - def modified_count(self): - return self.__bulk_api_result.get("nModified") - - @property - def deleted_count(self): - return self.__bulk_api_result.get("nRemoved") - - @property - def upserted_count(self): - return self.__bulk_api_result.get("nUpserted") - - @property - def upserted_ids(self): - if self.__bulk_api_result: - return dict( - (upsert["index"], upsert["_id"]) - for upsert in self.bulk_api_result["upserted"] - ) diff --git a/packages/syft/tests/mongomock/store.py b/packages/syft/tests/mongomock/store.py deleted file mode 100644 index 9cef7206329..00000000000 --- a/packages/syft/tests/mongomock/store.py +++ /dev/null @@ -1,191 +0,0 @@ -# stdlib -import collections -import datetime -import functools - -# relative -from .helpers import utcnow -from .thread import RWLock - - -class ServerStore(object): - """Object holding the data for a whole server (many databases).""" - - def __init__(self): - self._databases = {} - - def __getitem__(self, db_name): - try: - return self._databases[db_name] - except KeyError: - db = self._databases[db_name] = DatabaseStore() - return db - - def __contains__(self, db_name): - return self[db_name].is_created - - def list_created_database_names(self): - return [name for name, db in self._databases.items() if db.is_created] - - -class DatabaseStore(object): - """Object holding the data for a database (many collections).""" - - def __init__(self): - self._collections = {} - - def __getitem__(self, col_name): - try: - return self._collections[col_name] - except KeyError: - col = self._collections[col_name] = CollectionStore(col_name) - return col - - def __contains__(self, col_name): - return self[col_name].is_created - - def list_created_collection_names(self): - return [name for name, col in self._collections.items() if col.is_created] - - def create_collection(self, name): - col = self[name] - col.create() - return col - - def rename(self, name, new_name): - col = self._collections.pop(name, CollectionStore(new_name)) - col.name = new_name - self._collections[new_name] = col - - @property - def is_created(self): - return any(col.is_created for col in self._collections.values()) - - -class CollectionStore(object): - """Object holding the data for a collection.""" - - def __init__(self, name): - self._documents = collections.OrderedDict() - self.indexes = {} - self._is_force_created = False - self.name = name - self._ttl_indexes = {} - - # 694 - Lock for safely iterating and mutating OrderedDicts - self._rwlock = RWLock() - - def create(self): - self._is_force_created = True - - @property - def is_created(self): - return self._documents or self.indexes or self._is_force_created - - def drop(self): - self._documents = collections.OrderedDict() - self.indexes = {} - self._ttl_indexes = {} - self._is_force_created = False - - def create_index(self, index_name, index_dict): - self.indexes[index_name] = index_dict - if index_dict.get("expireAfterSeconds") is not None: - self._ttl_indexes[index_name] = index_dict - - def drop_index(self, index_name): - self._remove_expired_documents() - - # The main index object should raise a KeyError, but the - # TTL indexes have no meaning to the outside. - del self.indexes[index_name] - self._ttl_indexes.pop(index_name, None) - - @property - def is_empty(self): - self._remove_expired_documents() - return not self._documents - - def __contains__(self, key): - self._remove_expired_documents() - with self._rwlock.reader(): - return key in self._documents - - def __getitem__(self, key): - self._remove_expired_documents() - with self._rwlock.reader(): - return self._documents[key] - - def __setitem__(self, key, val): - with self._rwlock.writer(): - self._documents[key] = val - - def __delitem__(self, key): - with self._rwlock.writer(): - del self._documents[key] - - def __len__(self): - self._remove_expired_documents() - with self._rwlock.reader(): - return len(self._documents) - - @property - def documents(self): - self._remove_expired_documents() - with self._rwlock.reader(): - for doc in self._documents.values(): - yield doc - - def _remove_expired_documents(self): - for index in self._ttl_indexes.values(): - self._expire_documents(index) - - def _expire_documents(self, index): - # TODO(juannyg): use a caching mechanism to avoid re-expiring the documents if - # we just did and no document was added / updated - - # Ignore non-integer values - try: - expiry = int(index["expireAfterSeconds"]) - except ValueError: - return - - # Ignore commpound keys - if len(index["key"]) > 1: - return - - # "key" structure = list of (field name, direction) tuples - ttl_field_name = next(iter(index["key"]))[0] - ttl_now = utcnow() - - with self._rwlock.reader(): - expired_ids = [ - doc["_id"] - for doc in self._documents.values() - if self._value_meets_expiry(doc.get(ttl_field_name), expiry, ttl_now) - ] - - for exp_id in expired_ids: - del self[exp_id] - - def _value_meets_expiry(self, val, expiry, ttl_now): - val_to_compare = _get_min_datetime_from_value(val) - try: - return (ttl_now - val_to_compare).total_seconds() >= expiry - except TypeError: - return False - - -def _get_min_datetime_from_value(val): - if not val: - return datetime.datetime.max - if isinstance(val, list): - return functools.reduce(_min_dt, [datetime.datetime.max] + val) - return val - - -def _min_dt(dt1, dt2): - try: - return dt1 if dt1 < dt2 else dt2 - except TypeError: - return dt1 diff --git a/packages/syft/tests/mongomock/thread.py b/packages/syft/tests/mongomock/thread.py deleted file mode 100644 index ff673e44309..00000000000 --- a/packages/syft/tests/mongomock/thread.py +++ /dev/null @@ -1,94 +0,0 @@ -# stdlib -from contextlib import contextmanager -import threading - - -class RWLock: - """Lock enabling multiple readers but only 1 exclusive writer - - Source: https://cutt.ly/Ij70qaq - """ - - def __init__(self): - self._read_switch = _LightSwitch() - self._write_switch = _LightSwitch() - self._no_readers = threading.Lock() - self._no_writers = threading.Lock() - self._readers_queue = threading.RLock() - - @contextmanager - def reader(self): - self._reader_acquire() - try: - yield - except Exception: # pylint: disable=W0706 - raise - finally: - self._reader_release() - - @contextmanager - def writer(self): - self._writer_acquire() - try: - yield - except Exception: # pylint: disable=W0706 - raise - finally: - self._writer_release() - - def _reader_acquire(self): - """Readers should block whenever a writer has acquired""" - self._readers_queue.acquire() - self._no_readers.acquire() - self._read_switch.acquire(self._no_writers) - self._no_readers.release() - self._readers_queue.release() - - def _reader_release(self): - self._read_switch.release(self._no_writers) - - def _writer_acquire(self): - """Acquire the writer lock. - - Only the first writer will lock the readtry and then - all subsequent writers can simply use the resource as - it gets freed by the previous writer. The very last writer must - release the readtry semaphore, thus opening the gate for readers - to try reading. - - No reader can engage in the entry section if the readtry semaphore - has been set by a writer previously - """ - self._write_switch.acquire(self._no_readers) - self._no_writers.acquire() - - def _writer_release(self): - self._no_writers.release() - self._write_switch.release(self._no_readers) - - -class _LightSwitch: - """An auxiliary "light switch"-like object - - The first thread turns on the "switch", the last one turns it off. - - Source: https://cutt.ly/Ij70qaq - """ - - def __init__(self): - self._counter = 0 - self._mutex = threading.RLock() - - def acquire(self, lock): - self._mutex.acquire() - self._counter += 1 - if self._counter == 1: - lock.acquire() - self._mutex.release() - - def release(self, lock): - self._mutex.acquire() - self._counter -= 1 - if self._counter == 0: - lock.release() - self._mutex.release() diff --git a/packages/syft/tests/mongomock/write_concern.py b/packages/syft/tests/mongomock/write_concern.py deleted file mode 100644 index 93760445647..00000000000 --- a/packages/syft/tests/mongomock/write_concern.py +++ /dev/null @@ -1,45 +0,0 @@ -def _with_default_values(document): - if "w" in document: - return document - return dict(document, w=1) - - -class WriteConcern(object): - def __init__(self, w=None, wtimeout=None, j=None, fsync=None): - self._document = {} - if w is not None: - self._document["w"] = w - if wtimeout is not None: - self._document["wtimeout"] = wtimeout - if j is not None: - self._document["j"] = j - if fsync is not None: - self._document["fsync"] = fsync - - def __eq__(self, other): - try: - return _with_default_values(other.document) == _with_default_values( - self.document - ) - except AttributeError: - return NotImplemented - - def __ne__(self, other): - try: - return _with_default_values(other.document) != _with_default_values( - self.document - ) - except AttributeError: - return NotImplemented - - @property - def acknowledged(self): - return True - - @property - def document(self): - return self._document.copy() - - @property - def is_server_default(self): - return not self._document diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index 39d6b1871bc..851a83cb7c2 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -23,20 +23,18 @@ def test_actionobject_method(worker): root_datasite_client = worker.root_client assert root_datasite_client.settings.enable_eager_execution(enable=True) - action_store = worker.services.action.store + action_store = worker.services.action.stash obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) - assert len(action_store.data) == 1 + assert len(action_store._data) == 1 res = pointer.capitalize() - assert len(action_store.data) == 2 + assert len(action_store._data) == 2 assert res[0] == "A" -@pytest.mark.parametrize("delete_original_admin", [False, True]) def test_new_admin_has_action_object_permission( worker: Worker, faker: Faker, - delete_original_admin: bool, ) -> None: root_client = worker.root_client @@ -60,10 +58,6 @@ def test_new_admin_has_action_object_permission( root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN) - if delete_original_admin: - res = root_client.api.services.user.delete(root_client.account.id) - assert not isinstance(res, SyftError) - assert admin.api.services.action.get(obj.id) == obj @@ -75,7 +69,7 @@ def test_lib_function_action(worker): assert isinstance(res, ActionObject) assert all(res == np.array([0, 0, 0])) - assert len(worker.services.action.store.data) > 0 + assert len(worker.services.action.stash._data) > 0 def test_call_lib_function_action2(worker): @@ -90,7 +84,7 @@ def test_lib_class_init_action(worker): assert isinstance(res, ActionObject) assert res == np.float32(4.0) - assert len(worker.services.action.store.data) > 0 + assert len(worker.services.action.stash._data) > 0 def test_call_lib_wo_permission(worker): diff --git a/packages/syft/tests/syft/api_test.py b/packages/syft/tests/syft/api_test.py index 6c511b45d48..ca2f61ac147 100644 --- a/packages/syft/tests/syft/api_test.py +++ b/packages/syft/tests/syft/api_test.py @@ -45,7 +45,7 @@ def test_api_cache_invalidation_login(root_verify_key, worker): name="q", email="a@b.org", password="aaa", password_verify="aaa" ) guest_client = guest_client.login(email="a@b.org", password="aaa") - user_id = worker.document_store.partitions["User"].all(root_verify_key).value[-1].id + user_id = worker.root_client.users[-1].id def get_role(verify_key): users = worker.services.user.stash.get_all(root_verify_key).ok() diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 8b8613498fb..47e33f7926d 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -1,6 +1,5 @@ # stdlib import io -import random # third party import numpy as np @@ -42,10 +41,7 @@ def test_blob_storage_allocate(authed_context, blob_storage): assert isinstance(blob_deposit, BlobDeposit) -def test_blob_storage_write(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_write(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key @@ -60,10 +56,7 @@ def test_blob_storage_write(): worker.cleanup() -def test_blob_storage_write_syft_object(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_write_syft_object(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key @@ -78,10 +71,7 @@ def test_blob_storage_write_syft_object(): worker.cleanup() -def test_blob_storage_read(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_read(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py index d177aaa508e..bfec6e00895 100644 --- a/packages/syft/tests/syft/dataset/dataset_stash_test.py +++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py @@ -1,55 +1,16 @@ # third party import pytest -from typeguard import TypeCheckError # syft absolute from syft.service.dataset.dataset import Dataset -from syft.service.dataset.dataset_stash import ActionIDsPartitionKey -from syft.service.dataset.dataset_stash import NamePartitionKey -from syft.store.document_store import QueryKey +from syft.service.dataset.dataset_stash import DatasetStash from syft.store.document_store_errors import NotFoundException from syft.types.uid import UID -def test_dataset_namepartitionkey() -> None: - mock_obj = "dummy_name_key" - - assert NamePartitionKey.key == "name" - assert NamePartitionKey.type_ == str - - name_partition_key = NamePartitionKey.with_obj(obj=mock_obj) - - assert isinstance(name_partition_key, QueryKey) - assert name_partition_key.key == "name" - assert name_partition_key.type_ == str - assert name_partition_key.value == mock_obj - - with pytest.raises(AttributeError): - NamePartitionKey.with_obj(obj=[UID()]) - - -def test_dataset_actionidpartitionkey() -> None: - mock_obj = [UID() for _ in range(3)] - - assert ActionIDsPartitionKey.key == "action_ids" - assert ActionIDsPartitionKey.type_ == list[UID] - - action_ids_partition_key = ActionIDsPartitionKey.with_obj(obj=mock_obj) - - assert isinstance(action_ids_partition_key, QueryKey) - assert action_ids_partition_key.key == "action_ids" - assert action_ids_partition_key.type_ == list[UID] - assert action_ids_partition_key.value == mock_obj - - with pytest.raises(AttributeError): - ActionIDsPartitionKey.with_obj(obj="dummy_str") - - # Not sure what Exception should be raised here, Type or Attibute - with pytest.raises(TypeCheckError): - ActionIDsPartitionKey.with_obj(obj=["first_str", "second_str"]) - - -def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) -> None: +def test_dataset_get_by_name( + root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset: Dataset +) -> None: # retrieving existing dataset result = mock_dataset_stash.get_by_name(root_verify_key, mock_dataset.name) assert result.is_ok() @@ -63,7 +24,9 @@ def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) assert type(result.err()) is NotFoundException -def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dataset): +def test_dataset_search_action_ids( + root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset +): action_id = mock_dataset.assets[0].action_id result = mock_dataset_stash.search_action_ids(root_verify_key, uid=action_id) @@ -72,12 +35,6 @@ def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dat assert isinstance(result.ok()[0], Dataset) assert result.ok()[0].id == mock_dataset.id - # retrieving dataset by list of action_ids - result = mock_dataset_stash.search_action_ids(root_verify_key, uid=[action_id]) - assert result.is_ok() - assert isinstance(result.ok()[0], Dataset) - assert result.ok()[0].id == mock_dataset.id - # retrieving dataset by non-existing action_id other_action_id = UID() result = mock_dataset_stash.search_action_ids(root_verify_key, uid=other_action_id) @@ -86,5 +43,5 @@ def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dat # passing random object random_obj = object() - with pytest.raises(AttributeError): + with pytest.raises(ValueError): result = mock_dataset_stash.search_action_ids(root_verify_key, uid=random_obj) diff --git a/packages/syft/tests/syft/dataset/fixtures.py b/packages/syft/tests/syft/dataset/fixtures.py index 9c062e756bc..bcb26bff262 100644 --- a/packages/syft/tests/syft/dataset/fixtures.py +++ b/packages/syft/tests/syft/dataset/fixtures.py @@ -60,7 +60,9 @@ def mock_asset(worker, root_datasite_client) -> Asset: @pytest.fixture -def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset: +def mock_dataset( + root_verify_key, mock_dataset_stash: DatasetStash, mock_asset +) -> Dataset: uploader = Contributor( role=str(Roles.UPLOADER), name="test", @@ -70,7 +72,7 @@ def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset: id=UID(), name="test_dataset", uploader=uploader, contributors=[uploader] ) mock_dataset.asset_list.append(mock_asset) - result = mock_dataset_stash.partition.set(root_verify_key, mock_dataset) + result = mock_dataset_stash.set(root_verify_key, mock_dataset) mock_dataset = result.ok() yield mock_dataset diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py index 18fd4b85394..243a18130a2 100644 --- a/packages/syft/tests/syft/eager_test.py +++ b/packages/syft/tests/syft/eager_test.py @@ -174,7 +174,7 @@ def test_setattribute(worker, guest_client): obj_pointer.dtype = np.int32 # local object is updated - assert obj_pointer.id.id in worker.action_store.data + assert obj_pointer.id.id in worker.action_store._data assert obj_pointer.id != original_id res = root_datasite_client.api.services.action.get(obj_pointer.id) @@ -206,7 +206,7 @@ def test_getattribute(worker, guest_client): size_pointer = obj_pointer.size # check result - assert size_pointer.id.id in worker.action_store.data + assert size_pointer.id.id in worker.action_store._data assert root_datasite_client.api.services.action.get(size_pointer.id) == 6 @@ -226,7 +226,7 @@ def test_eager_method(worker, guest_client): flat_pointer = obj_pointer.flatten() - assert flat_pointer.id.id in worker.action_store.data + assert flat_pointer.id.id in worker.action_store._data # check result assert all( root_datasite_client.api.services.action.get(flat_pointer.id) @@ -250,7 +250,7 @@ def test_eager_dunder_method(worker, guest_client): first_row_pointer = obj_pointer[0] - assert first_row_pointer.id.id in worker.action_store.data + assert first_row_pointer.id.id in worker.action_store._data # check result assert all( root_datasite_client.api.services.action.get(first_row_pointer.id) diff --git a/packages/syft/tests/syft/migrations/data_migration_test.py b/packages/syft/tests/syft/migrations/data_migration_test.py index 708c56ac75a..a5203e2a0f8 100644 --- a/packages/syft/tests/syft/migrations/data_migration_test.py +++ b/packages/syft/tests/syft/migrations/data_migration_test.py @@ -115,7 +115,7 @@ def test_get_migration_data(worker, tmp_path): @contextmanager def named_worker_context(name): # required to launch worker with same name twice within the same test + ensure cleanup - worker = sy.Worker.named(name=name) + worker = sy.Worker.named(name=name, db_url="sqlite://") try: yield worker finally: diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py index 059c729d921..64c670d5ea3 100644 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py @@ -20,6 +20,7 @@ from syft.service.service import ServiceConfigRegistry from syft.service.service import service_method from syft.service.user.user_roles import GUEST_ROLE_LEVEL +from syft.store.db.db import DBManager from syft.store.document_store import DocumentStore from syft.store.document_store import NewBaseStash from syft.store.document_store import PartitionSettings @@ -85,7 +86,7 @@ class SyftMockObjectStash(NewBaseStash): object_type=syft_object, ) - def __init__(self, store: DocumentStore) -> None: + def __init__(self, store: DBManager) -> None: super().__init__(store=store) return SyftMockObjectStash @@ -103,8 +104,7 @@ class SyftMockObjectService(AbstractService): stash: stash_klass # type: ignore __module__: str = "syft.test" - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = stash_klass(store=store) @service_method( diff --git a/packages/syft/tests/syft/network_test.py b/packages/syft/tests/syft/network_test.py new file mode 100644 index 00000000000..3bb4b5e84e3 --- /dev/null +++ b/packages/syft/tests/syft/network_test.py @@ -0,0 +1,31 @@ +# syft absolute +from syft.abstract_server import ServerType +from syft.server.credentials import SyftSigningKey +from syft.service.network.network_service import NetworkStash +from syft.service.network.server_peer import ServerPeer +from syft.service.network.server_peer import ServerPeerUpdate +from syft.types.uid import UID + + +def test_add_route() -> None: + uid = UID() + peer = ServerPeer( + id=uid, + name="test", + verify_key=SyftSigningKey.generate().verify_key, + server_type=ServerType.DATASITE, + admin_email="info@openmined.org", + ) + network_stash = NetworkStash.random() + + network_stash.set( + credentials=network_stash.db.root_verify_key, + obj=peer, + ).unwrap() + peer_update = ServerPeerUpdate(id=uid, name="new name") + peer = network_stash.update( + credentials=network_stash.db.root_verify_key, + obj=peer_update, + ).unwrap() + + assert peer.name == "new name" diff --git a/packages/syft/tests/syft/notifications/notification_service_test.py b/packages/syft/tests/syft/notifications/notification_service_test.py index f48e77ab97d..a8319d32a80 100644 --- a/packages/syft/tests/syft/notifications/notification_service_test.py +++ b/packages/syft/tests/syft/notifications/notification_service_test.py @@ -144,20 +144,12 @@ def test_get_all_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_inbox_for_verify_key(*args, **kwargs) -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_inbox_for_verify_key", - mock_get_all_inbox_for_verify_key, - ) - response = test_notification_service.get_all(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_client_verify_key = None + response[0].syft_server_location = None assert response[0] == expected_message @@ -188,9 +180,6 @@ def mock_get_all_inbox_for_verify_key( def test_get_sent_success( - root_verify_key, - monkeypatch: MonkeyPatch, - notification_service: NotificationService, authed_context: AuthedServiceContext, document_store: DocumentStore, ) -> None: @@ -207,20 +196,12 @@ def test_get_sent_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_sent_for_verify_key(credentials, verify_key) -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_sent_for_verify_key", - mock_get_all_sent_for_verify_key, - ) - response = test_notification_service.get_all_sent(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message @@ -340,19 +321,12 @@ def test_get_all_read_success( NotificationStatus.READ, ) - def mock_get_all_by_verify_key_for_status() -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_by_verify_key_for_status", - mock_get_all_by_verify_key_for_status, - ) - response = test_notification_service.get_all_read(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message @@ -404,19 +378,11 @@ def test_get_all_unread_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_by_verify_key_for_status() -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_by_verify_key_for_status", - mock_get_all_by_verify_key_for_status, - ) - response = test_notification_service.get_all_unread(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message diff --git a/packages/syft/tests/syft/notifications/notification_stash_test.py b/packages/syft/tests/syft/notifications/notification_stash_test.py index b848324a2b7..7864fb10d19 100644 --- a/packages/syft/tests/syft/notifications/notification_stash_test.py +++ b/packages/syft/tests/syft/notifications/notification_stash_test.py @@ -8,13 +8,7 @@ # syft absolute from syft.server.credentials import SyftSigningKey from syft.server.credentials import SyftVerifyKey -from syft.service.notification.notification_stash import ( - OrderByCreatedAtTimeStampPartitionKey, -) -from syft.service.notification.notification_stash import FromUserVerifyKeyPartitionKey from syft.service.notification.notification_stash import NotificationStash -from syft.service.notification.notification_stash import StatusPartitionKey -from syft.service.notification.notification_stash import ToUserVerifyKeyPartitionKey from syft.service.notification.notifications import Notification from syft.service.notification.notifications import NotificationExpiryStatus from syft.service.notification.notifications import NotificationStatus @@ -60,86 +54,14 @@ def add_mock_notification( return mock_notification -def test_fromuserverifykey_partitionkey() -> None: - random_verify_key = SyftSigningKey.generate().verify_key - - assert FromUserVerifyKeyPartitionKey.type_ == SyftVerifyKey - assert FromUserVerifyKeyPartitionKey.key == "from_user_verify_key" - - result = FromUserVerifyKeyPartitionKey.with_obj(random_verify_key) - - assert result.type_ == SyftVerifyKey - assert result.key == "from_user_verify_key" - - assert result.value == random_verify_key - - signing_key = SyftSigningKey.from_string(test_signing_key_string) - with pytest.raises(AttributeError): - FromUserVerifyKeyPartitionKey.with_obj(signing_key) - - -def test_touserverifykey_partitionkey() -> None: - random_verify_key = SyftSigningKey.generate().verify_key - - assert ToUserVerifyKeyPartitionKey.type_ == SyftVerifyKey - assert ToUserVerifyKeyPartitionKey.key == "to_user_verify_key" - - result = ToUserVerifyKeyPartitionKey.with_obj(random_verify_key) - - assert result.type_ == SyftVerifyKey - assert result.key == "to_user_verify_key" - assert result.value == random_verify_key - - signing_key = SyftSigningKey.from_string(test_signing_key_string) - with pytest.raises(AttributeError): - ToUserVerifyKeyPartitionKey.with_obj(signing_key) - - -def test_status_partitionkey() -> None: - assert StatusPartitionKey.key == "status" - assert StatusPartitionKey.type_ == NotificationStatus - - result1 = StatusPartitionKey.with_obj(NotificationStatus.UNREAD) - result2 = StatusPartitionKey.with_obj(NotificationStatus.READ) - - assert result1.type_ == NotificationStatus - assert result1.key == "status" - assert result1.value == NotificationStatus.UNREAD - assert result2.type_ == NotificationStatus - assert result2.key == "status" - assert result2.value == NotificationStatus.READ - - notification_expiry_status_auto = NotificationExpiryStatus(0) - - with pytest.raises(AttributeError): - StatusPartitionKey.with_obj(notification_expiry_status_auto) - - -def test_orderbycreatedattimestamp_partitionkey() -> None: - random_datetime = DateTime.now() - - assert OrderByCreatedAtTimeStampPartitionKey.key == "created_at" - assert OrderByCreatedAtTimeStampPartitionKey.type_ == DateTime - - result = OrderByCreatedAtTimeStampPartitionKey.with_obj(random_datetime) - - assert result.type_ == DateTime - assert result.key == "created_at" - assert result.value == random_datetime - - def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_inbox_for_verify_key( + result = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ) - - assert response.is_ok() - - result = response.ok() + ).unwrap() assert len(result) == 0 # list of mock notifications @@ -152,14 +74,11 @@ def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: notification_list.append(mock_notification) # returned list of notifications from stash that's sorted by created_at - response2 = test_stash.get_all_inbox_for_verify_key( + result = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ) + ).unwrap() - assert response2.is_ok() - - result = response2.ok() - assert len(response2.value) == 5 + assert len(result) == 5 for notification in notification_list: # check if all notifications are present in the result @@ -205,12 +124,9 @@ def test_get_all_sent_for_verify_key(root_verify_key, document_store) -> None: def test_get_all_for_verify_key(root_verify_key, document_store) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key - query_key = FromUserVerifyKeyPartitionKey.with_obj(test_verify_key) test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_for_verify_key( - root_verify_key, random_verify_key, query_key - ) + response = test_stash.get_all_for_verify_key(root_verify_key, random_verify_key) assert response.is_ok() @@ -221,11 +137,8 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None: root_verify_key, test_stash, test_verify_key, random_verify_key ) - query_key2 = FromUserVerifyKeyPartitionKey.with_obj( - mock_notification.from_user_verify_key - ) response_from_verify_key = test_stash.get_all_for_verify_key( - root_verify_key, mock_notification.from_user_verify_key, query_key2 + root_verify_key, mock_notification.from_user_verify_key ) assert response_from_verify_key.is_ok() @@ -235,7 +148,7 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None: assert result[0] == mock_notification response_from_verify_key_string = test_stash.get_all_for_verify_key( - root_verify_key, test_verify_key_string, query_key2 + root_verify_key, test_verify_key_string ) assert response_from_verify_key_string.is_ok() @@ -249,28 +162,21 @@ def test_get_all_by_verify_key_for_status(root_verify_key, document_store) -> No random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_by_verify_key_for_status( + result = test_stash.get_all_by_verify_key_for_status( root_verify_key, random_verify_key, NotificationStatus.READ - ) - - assert response.is_ok() - - result = response.ok() + ).unwrap() assert len(result) == 0 mock_notification = add_mock_notification( root_verify_key, test_stash, test_verify_key, random_verify_key ) - response2 = test_stash.get_all_by_verify_key_for_status( + result2 = test_stash.get_all_by_verify_key_for_status( root_verify_key, mock_notification.to_user_verify_key, NotificationStatus.UNREAD - ) - assert response2.is_ok() + ).unwrap() + assert len(result2) == 1 - result = response2.ok() - assert len(result) == 1 - - assert result[0] == mock_notification + assert result2[0] == mock_notification with pytest.raises(AttributeError): test_stash.get_all_by_verify_key_for_status( @@ -288,7 +194,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None: root_verify_key, uid=random_uid, status=NotificationStatus.READ ).unwrap() - assert exc.type is SyftException + assert issubclass(exc.type, SyftException) assert exc.value.public_message mock_notification = add_mock_notification( @@ -314,7 +220,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None: status=notification_expiry_status_auto, ).unwrap() - assert exc.type is SyftException + assert issubclass(exc.type, SyftException) assert exc.value.public_message @@ -326,6 +232,10 @@ def test_update_notification_status_error_on_get_by_uid( test_stash = NotificationStash(store=document_store) expected_error_msg = f"No notification exists for id: {random_verify_key}" + add_mock_notification( + root_verify_key, test_stash, test_verify_key, random_verify_key + ) + @as_result(StashException) def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn: raise StashException(public_message=f"No notification exists for id: {uid}") @@ -335,11 +245,6 @@ def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn: "get_by_uid", mock_get_by_uid, ) - - add_mock_notification( - root_verify_key, test_stash, test_verify_key, random_verify_key - ) - with pytest.raises(StashException) as exc: test_stash.update_notification_status( root_verify_key, random_verify_key, NotificationStatus.READ @@ -354,11 +259,9 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None: random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.delete_all_for_verify_key(root_verify_key, test_verify_key) - - assert response.is_ok() - - result = response.ok() + result = test_stash.delete_all_for_verify_key( + root_verify_key, test_verify_key + ).unwrap() assert result is True add_mock_notification( @@ -367,23 +270,23 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None: inbox_before = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ).value + ).unwrap() assert len(inbox_before) == 1 - response2 = test_stash.delete_all_for_verify_key(root_verify_key, random_verify_key) - - assert response2.is_ok() - - result = response2.ok() - assert result is True + result2 = test_stash.delete_all_for_verify_key( + root_verify_key, random_verify_key + ).unwrap() + assert result2 is True inbox_after = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ).value + ).unwrap() assert len(inbox_after) == 0 with pytest.raises(AttributeError): - test_stash.delete_all_for_verify_key(root_verify_key, random_signing_key) + test_stash.delete_all_for_verify_key( + root_verify_key, random_signing_key + ).unwrap() def test_delete_all_for_verify_key_error_on_get_all_inbox_for_verify_key( diff --git a/packages/syft/tests/syft/request/request_stash_test.py b/packages/syft/tests/syft/request/request_stash_test.py index a492c2f6b9f..a9115d5c934 100644 --- a/packages/syft/tests/syft/request/request_stash_test.py +++ b/packages/syft/tests/syft/request/request_stash_test.py @@ -1,9 +1,4 @@ -# stdlib -from typing import NoReturn - # third party -import pytest -from pytest import MonkeyPatch # syft absolute from syft.client.client import SyftClient @@ -12,10 +7,6 @@ from syft.service.request.request import Request from syft.service.request.request import SubmitRequest from syft.service.request.request_stash import RequestStash -from syft.service.request.request_stash import RequestingUserVerifyKeyPartitionKey -from syft.store.document_store import PartitionKey -from syft.store.document_store import QueryKeys -from syft.types.errors import SyftException def test_requeststash_get_all_for_verify_key_no_requests( @@ -33,7 +24,6 @@ def test_requeststash_get_all_for_verify_key_no_requests( assert len(requests.ok()) == 0 -@pytest.mark.xfail def test_requeststash_get_all_for_verify_key_success( root_verify_key, request_stash: RequestStash, @@ -77,60 +67,3 @@ def test_requeststash_get_all_for_verify_key_success( requests.ok()[1] == stash_set_result_2.ok() or requests.ok()[0] == stash_set_result_2.ok() ) - - -def test_requeststash_get_all_for_verify_key_fail( - root_verify_key, - request_stash: RequestStash, - monkeypatch: MonkeyPatch, - guest_datasite_client: SyftClient, -) -> None: - verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key - mock_error_message = ( - "verify key not in the document store's unique or searchable keys" - ) - - def mock_query_all_error( - credentials: SyftVerifyKey, qks: QueryKeys, order_by: PartitionKey | None - ) -> NoReturn: - raise SyftException(public_message=mock_error_message) - - monkeypatch.setattr(request_stash, "query_all", mock_query_all_error) - - with pytest.raises(SyftException) as exc: - request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap() - - assert exc.type is SyftException - assert exc.value.public_message == mock_error_message - - -def test_requeststash_get_all_for_verify_key_find_index_fail( - root_verify_key, - request_stash: RequestStash, - monkeypatch: MonkeyPatch, - guest_datasite_client: SyftClient, -) -> None: - verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key - qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)]) - - mock_error_message = f"Failed to query index or search with {qks.all[0]}" - - def mock_find_index_or_search_keys_error( - credentials: SyftVerifyKey, - index_qks: QueryKeys, - search_qks: QueryKeys, - order_by: PartitionKey | None, - ) -> NoReturn: - raise SyftException(public_message=mock_error_message) - - monkeypatch.setattr( - request_stash.partition, - "find_index_or_search_keys", - mock_find_index_or_search_keys_error, - ) - - with pytest.raises(SyftException) as exc: - request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap() - - assert exc.type == SyftException - assert exc.value.public_message == mock_error_message diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index 76fd4d82685..a57a61e1509 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -506,15 +506,15 @@ def test_actionobject_syft_get_path(testcase): def test_actionobject_syft_send_get(worker, testcase): root_datasite_client = worker.root_client root_datasite_client._fetch_api(root_datasite_client.credentials) - action_store = worker.services.action.store + action_store = worker.services.action.stash orig_obj = testcase obj = helper_make_action_obj(orig_obj) - assert len(action_store.data) == 0 + assert len(action_store._data) == 0 ptr = obj.send(root_datasite_client) - assert len(action_store.data) == 1 + assert len(action_store._data) == 1 retrieved = ptr.get() assert obj.syft_action_data == retrieved @@ -1001,7 +1001,7 @@ def test_actionobject_syft_getattr_float_history(): @pytest.mark.skipif( - sys.platform != "linux", + sys.platform == "win32", reason="This is a hackish way to test attribute set/get, and it might fail on Windows or OSX", ) def test_actionobject_syft_getattr_np(worker): diff --git a/packages/syft/tests/syft/service/action/action_service_test.py b/packages/syft/tests/syft/service/action/action_service_test.py index 5a6b5561d99..e97362a6340 100644 --- a/packages/syft/tests/syft/service/action/action_service_test.py +++ b/packages/syft/tests/syft/service/action/action_service_test.py @@ -22,6 +22,6 @@ def test_action_service_sanity(worker): obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) - assert len(service.store.data) == 1 + assert len(service.stash._data) == 1 res = pointer.capitalize() assert res[0] == "A" diff --git a/packages/syft/tests/syft/service_permission_test.py b/packages/syft/tests/syft/service_permission_test.py index afc266005f3..ceb6d63923c 100644 --- a/packages/syft/tests/syft/service_permission_test.py +++ b/packages/syft/tests/syft/service_permission_test.py @@ -9,12 +9,8 @@ @pytest.fixture def guest_mock_user(root_verify_key, user_stash, guest_user): - result = user_stash.partition.set(root_verify_key, guest_user) - assert result.is_ok() - - user = result.ok() + user = user_stash.set(root_verify_key, guest_user).unwrap() assert user is not None - yield user diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index aaa7b0460fc..7555aadd91e 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -34,7 +34,6 @@ from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException from syft.types.errors import SyftException -from syft.types.result import Ok from syft.types.result import as_result @@ -100,37 +99,22 @@ def test_settingsservice_set_success( ) -> None: response = settings_service.set(authed_context, settings) assert isinstance(response, ServerSettings) + response.syft_client_verify_key = None + response.syft_server_location = None + response.pwd_token_config.syft_client_verify_key = None + response.pwd_token_config.syft_server_location = None + response.welcome_markdown.syft_client_verify_key = None + response.welcome_markdown.syft_server_location = None assert response == settings -def test_settingsservice_set_fail( - monkeypatch: MonkeyPatch, - settings_service: SettingsService, - settings: ServerSettings, - authed_context: AuthedServiceContext, -) -> None: - mock_error_message = "database failure" - - @as_result(StashException) - def mock_stash_set_error(credentials, settings: ServerSettings) -> NoReturn: - raise StashException(public_message=mock_error_message) - - monkeypatch.setattr(settings_service.stash, "set", mock_stash_set_error) - - with pytest.raises(StashException) as exc: - settings_service.set(authed_context, settings) - - assert exc.type == StashException - assert exc.value.public_message == mock_error_message - - def add_mock_settings( root_verify_key: SyftVerifyKey, settings_stash: SettingsStash, settings: ServerSettings, ) -> ServerSettings: # create a mock settings in the stash so that we can update it - result = settings_stash.partition.set(root_verify_key, settings) + result = settings_stash.set(root_verify_key, settings) assert result.is_ok() created_settings = result.ok() @@ -150,9 +134,7 @@ def test_settingsservice_update_success( notifier_stash: NotifierStash, ) -> None: # add a mock settings to the stash - mock_settings = add_mock_settings( - authed_context.credentials, settings_stash, settings - ) + mock_settings = settings_stash.set(authed_context.credentials, settings).unwrap() # get a new settings according to update_settings new_settings = deepcopy(settings) @@ -164,14 +146,6 @@ def test_settingsservice_update_success( assert new_settings != mock_settings assert mock_settings == settings - mock_stash_get_all_output = [mock_settings, mock_settings] - - def mock_stash_get_all(root_verify_key) -> Ok: - return Ok(mock_stash_get_all_output) - - monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all) - - # Mock the get_service method to return a mocked notifier_service with the notifier_stash class MockNotifierService: def __init__(self, stash): self.stash = stash @@ -194,33 +168,7 @@ def mock_get_service(service_name: str): # update the settings in the settings stash using settings_service response = settings_service.update(context=authed_context, settings=update_settings) - # not_updated_settings = response.ok()[1] - assert isinstance(response, SyftSuccess) - # assert ( - # not_updated_settings.to_dict() == settings.to_dict() - # ) # the second settings is not updated - - -def test_settingsservice_update_stash_get_all_fail( - monkeypatch: MonkeyPatch, - settings_service: SettingsService, - update_settings: ServerSettingsUpdate, - authed_context: AuthedServiceContext, -) -> None: - mock_error_message = "database failure" - - @as_result(StashException) - def mock_stash_get_all_error(credentials) -> NoReturn: - raise StashException(public_message=mock_error_message) - - monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all_error) - - with pytest.raises(StashException) as exc: - settings_service.update(context=authed_context, settings=update_settings) - - assert exc.type == StashException - assert exc.value.public_message == mock_error_message def test_settingsservice_update_stash_empty( @@ -230,9 +178,7 @@ def test_settingsservice_update_stash_empty( ) -> None: with pytest.raises(NotFoundException) as exc: settings_service.update(context=authed_context, settings=update_settings) - - assert exc.type == NotFoundException - assert exc.value.public_message == "Server settings not found" + assert exc.value.public_message == "Server settings not found" def test_settingsservice_update_fail( @@ -248,7 +194,7 @@ def test_settingsservice_update_fail( mock_stash_get_all_output = [settings, settings] @as_result(StashException) - def mock_stash_get_all(credentials) -> list[ServerSettings]: + def mock_stash_get_all(credentials, **kwargs) -> list[ServerSettings]: return mock_stash_get_all_output monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all) @@ -256,7 +202,7 @@ def mock_stash_get_all(credentials) -> list[ServerSettings]: mock_update_error_message = "Failed to update obj ServerMetadata" @as_result(StashException) - def mock_stash_update_error(credentials, settings: ServerSettings) -> NoReturn: + def mock_stash_update_error(credentials, obj: ServerSettings) -> NoReturn: raise StashException(public_message=mock_update_error_message) monkeypatch.setattr(settings_service.stash, "update", mock_stash_update_error) @@ -309,7 +255,7 @@ def test_settings_allow_guest_registration( new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") guest_datasite_client = worker.guest_client root_datasite_client = worker.root_client @@ -343,7 +289,7 @@ def test_settings_allow_guest_registration( new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") guest_datasite_client = worker.guest_client root_datasite_client = worker.root_client @@ -402,7 +348,7 @@ def get_mock_client(faker, root_client, role): new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") root_client = worker.root_client emails_added = [] diff --git a/packages/syft/tests/syft/settings/settings_stash_test.py b/packages/syft/tests/syft/settings/settings_stash_test.py index 3fbbd28f9e9..2d976b52108 100644 --- a/packages/syft/tests/syft/settings/settings_stash_test.py +++ b/packages/syft/tests/syft/settings/settings_stash_test.py @@ -1,54 +1,26 @@ -# third party - # syft absolute from syft.service.settings.settings import ServerSettings from syft.service.settings.settings import ServerSettingsUpdate from syft.service.settings.settings_stash import SettingsStash -def add_mock_settings( - root_verify_key, settings_stash: SettingsStash, settings: ServerSettings -) -> ServerSettings: - # prepare: add mock settings - result = settings_stash.partition.set(root_verify_key, settings) - assert result.is_ok() - - created_settings = result.ok() - assert created_settings is not None - - return created_settings - - def test_settingsstash_set( - root_verify_key, settings_stash: SettingsStash, settings: ServerSettings -) -> None: - result = settings_stash.set(root_verify_key, settings) - assert result.is_ok() - - created_settings = result.ok() - assert isinstance(created_settings, ServerSettings) - assert created_settings == settings - assert settings.id in settings_stash.partition.data - - -def test_settingsstash_update( root_verify_key, settings_stash: SettingsStash, settings: ServerSettings, update_settings: ServerSettingsUpdate, ) -> None: - # prepare: add a mock settings - mock_settings = add_mock_settings(root_verify_key, settings_stash, settings) + created_settings = settings_stash.set(root_verify_key, settings).unwrap() + assert isinstance(created_settings, ServerSettings) + assert created_settings == settings + assert settings_stash.exists(root_verify_key, settings.id) # update mock_settings according to update_settings update_kwargs = update_settings.to_dict(exclude_empty=True).items() for field_name, value in update_kwargs: - setattr(mock_settings, field_name, value) + setattr(settings, field_name, value) # update the settings in the stash - result = settings_stash.update(root_verify_key, settings=mock_settings) - - assert result.is_ok() - updated_settings = result.ok() + updated_settings = settings_stash.update(root_verify_key, obj=settings).unwrap() assert isinstance(updated_settings, ServerSettings) - assert mock_settings == updated_settings + assert settings == updated_settings diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 375204908c1..5c2fe63be0d 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -1,24 +1,26 @@ # stdlib -import sys -from typing import Any # third party import pytest # syft absolute +from syft.server.credentials import SyftSigningKey from syft.server.credentials import SyftVerifyKey +from syft.service.action.action_object import ActionObject +from syft.service.action.action_permissions import ActionObjectOWNER +from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectOWNER from syft.service.action.action_store import ActionObjectREAD +from syft.service.action.action_store import ActionObjectStash from syft.service.action.action_store import ActionObjectWRITE +from syft.service.user.user import User +from syft.service.user.user_roles import ServiceRole +from syft.service.user.user_stash import UserStash +from syft.store.db.db import DBManager from syft.types.uid import UID # relative -from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN -from .store_constants_test import TEST_VERIFY_KEY_STRING_CLIENT -from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT -from .store_mocks_test import MockSyftObject +from ..worker_test import action_object_stash # noqa: F401 permissions = [ ActionObjectOWNER, @@ -28,134 +30,119 @@ ] -@pytest.mark.parametrize( - "store", - [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), - ], -) -def test_action_store_sanity(store: Any): - assert hasattr(store, "store_config") - assert hasattr(store, "settings") - assert hasattr(store, "data") - assert hasattr(store, "permissions") - assert hasattr(store, "root_verify_key") - assert store.root_verify_key.verify == TEST_VERIFY_KEY_STRING_ROOT +def add_user(db_manager: DBManager, role: ServiceRole) -> SyftVerifyKey: + user_stash = UserStash(store=db_manager) + verify_key = SyftSigningKey.generate().verify_key + user_stash.set( + credentials=db_manager.root_verify_key, + obj=User(verify_key=verify_key, role=role, id=UID()), + ).unwrap() + return verify_key + + +def add_test_object( + stash: ActionObjectStash, verify_key: SyftVerifyKey +) -> ActionObject: + test_object = ActionObject.from_obj([1, 2, 3]) + uid = test_object.id + stash.set_or_update( + uid=uid, + credentials=verify_key, + syft_object=test_object, + has_result_read_permission=True, + ).unwrap() + return uid @pytest.mark.parametrize( - "store", + "stash", [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), + pytest.lazy_fixture("action_object_stash"), ], ) @pytest.mark.parametrize("permission", permissions) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -@pytest.mark.skipif(sys.platform == "darwin", reason="skip on mac") -def test_action_store_test_permissions(store: Any, permission: Any): - client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) - root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - hacker_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - new_admin_key = TEST_VERIFY_KEY_NEW_ADMIN - - access = permission(uid=UID(), credentials=client_key) - access_root = permission(uid=UID(), credentials=root_key) - access_hacker = permission(uid=UID(), credentials=hacker_key) - access_new_admin = permission(uid=UID(), credentials=new_admin_key) - - # add permission - store.add_permission(access) - - assert store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) +def test_action_store_test_permissions( + stash: ActionObjectStash, permission: ActionObjectPermission +) -> None: + client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + root_key = add_user(stash.db, ServiceRole.ADMIN) + hacker_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + new_admin_key = add_user(stash.db, ServiceRole.ADMIN) + + test_item_id = add_test_object(stash, client_key) + + access = permission(uid=test_item_id, credentials=client_key) + access_root = permission(uid=test_item_id, credentials=root_key) + access_hacker = permission(uid=test_item_id, credentials=hacker_key) + access_new_admin = permission(uid=test_item_id, credentials=new_admin_key) + + stash.add_permission(access) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # remove permission - store.remove_permission(access) + stash.remove_permission(access) - assert not store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + assert not stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # take ownership with new UID - client_uid2 = UID() - access = permission(uid=client_uid2, credentials=client_key) + item2_id = add_test_object(stash, client_key) + access = permission(uid=item2_id, credentials=client_key) - store.take_ownership(client_uid2, client_key) - assert store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + stash.add_permission(ActionObjectREAD(uid=item2_id, credentials=client_key)) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # delete UID as hacker - access_hacker_ro = ActionObjectREAD(uid=UID(), credentials=hacker_key) - store.add_permission(access_hacker_ro) - res = store.delete(client_uid2, hacker_key) + res = stash.delete_by_uid(hacker_key, item2_id) assert res.is_err() - assert store.has_permission(access) - assert store.has_permission(access_new_admin) - assert store.has_permission(access_hacker_ro) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # delete UID as owner - res = store.delete(client_uid2, client_key) + res = stash.delete_by_uid(client_key, item2_id) assert res.is_ok() - assert not store.has_permission(access) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + assert not stash.has_permission(access) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) @pytest.mark.parametrize( - "store", + "stash", [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), + pytest.lazy_fixture("action_object_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_action_store_test_dataset_get(store: Any): - client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) - root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) +def test_action_store_test_dataset_get(stash: ActionObjectStash) -> None: + client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + root_key = add_user(stash.db, ServiceRole.ADMIN) - permission_only_uid = UID() - access = ActionObjectWRITE(uid=permission_only_uid, credentials=client_key) - access_root = ActionObjectWRITE(uid=permission_only_uid, credentials=root_key) - read_permission = ActionObjectREAD(uid=permission_only_uid, credentials=client_key) + data_uid = add_test_object(stash, client_key) + access = ActionObjectWRITE(uid=data_uid, credentials=client_key) + access_root = ActionObjectWRITE(uid=data_uid, credentials=root_key) + read_permission = ActionObjectREAD(uid=data_uid, credentials=client_key) # add permission - store.add_permission(access) + stash.add_permission(access) - assert store.has_permission(access) - assert store.has_permission(access_root) + assert stash.has_permission(access) + assert stash.has_permission(access_root) - store.add_permission(read_permission) - assert store.has_permission(read_permission) + stash.add_permission(read_permission) + assert stash.has_permission(read_permission) # check that trying to get action data that doesn't exist returns an error, even if have permissions - res = store.get(permission_only_uid, client_key) - assert res.is_err() - - # add data - data_uid = UID() - obj = MockSyftObject(data=1) - - res = store.set(data_uid, client_key, obj, has_result_read_permission=True) - assert res.is_ok() - res = store.get(data_uid, client_key) - assert res.is_ok() - assert res.ok() == obj - - assert store.exists(data_uid) - res = store.delete(data_uid, client_key) - assert res.is_ok() - res = store.delete(data_uid, client_key) + stash.delete_by_uid(client_key, data_uid) + res = stash.get(data_uid, client_key) assert res.is_err() diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index c61806fb31b..8ed9a312b86 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,15 +12,15 @@ # syft absolute from syft.serde.serializable import serializable -from syft.store.dict_document_store import DictDocumentStore -from syft.store.document_store import NewBaseUIDStoreStash +from syft.service.queue.queue_stash import Status +from syft.service.request.request_service import RequestService +from syft.store.db.sqlite import SQLiteDBConfig +from syft.store.db.sqlite import SQLiteDBManager +from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey -from syft.store.document_store import PartitionSettings -from syft.store.document_store import QueryKey -from syft.store.document_store import QueryKeys -from syft.store.document_store import UIDPartitionKey from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException +from syft.store.linked_obj import LinkedObject from syft.types.errors import SyftException from syft.types.syft_object import SyftObject from syft.types.uid import UID @@ -35,6 +35,8 @@ class MockObject(SyftObject): desc: str importance: int value: int + linked_obj: LinkedObject | None = None + status: Status = Status.CREATED __attr_searchable__ = ["id", "name", "desc", "importance"] __attr_unique__ = ["id", "name"] @@ -45,24 +47,14 @@ class MockObject(SyftObject): ImportancePartitionKey = PartitionKey(key="importance", type_=int) -class MockStash(NewBaseUIDStoreStash): - object_type = MockObject - settings = PartitionSettings( - name=MockObject.__canonical_name__, object_type=MockObject - ) +class MockStash(ObjectStash[MockObject]): + pass def get_object_values(obj: SyftObject) -> tuple[Any]: return tuple(obj.to_dict().values()) -def add_mock_object(root_verify_key, stash: MockStash, obj: MockObject) -> MockObject: - result = stash.set(root_verify_key, obj) - assert result.is_ok() - - return result.ok() - - T = TypeVar("T") P = ParamSpec("P") @@ -80,7 +72,11 @@ def create_unique( @pytest.fixture def base_stash(root_verify_key) -> MockStash: - yield MockStash(store=DictDocumentStore(UID(), root_verify_key)) + config = SQLiteDBConfig() + db_manager = SQLiteDBManager(config, UID(), root_verify_key) + mock_stash = MockStash(store=db_manager) + db_manager.init_tables() + yield mock_stash def random_sentence(faker: Faker) -> str: @@ -119,8 +115,7 @@ def mock_objects(faker: Faker) -> list[MockObject]: def test_basestash_set( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - result = add_mock_object(root_verify_key, base_stash, mock_object) - + result = base_stash.set(root_verify_key, mock_object).unwrap() assert result is not None assert result == mock_object @@ -132,11 +127,10 @@ def test_basestash_set_duplicate( MockObject(**kwargs) for kwargs in multiple_object_kwargs(faker, n=2, same=True) ) - result = base_stash.set(root_verify_key, original) - assert result.is_ok() + base_stash.set(root_verify_key, original).unwrap() - result = base_stash.set(root_verify_key, duplicate) - assert result.is_err() + with pytest.raises(StashException): + base_stash.set(root_verify_key, duplicate).unwrap() def test_basestash_set_duplicate_unique_key( @@ -157,28 +151,19 @@ def test_basestash_set_duplicate_unique_key( def test_basestash_delete( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) - - result = base_stash.delete( - root_verify_key, UIDPartitionKey.with_obj(mock_object.id) - ) - assert result.is_ok() - - assert len(base_stash.get_all(root_verify_key).ok()) == 0 + base_stash.set(root_verify_key, mock_object).unwrap() + base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap() + assert len(base_stash.get_all(root_verify_key).unwrap()) == 0 def test_basestash_cannot_delete_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() random_uid = create_unique(UID, [mock_object.id]) - for result in [ - base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(random_uid)), - base_stash.delete_by_uid(root_verify_key, random_uid), - ]: - result = base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(UID())) - assert result.is_err() + result = base_stash.delete_by_uid(root_verify_key, random_uid) + assert result.is_err() assert ( len( @@ -193,7 +178,7 @@ def test_basestash_cannot_delete_non_existent( def test_basestash_update( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() updated_obj = mock_object.copy() updated_obj.name = faker.name() @@ -208,7 +193,7 @@ def test_basestash_update( def test_basestash_cannot_update_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() updated_obj = mock_object.copy() updated_obj.id = create_unique(UID, [mock_object.id]) @@ -227,10 +212,7 @@ def test_basestash_set_get_all( stored_objects = base_stash.get_all( root_verify_key, - ) - assert stored_objects.is_ok() - - stored_objects = stored_objects.ok() + ).unwrap() assert len(stored_objects) == len(mock_objects) stored_objects_values = {get_object_values(obj) for obj in stored_objects} @@ -241,11 +223,10 @@ def test_basestash_set_get_all( def test_basestash_get_by_uid( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() - result = base_stash.get_by_uid(root_verify_key, mock_object.id) - assert result.is_ok() - assert result.ok() == mock_object + result = base_stash.get_by_uid(root_verify_key, mock_object.id).unwrap() + assert result == mock_object random_uid = create_unique(UID, [mock_object.id]) bad_uid = base_stash.get_by_uid(root_verify_key, random_uid) @@ -262,12 +243,9 @@ def test_basestash_get_by_uid( def test_basestash_delete_by_uid( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() - result = base_stash.delete_by_uid(root_verify_key, mock_object.id) - assert result.is_ok() - - response = result.ok() + response = base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap() assert isinstance(response, UID) result = base_stash.get_by_uid(root_verify_key, mock_object.id) @@ -288,43 +266,73 @@ def test_basestash_query_one( base_stash.set(root_verify_key, obj) obj = random.choice(mock_objects) + result = base_stash.get_one( + root_verify_key, + filters={"name": obj.name}, + ).unwrap() - for result in ( - base_stash.query_one_kwargs(root_verify_key, name=obj.name), - base_stash.query_one( - root_verify_key, QueryKey.from_obj(NamePartitionKey, obj.name) - ), - ): - assert result.is_ok() - assert result.ok() == obj + assert result == obj existing_names = {obj.name for obj in mock_objects} random_name = create_unique(faker.name, existing_names) - for result in ( - base_stash.query_one_kwargs(root_verify_key, name=random_name), - base_stash.query_one( - root_verify_key, QueryKey.from_obj(NamePartitionKey, random_name) - ), - ): - assert result.is_err() - assert isinstance(result.err(), NotFoundException) + with pytest.raises(NotFoundException): + result = base_stash.get_one( + root_verify_key, + filters={"name": random_name}, + ).unwrap() params = {"name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_one_kwargs(root_verify_key, **params), - base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - assert result.ok() == obj + result = base_stash.get_one( + root_verify_key, + filters=params, + ).unwrap() + assert result == obj params = {"name": random_name, "desc": random_sentence(faker)} - for result in [ - base_stash.query_one_kwargs(root_verify_key, **params), - base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_err() - assert isinstance(result.err(), NotFoundException) + with pytest.raises(NotFoundException): + result = base_stash.get_one( + root_verify_key, + filters=params, + ).unwrap() + + +def test_basestash_query_enum( + root_verify_key, base_stash: MockStash, mock_object: MockObject +) -> None: + base_stash.set(root_verify_key, mock_object).unwrap() + result = base_stash.get_one( + root_verify_key, + filters={"status": Status.CREATED}, + ).unwrap() + + assert result == mock_object + with pytest.raises(NotFoundException): + result = base_stash.get_one( + root_verify_key, + filters={"status": Status.PROCESSING}, + ).unwrap() + + +def test_basestash_query_linked_obj( + root_verify_key, base_stash: MockStash, mock_object: MockObject +) -> None: + mock_object.linked_obj = LinkedObject( + object_type=MockObject, + object_uid=UID(), + id=UID(), + tags=["tag1", "tag2"], + server_uid=UID(), + service_type=RequestService, + ) + base_stash.set(root_verify_key, mock_object).unwrap() + + result = base_stash.get_one( + root_verify_key, + filters={"linked_obj.id": mock_object.linked_obj.id}, + ).unwrap() + + assert result == mock_object def test_basestash_query_all( @@ -339,46 +347,30 @@ def test_basestash_query_all( for obj in all_objects: base_stash.set(root_verify_key, obj) - for result in [ - base_stash.query_all_kwargs(root_verify_key, desc=desc), - base_stash.query_all( - root_verify_key, QueryKey.from_obj(DescPartitionKey, desc) - ), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == n_same - assert all(obj.desc == desc for obj in objects) - original_object_values = {get_object_values(obj) for obj in similar_objects} - retrived_objects_values = {get_object_values(obj) for obj in objects} - assert original_object_values == retrived_objects_values + objects = base_stash.get_all(root_verify_key, filters={"desc": desc}).unwrap() + assert len(objects) == n_same + assert all(obj.desc == desc for obj in objects) + original_object_values = {get_object_values(obj) for obj in similar_objects} + retrived_objects_values = {get_object_values(obj) for obj in objects} + assert original_object_values == retrived_objects_values random_desc = create_unique( random_sentence, [obj.desc for obj in all_objects], faker ) - for result in [ - base_stash.query_all_kwargs(root_verify_key, desc=random_desc), - base_stash.query_all( - root_verify_key, QueryKey.from_obj(DescPartitionKey, random_desc) - ), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 0 + + objects = base_stash.get_all( + root_verify_key, filters={"desc": random_desc} + ).unwrap() + assert len(objects) == 0 obj = random.choice(similar_objects) params = {"name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == sum( - 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc) - ) - assert objects[0] == obj + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() + assert len(objects) == sum( + 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc) + ) + assert objects[0] == obj def test_basestash_query_all_kwargs_multiple_params( @@ -397,66 +389,23 @@ def test_basestash_query_all_kwargs_multiple_params( base_stash.set(root_verify_key, obj) params = {"importance": importance, "desc": desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == n_same - assert all(obj.desc == desc for obj in objects) - original_object_values = {get_object_values(obj) for obj in similar_objects} - retrived_objects_values = {get_object_values(obj) for obj in objects} - assert original_object_values == retrived_objects_values + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() + assert len(objects) == n_same + assert all(obj.desc == desc for obj in objects) + original_object_values = {get_object_values(obj) for obj in similar_objects} + retrived_objects_values = {get_object_values(obj) for obj in objects} + assert original_object_values == retrived_objects_values params = { "name": create_unique(faker.name, [obj.name for obj in all_objects]), "desc": random_sentence(faker), } - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 0 + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() + assert len(objects) == 0 obj = random.choice(similar_objects) params = {"id": obj.id, "name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 1 - assert objects[0] == obj - - -def test_basestash_cannot_query_non_searchable( - root_verify_key, base_stash: MockStash, mock_objects: list[MockObject] -) -> None: - for obj in mock_objects: - base_stash.set(root_verify_key, obj) - - obj = random.choice(mock_objects) - - assert base_stash.query_one_kwargs(root_verify_key, value=10).is_err() - assert base_stash.query_all_kwargs(root_verify_key, value=10).is_err() - assert base_stash.query_one_kwargs( - root_verify_key, value=10, name=obj.name - ).is_err() - assert base_stash.query_all_kwargs( - root_verify_key, value=10, name=obj.name - ).is_err() - - ValuePartitionKey = PartitionKey(key="value", type_=int) - qk = ValuePartitionKey.with_obj(10) - - assert base_stash.query_one(root_verify_key, qk).is_err() - assert base_stash.query_all(root_verify_key, qk).is_err() - assert base_stash.query_all(root_verify_key, QueryKeys(qks=[qk])).is_err() - assert base_stash.query_all( - root_verify_key, QueryKeys(qks=[qk, UIDPartitionKey.with_obj(obj.id)]) - ).is_err() + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() + assert len(objects) == 1 + assert objects[0] == obj diff --git a/packages/syft/tests/syft/stores/dict_document_store_test.py b/packages/syft/tests/syft/stores/dict_document_store_test.py deleted file mode 100644 index e04414d666c..00000000000 --- a/packages/syft/tests/syft/stores/dict_document_store_test.py +++ /dev/null @@ -1,358 +0,0 @@ -# stdlib -from threading import Thread - -# syft absolute -from syft.store.dict_document_store import DictStorePartition -from syft.store.document_store import QueryKeys -from syft.types.uid import UID - -# relative -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - - -def test_dict_store_partition_sanity(dict_store_partition: DictStorePartition) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - assert hasattr(dict_store_partition, "data") - assert hasattr(dict_store_partition, "unique_keys") - assert hasattr(dict_store_partition, "searchable_keys") - - -def test_dict_store_partition_set( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - obj = MockSyftObject(id=UID(), data=1) - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = dict_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -def test_dict_store_partition_delete( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - objs = [] - repeats = 5 - for v in range(repeats): - obj = MockSyftObject(data=v) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = dict_store_partition.settings.store_key.with_obj(obj) - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = dict_store_partition.settings.store_key.with_obj(v) - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -def test_dict_store_partition_update( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - dict_store_partition.init_store() - - # add item - obj = MockSyftObject(data=1) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert len(dict_store_partition.all(root_verify_key).ok()) == 1 - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = dict_store_partition.settings.store_key.with_obj(rand_obj) - res = dict_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = dict_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = dict_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, unly the values are updated. - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = dict_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -def test_dict_store_partition_set_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - thread_cnt = 3 - repeats = 5 - - dict_store_partition.init_store() - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - obj = MockObjectType(data=idx) - - for _ in range(10): - res = dict_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - stored_cnt = len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == repeats * thread_cnt - - -def test_dict_store_partition_update_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - thread_cnt = 3 - repeats = 5 - dict_store_partition.init_store() - - obj = MockSyftObject(data=0) - key = dict_store_partition.settings.store_key.with_obj(obj) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = dict_store_partition.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -def test_dict_store_partition_set_delete_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - dict_store_partition.init_store() - - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = dict_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = dict_store_partition.settings.store_key.with_obj(obj) - - res = dict_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - stored_cnt = len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py deleted file mode 100644 index 95df806c189..00000000000 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ /dev/null @@ -1,1045 +0,0 @@ -# stdlib -from secrets import token_hex -from threading import Thread - -# third party -import pytest - -# syft absolute -from syft.server.credentials import SyftVerifyKey -from syft.service.action.action_permissions import ActionObjectPermission -from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_permissions import StoragePermission -from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectOWNER -from syft.service.action.action_store import ActionObjectREAD -from syft.service.action.action_store import ActionObjectWRITE -from syft.store.document_store import PartitionSettings -from syft.store.document_store import QueryKey -from syft.store.document_store import QueryKeys -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.mongo_document_store import MongoStorePartition -from syft.types.errors import SyftException -from syft.types.uid import UID - -# relative -from ...mongomock.collection import Collection as MongoCollection -from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_fixtures_test import mongo_store_partition_fn -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - -PERMISSIONS = [ - ActionObjectOWNER, - ActionObjectREAD, - ActionObjectWRITE, - ActionObjectEXECUTE, -] - - -def test_mongo_store_partition_sanity( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - assert hasattr(mongo_store_partition, "_collection") - assert hasattr(mongo_store_partition, "_permissions") - - -@pytest.mark.skip(reason="Test gets stuck at store.init_store()") -def test_mongo_store_partition_init_failed(root_verify_key) -> None: - # won't connect - mongo_config = MongoStoreClientConfig( - connectTimeoutMS=1, - timeoutMS=1, - ) - - store_config = MongoStoreConfig(client_config=mongo_config) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - store = MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - res = store.init_store() - assert res.is_err() - - -def test_mongo_store_partition_set( - root_verify_key, mongo_store_partition: MongoStorePartition -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - obj = MockSyftObject(data=1) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = mongo_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -def test_mongo_store_partition_delete( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - repeats = 5 - - objs = [] - for v in range(repeats): - obj = MockSyftObject(data=v) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(obj) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = mongo_store_partition.settings.store_key.with_obj(v) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -def test_mongo_store_partition_update( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - mongo_store_partition.init_store() - - # add item - obj = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(rand_obj) - res = mongo_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = mongo_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = mongo_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, only the values are updated. - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = mongo_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -def test_mongo_store_partition_set_threading(root_verify_key, mongo_client) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for idx in range(repeats): - obj = MockObjectType(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - return execution_err - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == thread_cnt * repeats - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_joblib( -# root_verify_key, -# mongo_client, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# for idx in range(repeats): -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockObjectType(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == thread_cnt * repeats - - -def test_mongo_store_partition_update_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - - mongo_db_name = token_hex(8) - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - obj = MockSyftObject(data=0) - key = mongo_store_partition.settings.store_key.with_obj(obj) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition_local = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = mongo_store_partition_local.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_update_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 - -# mongo_db_name = token_hex(8) - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockSyftObject(data=0) -# key = mongo_store_partition.settings.store_key.with_obj(obj) -# mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition_local = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# for repeat in range(repeats): -# obj = MockSyftObject(data=repeat) - -# for _ in range(10): -# res = mongo_store_partition_local.update(root_verify_key, key, obj) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - - -def test_mongo_store_partition_set_delete_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = mongo_store_partition.settings.store_key.with_obj(obj) - - res = mongo_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_delete_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, root_verify_key, mongo_db_name=mongo_db_name -# ) - -# for idx in range(repeats): -# obj = MockSyftObject(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# key = mongo_store_partition.settings.store_key.with_obj(obj) - -# res = mongo_store_partition.delete(root_verify_key, key) -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == 0 - - -def test_mongo_store_partition_permissions_collection( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - collection_permissions_status = mongo_store_partition.permissions - assert not collection_permissions_status.is_err() - collection_permissions = collection_permissions_status.ok() - assert isinstance(collection_permissions, MongoCollection) - - -def test_mongo_store_partition_add_remove_permission( - root_verify_key: SyftVerifyKey, mongo_store_partition: MongoStorePartition -) -> None: - """ - Test the add_permission and remove_permission functions of MongoStorePartition - """ - # setting up - res = mongo_store_partition.init_store() - assert res.is_ok() - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add the first permission - obj_read_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_read_permission) - find_res_1 = permissions_collection.find_one({"_id": obj_read_permission.uid}) - assert find_res_1 is not None - assert len(find_res_1["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # add the second permission - obj_write_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_write_permission) - - find_res_2 = permissions_collection.find_one({"_id": obj.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - assert find_res_2["permissions"] == { - obj_read_permission.permission_string, - obj_write_permission.permission_string, - } - - # add duplicated permission - mongo_store_partition.add_permission(obj_write_permission) - find_res_3 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_3["permissions"]) == 2 - assert find_res_3["permissions"] == find_res_2["permissions"] - - # remove the write permission - mongo_store_partition.remove_permission(obj_write_permission) - find_res_4 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_4["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # remove a non-existent permission - with pytest.raises(SyftException): - mongo_store_partition.remove_permission( - ActionObjectPermission( - uid=obj.id, - permission=ActionPermission.OWNER, - credentials=root_verify_key, - ) - ) - find_res_5 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_5["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # there is only one permission object - assert permissions_collection.count_documents({}) == 1 - - # add permissions in a loop - new_permissions = [] - repeats = 5 - for idx in range(1, repeats + 1): - new_obj = MockSyftObject(data=idx) - new_obj_read_permission = ActionObjectPermission( - uid=new_obj.id, - permission=ActionPermission.READ, - credentials=root_verify_key, - ) - new_permissions.append(new_obj_read_permission) - mongo_store_partition.add_permission(new_obj_read_permission) - assert permissions_collection.count_documents({}) == 1 + idx - - # remove all the permissions added in the loop - for permission in new_permissions: - mongo_store_partition.remove_permission(permission) - - assert permissions_collection.count_documents({}) == 1 - - -def test_mongo_store_partition_add_remove_storage_permission( - root_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the add_storage_permission and remove_storage_permission functions of MongoStorePartition - """ - - obj = MockSyftObject(data=1) - - storage_permission = StoragePermission( - uid=obj.id, - server_uid=UID(), - ) - assert not mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.add_storage_permission(storage_permission) - assert mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.remove_storage_permission(storage_permission) - assert not mongo_store_partition.has_storage_permission(storage_permission) - - obj2 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj2, add_storage_permission=False) - storage_permission3 = StoragePermission( - uid=obj2.id, server_uid=mongo_store_partition.server_uid - ) - assert not mongo_store_partition.has_storage_permission(storage_permission3) - - obj3 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj3, add_storage_permission=True) - storage_permission4 = StoragePermission( - uid=obj3.id, server_uid=mongo_store_partition.server_uid - ) - assert mongo_store_partition.has_storage_permission(storage_permission4) - - -def test_mongo_store_partition_add_permissions( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add multiple permissions for the first object - permission_1 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - permission_2 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.OWNER, credentials=root_verify_key - ) - permission_3 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=guest_verify_key - ) - permissions: list[ActionObjectPermission] = [ - permission_1, - permission_2, - permission_3, - ] - mongo_store_partition.add_permissions(permissions) - - # check if the permissions have been added properly - assert permissions_collection.count_documents({}) == 1 - find_res = permissions_collection.find_one({"_id": obj.id}) - assert find_res is not None - assert len(find_res["permissions"]) == 3 - - # add permissions for the second object - obj_2 = MockSyftObject(data=2) - permission_4 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - permission_5 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permissions([permission_4, permission_5]) - - assert permissions_collection.count_documents({}) == 2 - find_res_2 = permissions_collection.find_one({"_id": obj_2.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - - -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_has_permission( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - res = mongo_store_partition.init_store() - assert res.is_ok() - - # root permission - obj = MockSyftObject(data=1) - permission_root = permission(uid=obj.id, credentials=root_verify_key) - permission_client = permission(uid=obj.id, credentials=guest_verify_key) - permission_hacker = permission(uid=obj.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_root) - # only the root user has access to this permission - assert mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - # client permission for another object - obj_2 = MockSyftObject(data=2) - permission_client_2 = permission(uid=obj_2.id, credentials=guest_verify_key) - permission_root_2 = permission(uid=obj_2.id, credentials=root_verify_key) - permisson_hacker_2 = permission(uid=obj_2.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_client_2) - # the root (admin) and guest client should have this permission - assert mongo_store_partition.has_permission(permission_root_2) - assert mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - # remove permissions - mongo_store_partition.remove_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - mongo_store_partition.remove_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permission_root_2) - assert not mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_take_ownership( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - obj = MockSyftObject(data=1) - - # the guest client takes ownership of obj - mongo_store_partition.take_ownership( - uid=obj.id, credentials=guest_verify_key - ).unwrap() - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=guest_verify_key) - ) - # the root client will also has the permission - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - - # hacker or root try to take ownership of the obj and will fail - res = mongo_store_partition.take_ownership( - uid=obj.id, credentials=hacker_verify_key - ) - res_2 = mongo_store_partition.take_ownership( - uid=obj.id, credentials=root_verify_key - ) - assert res.is_err() - assert res_2.is_err() - assert ( - res.value.public_message - == res_2.value.public_message - == f"UID: {obj.id} already owned." - ) - - # another object - obj_2 = MockSyftObject(data=2) - # root client takes ownership - mongo_store_partition.take_ownership(uid=obj_2.id, credentials=root_verify_key) - assert mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=guest_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=hacker_verify_key) - ) - - -def test_mongo_store_partition_permissions_set( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the permissions functionalities when using MongoStorePartition._set function - """ - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - res = mongo_store_partition.init_store() - assert res.is_ok() - - # set the object to mongo_store_partition.collection - obj = MockSyftObject(data=1) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj - - # check if the corresponding permissions has been added to the permissions - # collection after the root client claim it - pemissions_collection = mongo_store_partition.permissions.ok() - assert isinstance(pemissions_collection, MongoCollection) - permissions = pemissions_collection.find_one({"_id": obj.id}) - assert permissions is not None - assert isinstance(permissions["permissions"], set) - assert len(permissions["permissions"]) == 4 - for permission in PERMISSIONS: - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - # the hacker tries to set duplicated object but should not be able to claim it - res_2 = mongo_store_partition.set(guest_verify_key, obj, ignore_duplicates=True) - assert res_2.is_ok() - for permission in PERMISSIONS: - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - -def test_mongo_store_partition_permissions_get_all( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - # set several objects for the root and guest client - num_root_objects: int = 5 - num_guest_objects: int = 3 - for i in range(num_root_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - for i in range(num_guest_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj, ignore_duplicates=False - ) - - assert ( - len(mongo_store_partition.all(root_verify_key).ok()) - == num_root_objects + num_guest_objects - ) - assert len(mongo_store_partition.all(guest_verify_key).ok()) == num_guest_objects - assert len(mongo_store_partition.all(hacker_verify_key).ok()) == 0 - - -def test_mongo_store_partition_permissions_delete( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - collection: MongoCollection = mongo_store_partition.collection.ok() - pemissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - # guest or hacker can't delete it - assert not mongo_store_partition.delete(guest_verify_key, qk).is_ok() - assert not mongo_store_partition.delete(hacker_verify_key, qk).is_ok() - # only the root client can delete it - assert mongo_store_partition.delete(root_verify_key, qk).is_ok() - # check if the object and its permission have been deleted - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set an object - obj_2 = MockSyftObject(data=2) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_2, ignore_duplicates=False - ) - qk_2: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_2) - # the hacker can't delete it - assert not mongo_store_partition.delete(hacker_verify_key, qk_2).is_ok() - # the guest client can delete it - assert mongo_store_partition.delete(guest_verify_key, qk_2).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set another object - obj_3 = MockSyftObject(data=3) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_3, ignore_duplicates=False - ) - qk_3: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_3) - # the root client also has the permission to delete it - assert mongo_store_partition.delete(root_verify_key, qk_3).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - -def test_mongo_store_partition_permissions_update( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - assert len(mongo_store_partition.all(credentials=root_verify_key).ok()) == 1 - - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - permsissions: MongoCollection = mongo_store_partition.permissions.ok() - repeats = 5 - - for v in range(repeats): - # the guest client should not have permission to update obj - obj_new = MockSyftObject(data=v) - res = mongo_store_partition.update( - credentials=guest_verify_key, qk=qk, obj=obj_new - ) - assert res.is_err() - # the root client has the permission to update obj - res = mongo_store_partition.update( - credentials=root_verify_key, qk=qk, obj=obj_new - ) - assert res.is_ok() - # the id of the object in the permission collection should not be changed - assert permsissions.find_one(qk.as_dict_mongo)["_id"] == obj.id diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py index 312766e7c4e..d4e2cd25747 100644 --- a/packages/syft/tests/syft/stores/queue_stash_test.py +++ b/packages/syft/tests/syft/stores/queue_stash_test.py @@ -1,26 +1,20 @@ # stdlib -import threading -from threading import Thread -import time -from typing import Any +from concurrent.futures import ThreadPoolExecutor # third party import pytest # syft absolute from syft.service.queue.queue_stash import QueueItem +from syft.service.queue.queue_stash import QueueStash from syft.service.worker.worker_pool import WorkerPool from syft.service.worker.worker_pool_service import SyftWorkerPoolService from syft.store.linked_obj import LinkedObject from syft.types.errors import SyftException from syft.types.uid import UID -# relative -from .store_fixtures_test import mongo_queue_stash_fn -from .store_fixtures_test import sqlite_queue_stash_fn - -def mock_queue_object(): +def mock_queue_object() -> QueueItem: worker_pool_obj = WorkerPool( name="mypool", image_id=UID(), @@ -47,405 +41,246 @@ def mock_queue_object(): @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -def test_queue_stash_sanity(queue: Any) -> None: +def test_queue_stash_sanity(queue: QueueStash) -> None: assert len(queue) == 0 - assert hasattr(queue, "store") - assert hasattr(queue, "partition") @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -# @pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: - objs = [] +# +def test_queue_stash_set_get(root_verify_key, queue: QueueStash) -> None: + objs: list[QueueItem] = [] repeats = 5 for idx in range(repeats): obj = mock_queue_object() objs.append(obj) - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() assert len(queue) == idx + 1 with pytest.raises(SyftException): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() assert len(queue) == idx + 1 assert len(queue.get_all(root_verify_key).ok()) == idx + 1 - item = queue.find_one(root_verify_key, id=obj.id) - assert item.is_ok() - assert item.ok() == obj + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item == obj cnt = len(objs) for obj in objs: - res = queue.find_and_delete(root_verify_key, id=obj.id) - assert res.is_ok() - + queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap() cnt -= 1 assert len(queue) == cnt - item = queue.find_one(root_verify_key, id=obj.id) + item = queue.get_by_uid(root_verify_key, uid=obj.id) assert item.is_err() @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_stash_update(root_verify_key, queue: Any) -> None: +def test_queue_stash_update(queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() repeats = 5 for idx in range(repeats): obj.args = [idx] - res = queue.update(root_verify_key, obj) - assert res.is_ok() + queue.update(root_verify_key, obj).unwrap() assert len(queue) == 1 - item = queue.find_one(root_verify_key, id=obj.id) - assert item.is_ok() - assert item.ok().args == [idx] + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item.args == [idx] - res = queue.find_and_delete(root_verify_key, id=obj.id) - assert res.is_ok() + queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap() assert len(queue) == 0 @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_set_existing_queue_threading(root_verify_key, queue: Any) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for _ in range(repeats): - obj = mock_queue_object() - - for _ in range(10): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == thread_cnt * repeats +def test_queue_set_existing_queue_threading(root_verify_key, queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key + items_to_create = 100 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue.set( + root_verify_key, + mock_queue_object(), + ), + range(items_to_create), + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue) == items_to_create @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_update_existing_queue_threading(root_verify_key, queue: Any) -> None: - thread_cnt = 3 - repeats = 5 - +def test_queue_update_existing_queue_threading(queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key obj = mock_queue_object() - queue.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for repeat in range(repeats): - obj.args = [repeat] - for _ in range(10): - res = queue.update(root_verify_key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() + def update_queue(): + obj.args = [UID()] + res = queue.update(root_verify_key, obj) + return res - tids.append(thread) + queue.set(root_verify_key, obj, ignore_duplicates=False) - for thread in tids: - thread.join() + with ThreadPoolExecutor(max_workers=3) as executor: + # Run the update_queue function in multiple threads + results = list( + executor.map( + lambda _: update_queue(), + range(5), + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" - assert execution_err is None + assert len(queue) == 1 + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item.args != [] @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_set_delete_existing_queue_threading( - root_verify_key, - queue: Any, + queue: QueueStash, ) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - objs = [] - - for _ in range(repeats * thread_cnt): - obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert res.is_ok() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - item_idx = tid * repeats + idx - - for _ in range(10): - res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == 0 - - -def helper_queue_set_threading(root_verify_key, create_queue_cbk) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue = create_queue_cbk() - - for _ in range(repeats): - obj = mock_queue_object() - - for _ in range(10): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - queue = create_queue_cbk() - - assert execution_err is None - assert len(queue) == thread_cnt * repeats - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_set_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_set_threading(root_verify_key, create_queue_cbk) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_set_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_threading(root_verify_key, create_queue_cbk) - - -def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None: - thread_cnt = 3 - repeats = 5 - - queue = create_queue_cbk() - time.sleep(1) - + root_verify_key = queue.db.root_verify_key + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue.set( + root_verify_key, + mock_queue_object(), + ), + range(15), + ) + ) + objs = [item.unwrap() for item in results] + + results = list( + executor.map( + lambda obj: queue.delete_by_uid(root_verify_key, uid=obj.id), + objs, + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + + +def test_queue_set(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid + + def set_in_new_thread(_): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + return queue_stash.set(root_verify_key, mock_queue_object()) + + total_repeats = 50 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + set_in_new_thread, + range(total_repeats), + ) + ) + + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue_stash) == total_repeats + + +def test_queue_update_threading(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid obj = mock_queue_object() - queue.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue_local = create_queue_cbk() - - for repeat in range(repeats): - obj.args = [repeat] - - for _ in range(10): - res = queue_local.update(root_verify_key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_update_threading(root_verify_key, create_queue_cbk) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_update_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_update_threading(root_verify_key, create_queue_cbk) - - -def helper_queue_set_delete_threading( - root_verify_key, - create_queue_cbk, -) -> None: - thread_cnt = 3 - repeats = 5 - - queue = create_queue_cbk() - execution_err = None - objs = [] - - for _ in range(repeats * thread_cnt): - obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert res.is_ok() - - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue = create_queue_cbk() - for idx in range(repeats): - item_idx = tid * repeats + idx - - for _ in range(10): - res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == 0 - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) + queue_stash.set(root_verify_key, obj).unwrap() + + def update_in_new_thread(_): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + obj.args = [UID()] + return queue_stash.update(root_verify_key, obj) + + total_repeats = 50 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + update_in_new_thread, + range(total_repeats), + ) + ) + + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue_stash) == 1 + + +def test_queue_delete_threading(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid + + def delete_in_new_thread(obj: QueueItem): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + return queue_stash.delete_by_uid(root_verify_key, uid=obj.id) + + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue_stash.set( + root_verify_key, + mock_queue_object(), + ), + range(50), + ) + ) + objs = [item.unwrap() for item in results] + + results = list( + executor.map( + delete_in_new_thread, + objs, + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + + assert len(queue_stash) == 0 diff --git a/packages/syft/tests/syft/stores/sqlite_document_store_test.py b/packages/syft/tests/syft/stores/sqlite_document_store_test.py deleted file mode 100644 index 46ee540aa9c..00000000000 --- a/packages/syft/tests/syft/stores/sqlite_document_store_test.py +++ /dev/null @@ -1,520 +0,0 @@ -# stdlib -from threading import Thread - -# third party -import pytest - -# syft absolute -from syft.store.document_store import QueryKeys -from syft.store.sqlite_document_store import SQLiteStorePartition - -# relative -from .store_fixtures_test import sqlite_store_partition_fn -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - - -def test_sqlite_store_partition_sanity( - sqlite_store_partition: SQLiteStorePartition, -) -> None: - assert hasattr(sqlite_store_partition, "data") - assert hasattr(sqlite_store_partition, "unique_keys") - assert hasattr(sqlite_store_partition, "searchable_keys") - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - obj = MockSyftObject(data=1) - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = sqlite_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_delete( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - objs = [] - repeats = 5 - for v in range(repeats): - obj = MockSyftObject(data=v) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = sqlite_store_partition.settings.store_key.with_obj(obj) - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = sqlite_store_partition.settings.store_key.with_obj(v) - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_update( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - # add item - obj = MockSyftObject(data=1) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = sqlite_store_partition.settings.store_key.with_obj(rand_obj) - res = sqlite_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = sqlite_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = sqlite_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, unly the values are updated. - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = sqlite_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set_threading( - sqlite_workspace: tuple, - root_verify_key, -) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - for idx in range(repeats): - for _ in range(10): - obj = MockObjectType(data=idx) - res = sqlite_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - return execution_err - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - stored_cnt = len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == thread_cnt * repeats - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_set_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# def _kv_cbk(tid: int) -> None: -# for idx in range(repeats): -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# obj = MockObjectType(data=idx) - -# for _ in range(10): -# res = sqlite_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# stored_cnt = len( -# sqlite_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == thread_cnt * repeats - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_update_threading( - root_verify_key, - sqlite_workspace: tuple, -) -> None: - thread_cnt = 3 - repeats = 5 - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - obj = MockSyftObject(data=0) - key = sqlite_store_partition.settings.store_key.with_obj(obj) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - sqlite_store_partition_local = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = sqlite_store_partition_local.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_update_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# obj = MockSyftObject(data=0) -# key = sqlite_store_partition.settings.store_key.with_obj(obj) -# sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - -# def _kv_cbk(tid: int) -> None: -# sqlite_store_partition_local = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# for repeat in range(repeats): -# obj = MockSyftObject(data=repeat) - -# for _ in range(10): -# res = sqlite_store_partition_local.update(root_verify_key, key, obj) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set_delete_threading( - root_verify_key, - sqlite_workspace: tuple, -) -> None: - thread_cnt = 3 - repeats = 5 - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = sqlite_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = sqlite_store_partition.settings.store_key.with_obj(obj) - - res = sqlite_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - stored_cnt = len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_set_delete_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# def _kv_cbk(tid: int) -> None: -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) - -# for idx in range(repeats): -# obj = MockSyftObject(data=idx) - -# for _ in range(10): -# res = sqlite_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# key = sqlite_store_partition.settings.store_key.with_obj(obj) - -# res = sqlite_store_partition.delete(root_verify_key, key) -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) -# for execution_err in errs: -# assert execution_err is None - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# stored_cnt = len( -# sqlite_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == 0 diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index b64d14be8e3..47452e9740e 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -1,72 +1,30 @@ # stdlib -from collections.abc import Generator -import os -from pathlib import Path -from secrets import token_hex -import tempfile import uuid -# third party -import pytest - # syft absolute from syft.server.credentials import SyftVerifyKey from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_store import DictActionStore -from syft.service.action.action_store import MongoActionStore -from syft.service.action.action_store import SQLiteActionStore -from syft.service.queue.queue_stash import QueueStash from syft.service.user.user import User from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.dict_document_store import DictDocumentStore -from syft.store.dict_document_store import DictStoreConfig -from syft.store.dict_document_store import DictStorePartition +from syft.store.db.sqlite import SQLiteDBConfig +from syft.store.db.sqlite import SQLiteDBManager from syft.store.document_store import DocumentStore -from syft.store.document_store import PartitionSettings -from syft.store.locks import LockingConfig -from syft.store.locks import NoLockingConfig -from syft.store.locks import ThreadingLockingConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoDocumentStore -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.mongo_document_store import MongoStorePartition -from syft.store.sqlite_document_store import SQLiteDocumentStore -from syft.store.sqlite_document_store import SQLiteStoreClientConfig -from syft.store.sqlite_document_store import SQLiteStoreConfig -from syft.store.sqlite_document_store import SQLiteStorePartition from syft.types.uid import UID # relative from .store_constants_test import TEST_SIGNING_KEY_NEW_ADMIN from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN -from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT -from .store_mocks_test import MockObjectType - -MONGO_CLIENT_CACHE = None - -locking_scenarios = [ - "nop", - "threading", -] - - -def str_to_locking_config(conf: str) -> LockingConfig: - if conf == "nop": - return NoLockingConfig() - elif conf == "threading": - return ThreadingLockingConfig() - else: - raise NotImplementedError(f"unknown locking config {conf}") def document_store_with_admin( server_uid: UID, verify_key: SyftVerifyKey ) -> DocumentStore: - document_store = DictDocumentStore( - server_uid=server_uid, root_verify_key=verify_key + config = SQLiteDBConfig() + document_store = SQLiteDBManager( + server_uid=server_uid, root_verify_key=verify_key, config=config ) password = uuid.uuid4().hex @@ -94,307 +52,3 @@ def document_store_with_admin( ) return document_store - - -@pytest.fixture(scope="function") -def sqlite_workspace() -> Generator: - sqlite_db_name = token_hex(8) + ".sqlite" - root = os.getenv("SYFT_TEMP_ROOT", "syft") - sqlite_workspace_folder = Path( - tempfile.gettempdir(), root, "fixture_sqlite_workspace" - ) - sqlite_workspace_folder.mkdir(parents=True, exist_ok=True) - - db_path = sqlite_workspace_folder / sqlite_db_name - - if db_path.exists(): - db_path.unlink() - - yield sqlite_workspace_folder, sqlite_db_name - - try: - db_path.exists() and db_path.unlink() - except BaseException as e: - print("failed to cleanup sqlite db", e) - - -def sqlite_store_partition_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "nop", -): - workspace, db_name = sqlite_workspace - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, locking_config=locking_config - ) - - settings = PartitionSettings(name="test", object_type=MockObjectType) - - store = SQLiteStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - store.init_store().unwrap() - - return store - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_store_partition( - root_verify_key, sqlite_workspace: tuple[Path, str], request -): - locking_config_name = request.param - store = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) - - yield store - - -def sqlite_document_store_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "nop", -): - workspace, db_name = sqlite_workspace - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, locking_config=locking_config - ) - - return SQLiteDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_document_store(root_verify_key, sqlite_workspace: tuple[Path, str], request): - locking_config_name = request.param - store = sqlite_document_store_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) - yield store - - -def sqlite_queue_stash_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "threading", -): - store = sqlite_document_store_fn( - root_verify_key, - sqlite_workspace, - locking_config_name=locking_config_name, - ) - return QueueStash(store=store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_queue_stash(root_verify_key, sqlite_workspace: tuple[Path, str], request): - locking_config_name = request.param - yield sqlite_queue_stash_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): - workspace, db_name = sqlite_workspace - locking_config_name = request.param - - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, - locking_config=locking_config, - ) - - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - - yield SQLiteActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - -def mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - mongo_config = MongoStoreClientConfig(client=mongo_client) - - locking_config = str_to_locking_config(locking_config_name) - - store_config = MongoStoreConfig( - client_config=mongo_config, - db_name=mongo_db_name, - locking_config=locking_config, - ) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - return MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_store_partition(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield partition - - # cleanup db - try: - mongo_client.drop_database(mongo_db_name) - except BaseException as e: - print("failed to cleanup mongo fixture", e) - - -def mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - locking_config = str_to_locking_config(locking_config_name) - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - - mongo_client.drop_database(mongo_db_name) - - return MongoDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_document_store(root_verify_key, mongo_client, request): - locking_config_name = request.param - mongo_db_name = token_hex(8) - yield mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - - -def mongo_queue_stash_fn(mongo_document_store): - return QueueStash(store=mongo_document_store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_queue_stash(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - store = mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield mongo_queue_stash_fn(store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_action_store(mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - locking_config = str_to_locking_config(locking_config_name) - - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - mongo_action_store = MongoActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - yield mongo_action_store - - -def dict_store_partition_fn( - root_verify_key, - locking_config_name: str = "nop", -): - locking_config = str_to_locking_config(locking_config_name) - store_config = DictStoreConfig(locking_config=locking_config) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - return DictStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_store_partition(root_verify_key, request): - locking_config_name = request.param - yield dict_store_partition_fn( - root_verify_key, locking_config_name=locking_config_name - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_action_store(request): - locking_config_name = request.param - locking_config = str_to_locking_config(locking_config_name) - - store_config = DictStoreConfig(locking_config=locking_config) - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - - yield DictActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - -def dict_document_store_fn(root_verify_key, locking_config_name: str = "nop"): - locking_config = str_to_locking_config(locking_config_name) - store_config = DictStoreConfig(locking_config=locking_config) - return DictDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_document_store(root_verify_key, request): - locking_config_name = request.param - yield dict_document_store_fn( - root_verify_key, locking_config_name=locking_config_name - ) - - -def dict_queue_stash_fn(dict_document_store): - return QueueStash(store=dict_document_store) - - -@pytest.fixture(scope="function") -def dict_queue_stash(dict_document_store): - yield dict_queue_stash_fn(dict_document_store) diff --git a/packages/syft/tests/syft/types/errors_test.py b/packages/syft/tests/syft/types/errors_test.py index 4ac185ba421..ca8e557ef11 100644 --- a/packages/syft/tests/syft/types/errors_test.py +++ b/packages/syft/tests/syft/types/errors_test.py @@ -5,6 +5,7 @@ import pytest # syft absolute +import syft from syft.service.context import AuthedServiceContext from syft.service.user.user_roles import ServiceRole from syft.types.errors import SyftException @@ -52,3 +53,24 @@ def test_get_message(role, private_msg, public_msg, expected_message): mock_context.dev_mode = False exception = SyftException(private_msg, public_message=public_msg) assert exception.get_message(mock_context) == expected_message + + +def test_syfterror_raise_works_in_pytest(): + """ + SyftError has own exception handler that wasnt working in notebook testing environments, + this is just a sanity check to make sure it works in pytest. + """ + with pytest.raises(SyftException): + raise SyftException(public_message="-") + + with syft.raises(SyftException(public_message="-")): + raise SyftException(public_message="-") + + # syft.raises works with wildcard + with syft.raises(SyftException(public_message="*test message*")): + raise SyftException(public_message="longer test message") + + # syft.raises with different public message should raise + with pytest.raises(AssertionError): + with syft.raises(SyftException(public_message="*different message*")): + raise SyftException(public_message="longer test message") diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 9410cecd695..5de444197ba 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -16,6 +16,7 @@ from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.service.user.user import User +from syft.service.user.user import UserView from syft.service.user.user_roles import ServiceRole from syft.types.errors import SyftException @@ -46,12 +47,10 @@ def test_repr_markdown_not_throwing_error(guest_client: DatasiteClient) -> None: assert result[0]._repr_markdown_() -@pytest.mark.parametrize("delete_original_admin", [False, True]) def test_new_admin_can_list_user_code( worker: Worker, ds_client: DatasiteClient, faker: Faker, - delete_original_admin: bool, ) -> None: root_client = worker.root_client @@ -66,17 +65,16 @@ def test_new_admin_can_list_user_code( admin = root_client.login(email=email, password=pw) - root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN) - - if delete_original_admin: - res = root_client.api.services.user.delete(root_client.account.id) - assert not isinstance(res, SyftError) + result: UserView = root_client.api.services.user.update( + uid=admin.account.id, role=ServiceRole.ADMIN + ) + assert result.role == ServiceRole.ADMIN user_code_stash = worker.services.user_code.stash - user_code = user_code_stash.get_all(user_code_stash.store.root_verify_key).ok() + user_codes = user_code_stash._data - assert len(user_code) == len(admin.code.get_all()) - assert {c.id for c in user_code} == {c.id for c in admin.code} + assert 1 == len(admin.code.get_all()) + assert {c.id for c in user_codes} == {c.id for c in admin.code} def test_user_code(worker) -> None: diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 45e31da18fe..59e905ee657 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -9,7 +9,9 @@ from pytest import MonkeyPatch # syft absolute +import syft as sy from syft import orchestra +from syft.client.client import SyftClient from syft.server.credentials import SyftVerifyKey from syft.server.worker import Worker from syft.service.context import AuthedServiceContext @@ -208,7 +210,7 @@ def test_userservice_get_all_success( expected_output = [x.to(UserView) for x in mock_get_all_output] @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> list[User]: + def mock_get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: return mock_get_all_output monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) @@ -222,23 +224,6 @@ def mock_get_all(credentials: SyftVerifyKey) -> list[User]: ) -def test_userservice_get_all_error( - monkeypatch: MonkeyPatch, - user_service: UserService, - authed_context: AuthedServiceContext, -) -> None: - @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> NoReturn: - raise StashException - - monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) - - with pytest.raises(StashException) as exc: - user_service.get_all(authed_context) - - assert exc.type == StashException - - def test_userservice_search( monkeypatch: MonkeyPatch, user_service: UserService, @@ -246,13 +231,13 @@ def test_userservice_search( guest_user: User, ) -> None: @as_result(SyftException) - def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: + def get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: for key in kwargs.keys(): if hasattr(guest_user, key): return [guest_user] return [] - monkeypatch.setattr(user_service.stash, "find_all", mock_find_all) + monkeypatch.setattr(user_service.stash, "get_all", get_all) expected_output = [guest_user.to(UserView)] @@ -541,27 +526,10 @@ def mock_get_by_email(credentials: SyftVerifyKey, email: str) -> NoReturn: assert exc.value.public_message == expected_output -def test_userservice_admin_verify_key_error( - monkeypatch: MonkeyPatch, user_service: UserService -) -> None: - expected_output = "failed to get admin verify_key" - - def mock_admin_verify_key() -> UID: - raise SyftException(public_message=expected_output) - - monkeypatch.setattr(user_service.stash, "admin_verify_key", mock_admin_verify_key) - - with pytest.raises(SyftException) as exc: - user_service.admin_verify_key() - - assert exc.type == SyftException - assert exc.value.public_message == expected_output - - def test_userservice_admin_verify_key_success( monkeypatch: MonkeyPatch, user_service: UserService, worker ) -> None: - response = user_service.admin_verify_key() + response = user_service.root_verify_key assert isinstance(response, SyftVerifyKey) assert response == worker.root_client.credentials.verify_key @@ -586,7 +554,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> User: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) with pytest.raises(SyftException) as exc: @@ -616,7 +584,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> NoReturn: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) with pytest.raises(StashException) as exc: @@ -645,7 +613,7 @@ def mock_set(*args, **kwargs) -> User: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email) @@ -684,7 +652,7 @@ def mock_set( new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email) @@ -790,3 +758,43 @@ def test_userservice_update_via_client_with_mixed_args(): email="new_user@openmined.org", password="newpassword" ) assert user_client.account.name == "User name" + + +def test_reset_password(): + server = orchestra.launch(name="datasite-test", reset=True) + + datasite_client = server.login(email="info@openmined.org", password="changethis") + datasite_client.register( + email="new_syft_user@openmined.org", + password="verysecurepassword", + password_verify="verysecurepassword", + name="New User", + ) + guest_client: SyftClient = server.login_as_guest() + guest_client.forgot_password(email="new_syft_user@openmined.org") + temp_token = datasite_client.users.request_password_reset( + datasite_client.notifications[-1].linked_obj.resolve.id + ) + guest_client.reset_password(token=temp_token, new_password="Password123") + server.login(email="new_syft_user@openmined.org", password="Password123") + + +def test_root_cannot_be_deleted(): + server = orchestra.launch(name="datasite-test", reset=True) + datasite_client = server.login(email="info@openmined.org", password="changethis") + + new_admin_email = "admin@openmined.org" + new_admin_pass = "changethis2" + datasite_client.register( + name="second admin", + email=new_admin_email, + password=new_admin_pass, + password_verify=new_admin_pass, + ) + # update role + new_user_id = datasite_client.users.search(email=new_admin_email)[0].id + datasite_client.users.update(uid=new_user_id, role="admin") + + new_admin_client = server.login(email=new_admin_email, password=new_admin_pass) + with sy.raises(sy.SyftException): + new_admin_client.users.delete(datasite_client.account.id) diff --git a/packages/syft/tests/syft/users/user_stash_test.py b/packages/syft/tests/syft/users/user_stash_test.py index ee8e4b1edc9..584e616d093 100644 --- a/packages/syft/tests/syft/users/user_stash_test.py +++ b/packages/syft/tests/syft/users/user_stash_test.py @@ -1,5 +1,6 @@ # third party from faker import Faker +import pytest # syft absolute from syft.server.credentials import SyftSigningKey @@ -14,7 +15,7 @@ def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> User: # prepare: add mock data - result = user_stash.partition.set(root_datasite_client.credentials.verify_key, user) + result = user_stash.set(root_datasite_client.credentials.verify_key, user) assert result.is_ok() user = result.ok() @@ -26,30 +27,29 @@ def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> Us def test_userstash_set( root_datasite_client, user_stash: UserStash, guest_user: User ) -> None: - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_ok() - - created_user = result.ok() + created_user = user_stash.set( + root_datasite_client.credentials.verify_key, guest_user + ).unwrap() assert isinstance(created_user, User) assert guest_user == created_user - assert guest_user.id in user_stash.partition.data + assert user_stash.exists( + root_datasite_client.credentials.verify_key, created_user.id + ) def test_userstash_set_duplicate( root_datasite_client, user_stash: UserStash, guest_user: User ) -> None: - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_ok() + _ = user_stash.set(root_datasite_client.credentials.verify_key, guest_user).unwrap() + original_count = len(user_stash._data) - original_count = len(user_stash.partition.data) + with pytest.raises(SyftException) as exc: + _ = user_stash.set( + root_datasite_client.credentials.verify_key, guest_user + ).unwrap() + assert exc.public_message - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_err() - exc = result.err() - assert type(exc) == SyftException - assert exc.public_message - - assert len(user_stash.partition.data) == original_count + assert len(user_stash._data) == original_count def test_userstash_get_by_uid( @@ -171,11 +171,9 @@ def test_userstash_get_by_role( # prepare: add mock data user = add_mock_user(root_datasite_client, user_stash, guest_user) - result = user_stash.get_by_role( + searched_user = user_stash.get_by_role( root_datasite_client.credentials.verify_key, role=ServiceRole.GUEST - ) - assert result.is_ok() - searched_user = result.ok() + ).unwrap() assert user == searched_user diff --git a/packages/syft/tests/syft/users/user_test.py b/packages/syft/tests/syft/users/user_test.py index 1a9830ed48c..4c727b4f1fe 100644 --- a/packages/syft/tests/syft/users/user_test.py +++ b/packages/syft/tests/syft/users/user_test.py @@ -1,5 +1,6 @@ # stdlib from secrets import token_hex +import time # third party from faker import Faker @@ -388,6 +389,9 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non admin_client = get_mock_client(worker.root_client, ServiceRole.ADMIN) assert admin_client.account.role == ServiceRole.ADMIN + # wait for the user to be created for sorting purposes + time.sleep(0.01) + admin_client.register( name="Sheldon Cooper", email="sheldon@caltech.edu", @@ -424,6 +428,7 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non with pytest.raises(SyftException): ds_client.account.update(role="guest") + with pytest.raises(SyftException): ds_client.account.update(role="data_scientist") # now we set sheldon's role to admin. Only now he can change his role diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index ce0f027a101..f52772038cf 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -16,16 +16,17 @@ from syft.server.credentials import SyftVerifyKey from syft.server.worker import Worker from syft.service.action.action_object import ActionObject -from syft.service.action.action_store import DictActionStore +from syft.service.action.action_store import ActionObjectStash from syft.service.context import AuthedServiceContext from syft.service.queue.queue_stash import QueueItem from syft.service.response import SyftError from syft.service.user.user import User from syft.service.user.user import UserCreate from syft.service.user.user import UserView +from syft.service.user.user_stash import UserStash +from syft.store.db.sqlite import SQLiteDBManager from syft.types.errors import SyftException from syft.types.result import Ok -from syft.types.uid import UID test_signing_key_string = ( "b7803e90a6f3f4330afbd943cef3451c716b338b17a9cf40a0a309bc38bc366d" @@ -75,30 +76,42 @@ def test_signing_key() -> None: assert test_verify_key == test_verify_key_2 -def test_action_store() -> None: +@pytest.fixture( + scope="function", + params=[ + "tODOsqlite_address", + # "TODOpostgres_address", # will be used when we have a postgres CI tests + ], +) +def action_object_stash() -> ActionObjectStash: + root_verify_key = SyftVerifyKey.from_string(test_verify_key_string) + db_manager = SQLiteDBManager.random(root_verify_key=root_verify_key) + stash = ActionObjectStash(store=db_manager) + _ = UserStash(store=db_manager) + stash.db.init_tables() + yield stash + + +def test_action_store(action_object_stash: ActionObjectStash) -> None: test_signing_key = SyftSigningKey.from_string(test_signing_key_string) - action_store = DictActionStore(server_uid=UID()) - uid = UID() + test_verify_key = test_signing_key.verify_key raw_data = np.array([1, 2, 3]) test_object = ActionObject.from_obj(raw_data) + uid = test_object.id - set_result = action_store.set( + action_object_stash.set_or_update( uid=uid, - credentials=test_signing_key, + credentials=test_verify_key, syft_object=test_object, has_result_read_permission=True, - ) - assert set_result.is_ok() - test_object_result = action_store.get(uid=uid, credentials=test_signing_key) - assert test_object_result.is_ok() - assert (test_object == test_object_result.ok()).all() + ).unwrap() + from_stash = action_object_stash.get(uid=uid, credentials=test_verify_key).unwrap() + assert (test_object == from_stash).all() test_verift_key_2 = SyftVerifyKey.from_string(test_verify_key_string_2) - test_object_result_fail = action_store.get(uid=uid, credentials=test_verift_key_2) - assert test_object_result_fail.is_err() - exc = test_object_result_fail.err() - assert type(exc) == SyftException - assert "denied" in exc.public_message + with pytest.raises(SyftException) as exc: + action_object_stash.get(uid=uid, credentials=test_verift_key_2).unwrap() + assert "denied" in exc.public_message def test_user_transform() -> None: @@ -223,14 +236,6 @@ def post_add(context: Any, name: str, new_result: Any) -> Any: action_object.syft_post_hooks__["__add__"] = [] -def test_worker_serde(worker) -> None: - ser = sy.serialize(worker, to_bytes=True) - de = sy.deserialize(ser, from_bytes=True) - - assert de.signing_key == worker.signing_key - assert de.id == worker.id - - @pytest.fixture(params=[0]) def worker_with_proc(request): worker = Worker( diff --git a/packages/syftcli/manifest.yml b/packages/syftcli/manifest.yml index c17457f6e6b..4b8754f154f 100644 --- a/packages/syftcli/manifest.yml +++ b/packages/syftcli/manifest.yml @@ -6,7 +6,7 @@ dockerTag: 0.9.2-beta.3 images: - docker.io/openmined/syft-frontend:0.9.2-beta.3 - docker.io/openmined/syft-backend:0.9.2-beta.3 - - docker.io/library/mongo:7.0.4 + - docker.io/library/postgres:16.1 - docker.io/traefik:v2.11.0 configFiles: diff --git a/scripts/dev_tools.sh b/scripts/dev_tools.sh index 20a74b597e0..c23b56c05b9 100755 --- a/scripts/dev_tools.sh +++ b/scripts/dev_tools.sh @@ -23,15 +23,13 @@ function docker_list_exposed_ports() { if [[ -z "$1" ]]; then # list db, redis, rabbitmq, and seaweedfs ports - docker_list_exposed_ports "db\|seaweedfs\|mongo" + docker_list_exposed_ports "db\|seaweedfs" else PORT=$1 if docker ps | grep ":${PORT}" | grep -q 'redis'; then ${command} redis://127.0.0.1:${PORT} elif docker ps | grep ":${PORT}" | grep -q 'postgres'; then ${command} postgresql://postgres:changethis@127.0.0.1:${PORT}/app - elif docker ps | grep ":${PORT}" | grep -q 'mongo'; then - ${command} mongodb://root:example@127.0.0.1:${PORT} else ${command} http://localhost:${PORT} fi diff --git a/scripts/reset_k8s.sh b/scripts/reset_k8s.sh index d0d245be6f2..033cb24ed31 100755 --- a/scripts/reset_k8s.sh +++ b/scripts/reset_k8s.sh @@ -1,22 +1,49 @@ #!/bin/bash -# WARNING: this will drop the 'app' database in your mongo-0 instance in the syft namespace echo $1 -# Dropping the database on mongo-0 -if [ -z $1 ]; then - MONGO_POD_NAME="mongo-0" -else - MONGO_POD_NAME=$1 -fi +# Default pod name +DEFAULT_POD_NAME="postgres-0" + +# Use the provided pod name or the default +POSTGRES_POD_NAME=${1:-$DEFAULT_POD_NAME} + +# SQL commands to reset all tables +RESET_COMMAND=" +DO \$\$ +DECLARE + r RECORD; +BEGIN + -- Disable all triggers + SET session_replication_role = 'replica'; + + -- Truncate all tables in the current schema + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = current_schema()) LOOP + EXECUTE 'TRUNCATE TABLE ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + + -- Re-enable all triggers + SET session_replication_role = 'origin'; +END \$\$; + +-- Reset all sequences +DO \$\$ +DECLARE + r RECORD; +BEGIN + FOR r IN (SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = current_schema()) LOOP + EXECUTE 'ALTER SEQUENCE ' || quote_ident(r.sequence_name) || ' RESTART WITH 1'; + END LOOP; +END \$\$; +" -DROPCMD="<&1 +echo "All tables in $POSTGRES_POD_NAME have been reset." # Resetting the backend pod BACKEND_POD=$(kubectl get pods -n syft -o jsonpath="{.items[*].metadata.name}" | tr ' ' '\n' | grep -E ".*backend.*") diff --git a/scripts/reset_mongo.sh b/scripts/reset_mongo.sh deleted file mode 100755 index ac1641f68e4..00000000000 --- a/scripts/reset_mongo.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# WARNING: this will drop the app database in all your mongo dbs -echo $1 - -if [ -z $1 ]; then - MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo) -else - MONGO_CONTAINER_NAME=$1 -fi - -DROPCMD="<&1 \ No newline at end of file diff --git a/scripts/reset_network.sh b/scripts/reset_network.sh deleted file mode 100755 index ce5f863ff14..00000000000 --- a/scripts/reset_network.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo) -DROPCMD="<&1 - -# flush the worker queue -. ${BASH_SOURCE%/*}/flush_queue.sh - -# reset docker service to clear out weird network issues -sudo service docker restart - -# make sure all containers start -. ${BASH_SOURCE%/*}/../packages/grid/scripts/containers.sh diff --git a/tox.ini b/tox.ini index 958dc2a8772..6e37f408ded 100644 --- a/tox.ini +++ b/tox.ini @@ -467,7 +467,7 @@ commands = ; sleep 30 # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -721,12 +721,12 @@ commands = sleep 30 # wait for test gateway 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft # wait for test datasite 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -813,7 +813,7 @@ commands = sleep 30 # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -887,7 +887,7 @@ allowlist_externals = setenv = CLUSTER_NAME = {env:CLUSTER_NAME:syft} CLUSTER_HTTP_PORT = {env:SERVER_PORT:8080} -; Usage for posargs: names of the relevant services among {frontend backend proxy mongo seaweedfs registry} +; Usage for posargs: names of the relevant services among {frontend backend proxy postgres seaweedfs registry} commands = bash -c "env; date; k3d version" @@ -906,7 +906,7 @@ commands = fi" # Mongo - bash -c "if echo '{posargs}' | grep -q 'mongo'; then echo 'Checking readiness of Mongo'; ./scripts/wait_for.sh service mongo --context k3d-$CLUSTER_NAME --namespace syft; fi" + bash -c "if echo '{posargs}' | grep -q 'postgres'; then echo 'Checking readiness of Postgres'; ./scripts/wait_for.sh service postgres --context k3d-$CLUSTER_NAME --namespace syft; fi" # Proxy bash -c "if echo '{posargs}' | grep -q 'proxy'; then echo 'Checking readiness of proxy'; ./scripts/wait_for.sh service proxy --context k3d-$CLUSTER_NAME --namespace syft; fi" @@ -950,7 +950,7 @@ commands = echo "Installing local helm charts"; \ if [[ "{posargs}" == "override" ]]; then \ echo "Overriding resourcesPreset"; \ - helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ + helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ else \ helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \ fi \ @@ -960,14 +960,14 @@ commands = helm repo update openmined; \ if [[ "{posargs}" == "override" ]]; then \ echo "Overriding resourcesPreset"; \ - helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ + helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ else \ helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \ fi \ fi' ; wait for everything else to be loaded - tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry + tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry # Run Notebook tests tox -e e2e.test.notebook @@ -1234,7 +1234,9 @@ setenv= DEVSPACE_PROFILE={env:DEVSPACE_PROFILE} allowlist_externals = tox + bash commands = + bash -c "CLUSTER_NAME=${CLUSTER_NAME} tox -e dev.k8s.destroy" tox -e dev.k8s.start tox -e dev.k8s.{posargs:deploy} @@ -1460,7 +1462,7 @@ commands = ' ; wait for everything else to be loaded - tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry + tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry bash -c 'python -c "import syft as sy; print(\"Migrating from syft version:\", sy.__version__)"' @@ -1541,7 +1543,7 @@ commands = sleep 30 ; # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft