diff --git a/.github/workflows/container-scan.yml b/.github/workflows/container-scan.yml
index 303eb11bc40..211a8022f62 100644
--- a/.github/workflows/container-scan.yml
+++ b/.github/workflows/container-scan.yml
@@ -224,7 +224,7 @@ jobs:
name: syft.sbom.json
path: syft.sbom.json
- scan-mongo-latest-trivy:
+ scan-postgres-latest-trivy:
permissions:
contents: read # for actions/checkout to fetch code
security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
@@ -238,24 +238,24 @@ jobs:
continue-on-error: true
uses: aquasecurity/trivy-action@master
with:
- image-ref: "mongo:7.0.0"
+ image-ref: "postgres:16.1"
format: "cyclonedx"
- output: "mongo-trivy-results.sbom.json"
+ output: "postgres-trivy-results.sbom.json"
timeout: "10m0s"
#Upload SBOM to GitHub Artifacts
- name: Upload SBOM to GitHub Artifacts
uses: actions/upload-artifact@v4
with:
- name: mongo-trivy-results.sbom.json
- path: mongo-trivy-results.sbom.json
+ name: postgres-trivy-results.sbom.json
+ path: postgres-trivy-results.sbom.json
#Generate sarif file
- name: Run Trivy vulnerability scanner
continue-on-error: true
uses: aquasecurity/trivy-action@master
with:
- image-ref: "mongo:7.0.0"
+ image-ref: "postgres:16.1"
format: "sarif"
output: "trivy-results.sarif"
timeout: "10m0s"
@@ -266,7 +266,7 @@ jobs:
with:
sarif_file: "trivy-results.sarif"
- scan-mongo-latest-snyk:
+ scan-postgres-latest-snyk:
permissions:
contents: read # for actions/checkout to fetch code
security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
@@ -281,7 +281,7 @@ jobs:
# This is where you will need to introduce the Snyk API token created with your Snyk account
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
with:
- image: mongo:7.0.0
+ image: postgres:16.1
args: --sarif-file-output=snyk-code.sarif
# Replace any "undefined" security severity values with 0. The undefined value is used in the case
diff --git a/.isort.cfg b/.isort.cfg
index aeb09bb8f36..26309a07039 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -20,4 +20,4 @@ import_heading_localfolder=relative
ignore_comments=False
force_grid_wrap=True
honor_noqa=True
-skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/*,packages/syft/tests/mongomock/*
\ No newline at end of file
+skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/*
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3487d8d0915..e1c50cd3b96 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,41 +3,36 @@ repos:
rev: v4.5.0
hooks:
- id: check-ast
- exclude: ^(packages/syft/tests/mongomock)
always_run: true
- id: trailing-whitespace
always_run: true
- exclude: ^(docs/|.+\.md|.bumpversion.cfg|packages/syft/tests/mongomock)
+ exclude: ^(docs/|.+\.md|.bumpversion.cfg)
- id: check-docstring-first
always_run: true
- exclude: ^(packages/syft/tests/mongomock)
- id: check-json
always_run: true
- exclude: ^(packages/grid/frontend/|packages/syft/tests/mongomock|.vscode)
+ exclude: ^(packages/grid/frontend/|.vscode)
- id: check-added-large-files
always_run: true
exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif)
- id: check-yaml
always_run: true
- exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/|packages/syft/tests/mongomock)
+ exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/)
- id: check-merge-conflict
always_run: true
args: ["--assume-in-merge"]
- id: check-executables-have-shebangs
always_run: true
- exclude: ^(packages/syft/tests/mongomock)
- id: debug-statements
always_run: true
- exclude: ^(packages/syft/tests/mongomock)
- id: name-tests-test
always_run: true
- exclude: ^(.*/tests/utils/)|^(.*fixtures.py|packages/syft/tests/mongomock)|^(tests/scenarios/bigquery/helpers)
+ exclude: ^(.*/tests/utils/)|^(.*fixtures.py)|^(tests/scenarios/bigquery/helpers)
- id: requirements-txt-fixer
always_run: true
- exclude: "packages/syft/tests/mongomock"
- id: mixed-line-ending
args: ["--fix=lf"]
- exclude: '\.bat|\.csv|\.ps1$|packages/syft/tests/mongomock'
+ exclude: '\.bat|\.csv|\.ps1$'
- repo: https://github.com/MarcoGorelli/absolufy-imports # This repository has been archived by the owner on Aug 15, 2023. It is now read-only.
rev: v0.3.1
@@ -88,7 +83,6 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- exclude: packages/syft/tests/mongomock
types_or: [python, pyi, jupyter]
- id: ruff-format
types_or: [python, pyi, jupyter]
@@ -178,7 +172,7 @@ repos:
rev: "v3.0.0-alpha.9-for-vscode"
hooks:
- id: prettier
- exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode)
+ exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|.vscode)
# - repo: meta
# hooks:
diff --git a/.vscode/launch.json b/.vscode/launch.json
index bb5d6e9c00a..7e30fe06537 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -4,6 +4,13 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
+ {
+ "name": "Python Debugger: Current File",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal"
+ },
{
"name": "Syft Debugger",
"type": "debugpy",
diff --git a/docs/source/api_reference/syft.store.mongo_client.rst b/docs/source/api_reference/syft.store.mongo_client.rst
deleted file mode 100644
index a21d43700aa..00000000000
--- a/docs/source/api_reference/syft.store.mongo_client.rst
+++ /dev/null
@@ -1,31 +0,0 @@
-syft.store.mongo\_client
-========================
-
-.. automodule:: syft.store.mongo_client
-
-
-
-
-
-
-
-
-
-
-
- .. rubric:: Classes
-
- .. autosummary::
-
- MongoClient
- MongoClientCache
- MongoStoreClientConfig
-
-
-
-
-
-
-
-
-
diff --git a/docs/source/api_reference/syft.store.mongo_codecs.rst b/docs/source/api_reference/syft.store.mongo_codecs.rst
deleted file mode 100644
index 1d91b779e95..00000000000
--- a/docs/source/api_reference/syft.store.mongo_codecs.rst
+++ /dev/null
@@ -1,35 +0,0 @@
-syft.store.mongo\_codecs
-========================
-
-.. automodule:: syft.store.mongo_codecs
-
-
-
-
-
-
-
- .. rubric:: Functions
-
- .. autosummary::
-
- fallback_syft_encoder
-
-
-
-
-
- .. rubric:: Classes
-
- .. autosummary::
-
- SyftMongoBinaryDecoder
-
-
-
-
-
-
-
-
-
diff --git a/docs/source/api_reference/syft.store.mongo_document_store.rst b/docs/source/api_reference/syft.store.mongo_document_store.rst
deleted file mode 100644
index 30fdb6bc6ca..00000000000
--- a/docs/source/api_reference/syft.store.mongo_document_store.rst
+++ /dev/null
@@ -1,40 +0,0 @@
-syft.store.mongo\_document\_store
-=================================
-
-.. automodule:: syft.store.mongo_document_store
-
-
-
-
-
-
-
- .. rubric:: Functions
-
- .. autosummary::
-
- from_mongo
- syft_obj_to_mongo
- to_mongo
-
-
-
-
-
- .. rubric:: Classes
-
- .. autosummary::
-
- MongoBsonObject
- MongoDocumentStore
- MongoStoreConfig
- MongoStorePartition
-
-
-
-
-
-
-
-
-
diff --git a/docs/source/api_reference/syft.store.rst b/docs/source/api_reference/syft.store.rst
index b21cf230488..e83e8699025 100644
--- a/docs/source/api_reference/syft.store.rst
+++ b/docs/source/api_reference/syft.store.rst
@@ -32,8 +32,5 @@
syft.store.kv_document_store
syft.store.linked_obj
syft.store.locks
- syft.store.mongo_client
- syft.store.mongo_codecs
- syft.store.mongo_document_store
syft.store.sqlite_document_store
diff --git a/notebooks/api/0.8/12-custom-api-endpoint.ipynb b/notebooks/api/0.8/12-custom-api-endpoint.ipynb
index aa60e30dd87..c58c78c1795 100644
--- a/notebooks/api/0.8/12-custom-api-endpoint.ipynb
+++ b/notebooks/api/0.8/12-custom-api-endpoint.ipynb
@@ -657,6 +657,9 @@
"# stdlib\n",
"import time\n",
"\n",
+ "# syft absolute\n",
+ "from syft.service.job.job_stash import JobStatus\n",
+ "\n",
"# Iterate over the Jobs waiting them to finish their pipelines.\n",
"job_pool = [\n",
" (log_call_job, \"Logging Private Function Call\"),\n",
@@ -665,13 +668,19 @@
"]\n",
"for job, expected_log in job_pool:\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
- " while updated_job.status.value != \"completed\":\n",
+ " while updated_job.status in {JobStatus.CREATED, JobStatus.PROCESSING}:\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
" time.sleep(1)\n",
- " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n",
- " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n",
- " _print=False\n",
- " )"
+ "\n",
+ " assert (\n",
+ " updated_job.status == JobStatus.COMPLETED\n",
+ " ), f\"Job {updated_job.id} exited with status {updated_job.status} and result {updated_job.result}\"\n",
+ " if updated_job.status == JobStatus.COMPLETED:\n",
+ " print(f\"Job {updated_job.id} completed\")\n",
+ " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n",
+ " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n",
+ " _print=False\n",
+ " )"
]
},
{
@@ -683,6 +692,11 @@
}
],
"metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
"language_info": {
"codemirror_mode": {
"name": "ipython",
@@ -693,7 +707,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.5"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb
index c64e6c40f4a..a92f7987e68 100644
--- a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb
+++ b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb
@@ -78,7 +78,7 @@
"If you want to deploy your Kubernetes cluster in a resource-constrained environment, use the following flags to override the default configurations. Please note that you will need at least 1 CPU and 2 GB of RAM on Docker, and some tests may not work in such low-resource environments:\n",
"\n",
"```sh\n",
- "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n",
+ "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n",
"```\n",
"\n",
"\n",
@@ -89,7 +89,7 @@
"If you would like to set your own default password even for the production style deployment, use the following command:\n",
"\n",
"```sh\n",
- "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set mongo.secret.rootPassword=\"example\"\n",
+ "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set postgres.secret.rootPassword=\"example\"\n",
"```\n",
"\n"
]
diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile
index c51ba31c8fd..984d2f174f4 100644
--- a/packages/grid/backend/backend.dockerfile
+++ b/packages/grid/backend/backend.dockerfile
@@ -9,7 +9,7 @@ ARG TORCH_VERSION="2.2.2"
# ==================== [BUILD STEP] Python Dev Base ==================== #
-FROM cgr.dev/chainguard/wolfi-base as syft_deps
+FROM cgr.dev/chainguard/wolfi-base AS syft_deps
ARG PYTHON_VERSION
ARG UV_VERSION
@@ -45,7 +45,7 @@ RUN --mount=type=cache,target=/root/.cache,sharing=locked \
# ==================== [Final] Setup Syft Server ==================== #
-FROM cgr.dev/chainguard/wolfi-base as backend
+FROM cgr.dev/chainguard/wolfi-base AS backend
ARG PYTHON_VERSION
ARG UV_VERSION
@@ -84,9 +84,10 @@ ENV \
DEFAULT_ROOT_EMAIL="info@openmined.org" \
DEFAULT_ROOT_PASSWORD="changethis" \
STACK_API_KEY="changeme" \
- MONGO_HOST="localhost" \
- MONGO_PORT="27017" \
- MONGO_USERNAME="root" \
- MONGO_PASSWORD="example"
+ POSTGRESQL_DBNAME="syftdb_postgres" \
+ POSTGRESQL_HOST="localhost" \
+ POSTGRESQL_PORT="5432" \
+ POSTGRESQL_USERNAME="syft_postgres" \
+ POSTGRESQL_PASSWORD="example"
CMD ["bash", "./grid/start.sh"]
diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py
index 63bda939c29..e92d6783ae7 100644
--- a/packages/grid/backend/grid/core/config.py
+++ b/packages/grid/backend/grid/core/config.py
@@ -126,10 +126,11 @@ def get_emails_enabled(self) -> Self:
# NETWORK_CHECK_INTERVAL: int = int(os.getenv("NETWORK_CHECK_INTERVAL", 60))
# DATASITE_CHECK_INTERVAL: int = int(os.getenv("DATASITE_CHECK_INTERVAL", 60))
CONTAINER_HOST: str = str(os.getenv("CONTAINER_HOST", "docker"))
- MONGO_HOST: str = str(os.getenv("MONGO_HOST", ""))
- MONGO_PORT: int = int(os.getenv("MONGO_PORT", 27017))
- MONGO_USERNAME: str = str(os.getenv("MONGO_USERNAME", ""))
- MONGO_PASSWORD: str = str(os.getenv("MONGO_PASSWORD", ""))
+ POSTGRESQL_DBNAME: str = str(os.getenv("POSTGRESQL_DBNAME", ""))
+ POSTGRESQL_HOST: str = str(os.getenv("POSTGRESQL_HOST", ""))
+ POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 5432))
+ POSTGRESQL_USERNAME: str = str(os.getenv("POSTGRESQL_USERNAME", ""))
+ POSTGRESQL_PASSWORD: str = str(os.getenv("POSTGRESQL_PASSWORD", ""))
DEV_MODE: bool = True if os.getenv("DEV_MODE", "false").lower() == "true" else False
# ZMQ stuff
QUEUE_PORT: int = int(os.getenv("QUEUE_PORT", 5556))
@@ -137,7 +138,7 @@ def get_emails_enabled(self) -> Self:
True if os.getenv("CREATE_PRODUCER", "false").lower() == "true" else False
)
N_CONSUMERS: int = int(os.getenv("N_CONSUMERS", 1))
- SQLITE_PATH: str = os.path.expandvars("$HOME/data/db/")
+ SQLITE_PATH: str = os.path.expandvars("/tmp/data/db")
SINGLE_CONTAINER_MODE: bool = str_to_bool(os.getenv("SINGLE_CONTAINER_MODE", False))
CONSUMER_SERVICE_NAME: str | None = os.getenv("CONSUMER_SERVICE_NAME")
INMEMORY_WORKERS: bool = str_to_bool(os.getenv("INMEMORY_WORKERS", True))
diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py
index 3f401d7e349..7d8d011de5d 100644
--- a/packages/grid/backend/grid/core/server.py
+++ b/packages/grid/backend/grid/core/server.py
@@ -1,3 +1,6 @@
+# stdlib
+from pathlib import Path
+
# syft absolute
from syft.abstract_server import ServerType
from syft.server.datasite import Datasite
@@ -14,10 +17,8 @@
from syft.service.queue.zmq_client import ZMQQueueConfig
from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig
from syft.store.blob_storage.seaweedfs import SeaweedFSConfig
-from syft.store.mongo_client import MongoStoreClientConfig
-from syft.store.mongo_document_store import MongoStoreConfig
-from syft.store.sqlite_document_store import SQLiteStoreClientConfig
-from syft.store.sqlite_document_store import SQLiteStoreConfig
+from syft.store.db.postgres import PostgresDBConfig
+from syft.store.db.sqlite import SQLiteDBConfig
from syft.types.uid import UID
# server absolute
@@ -36,23 +37,26 @@ def queue_config() -> ZMQQueueConfig:
return queue_config
-def mongo_store_config() -> MongoStoreConfig:
- mongo_client_config = MongoStoreClientConfig(
- hostname=settings.MONGO_HOST,
- port=settings.MONGO_PORT,
- username=settings.MONGO_USERNAME,
- password=settings.MONGO_PASSWORD,
- )
-
- return MongoStoreConfig(client_config=mongo_client_config)
-
+def sql_store_config() -> SQLiteDBConfig:
+ # Check if the directory exists, and create it if it doesn't
+ sqlite_path = Path(settings.SQLITE_PATH)
+ if not sqlite_path.exists():
+ sqlite_path.mkdir(parents=True, exist_ok=True)
-def sql_store_config() -> SQLiteStoreConfig:
- client_config = SQLiteStoreClientConfig(
+ return SQLiteDBConfig(
filename=f"{UID.from_string(get_server_uid_env())}.sqlite",
path=settings.SQLITE_PATH,
)
- return SQLiteStoreConfig(client_config=client_config)
+
+
+def postgresql_store_config() -> PostgresDBConfig:
+ return PostgresDBConfig(
+ host=settings.POSTGRESQL_HOST,
+ port=settings.POSTGRESQL_PORT,
+ user=settings.POSTGRESQL_USERNAME,
+ password=settings.POSTGRESQL_PASSWORD,
+ database=settings.POSTGRESQL_DBNAME,
+ )
def seaweedfs_config() -> SeaweedFSConfig:
@@ -87,20 +91,19 @@ def seaweedfs_config() -> SeaweedFSConfig:
worker_class = worker_classes[server_type]
single_container_mode = settings.SINGLE_CONTAINER_MODE
-store_config = sql_store_config() if single_container_mode else mongo_store_config()
+db_config = sql_store_config() if single_container_mode else postgresql_store_config()
+
blob_storage_config = None if single_container_mode else seaweedfs_config()
queue_config = queue_config()
worker: Server = worker_class(
name=server_name,
server_side_type=server_side_type,
- action_store_config=store_config,
- document_store_config=store_config,
enable_warnings=enable_warnings,
blob_storage_config=blob_storage_config,
local_db=single_container_mode,
queue_config=queue_config,
- migrate=True,
+ migrate=False,
in_memory_workers=settings.INMEMORY_WORKERS,
smtp_username=settings.SMTP_USERNAME,
smtp_password=settings.SMTP_PASSWORD,
@@ -109,4 +112,5 @@ def seaweedfs_config() -> SeaweedFSConfig:
smtp_host=settings.SMTP_HOST,
association_request_auto_approval=settings.ASSOCIATION_REQUEST_AUTO_APPROVAL,
background_tasks=True,
+ db_config=db_config,
)
diff --git a/packages/grid/default.env b/packages/grid/default.env
index e1bc5c42557..49697538cd8 100644
--- a/packages/grid/default.env
+++ b/packages/grid/default.env
@@ -77,14 +77,6 @@ KANIKO_VERSION="v1.23.2"
# Jax
JAX_ENABLE_X64=True
-# Mongo
-MONGO_IMAGE=mongo
-MONGO_VERSION="7.0.8"
-MONGO_HOST=mongo
-MONGO_PORT=27017
-MONGO_USERNAME=root
-MONGO_PASSWORD=example
-
# Redis
REDIS_PORT=6379
REDIS_STORE_DB_ID=0
@@ -110,4 +102,13 @@ ENABLE_SIGNUP=False
DOCKER_IMAGE_ENCLAVE_ATTESTATION=openmined/syft-enclave-attestation
# Rathole Config
-RATHOLE_PORT=2333
\ No newline at end of file
+RATHOLE_PORT=2333
+
+# PostgresSQL Config
+# POSTGRESQL_IMAGE=postgres
+# export POSTGRESQL_VERSION="15"
+POSTGRESQL_DBNAME=syftdb_postgres
+POSTGRESQL_HOST=postgres
+POSTGRESQL_PORT=5432
+POSTGRESQL_USERNAME=syft_postgres
+POSTGRESQL_PASSWORD=example
diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml
index 08121fcaaa6..b9a31457d4e 100644
--- a/packages/grid/devspace.yaml
+++ b/packages/grid/devspace.yaml
@@ -79,12 +79,12 @@ deployments:
- ./helm/examples/dev/base.yaml
dev:
- mongo:
+ postgres:
labelSelector:
app.kubernetes.io/name: syft
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
ports:
- - port: "27017"
+ - port: "5432"
seaweedfs:
labelSelector:
app.kubernetes.io/name: syft
@@ -94,6 +94,7 @@ dev:
- port: "8888" # filer
- port: "8333" # S3
- port: "4001" # mount azure
+ - port: "5432" # mount postgres
backend:
labelSelector:
app.kubernetes.io/name: syft
@@ -205,8 +206,8 @@ profiles:
path: dev.seaweedfs
# Port Re-Mapping
- op: replace
- path: dev.mongo.ports[0].port
- value: 27018:27017
+ path: dev.postgres.ports[0].port
+ value: 5433:5432
- op: replace
path: dev.backend.ports[0].port
value: 5679:5678
@@ -268,8 +269,8 @@ profiles:
value: ./helm/examples/dev/enclave.yaml
# Port Re-Mapping
- op: replace
- path: dev.mongo.ports[0].port
- value: 27019:27017
+ path: dev.postgres.ports[0].port
+ value: 5434:5432
- op: replace
path: dev.backend.ports[0].port
value: 5680:5678
diff --git a/packages/grid/helm/examples/azure/azure.high.yaml b/packages/grid/helm/examples/azure/azure.high.yaml
index 3234a62b757..4733fed4cc7 100644
--- a/packages/grid/helm/examples/azure/azure.high.yaml
+++ b/packages/grid/helm/examples/azure/azure.high.yaml
@@ -38,5 +38,5 @@ registry:
frontend:
resourcesPreset: medium
-mongo:
+postgres:
resourcesPreset: large
diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml
index 4999ae40aed..3fc1ad5c4da 100644
--- a/packages/grid/helm/examples/dev/base.yaml
+++ b/packages/grid/helm/examples/dev/base.yaml
@@ -21,7 +21,7 @@ server:
secret:
defaultRootPassword: changethis
-mongo:
+postgres:
resourcesPreset: null
resources: null
diff --git a/packages/grid/helm/examples/gcp/gcp.high.yaml b/packages/grid/helm/examples/gcp/gcp.high.yaml
index efdbbe72e68..2a430807fac 100644
--- a/packages/grid/helm/examples/gcp/gcp.high.yaml
+++ b/packages/grid/helm/examples/gcp/gcp.high.yaml
@@ -97,7 +97,7 @@ frontend:
# =================================================================================
-mongo:
+postgres:
resourcesPreset: large
# =================================================================================
diff --git a/packages/grid/helm/examples/gcp/gcp.low.yaml b/packages/grid/helm/examples/gcp/gcp.low.yaml
index 94cfc324b0b..8e9e3e7ba35 100644
--- a/packages/grid/helm/examples/gcp/gcp.low.yaml
+++ b/packages/grid/helm/examples/gcp/gcp.low.yaml
@@ -97,7 +97,7 @@ frontend:
# =================================================================================
-mongo:
+postgres:
resourcesPreset: large
# =================================================================================
diff --git a/packages/grid/helm/examples/gcp/gcp.nosync.yaml b/packages/grid/helm/examples/gcp/gcp.nosync.yaml
index 02935edfd8f..8e622be5254 100644
--- a/packages/grid/helm/examples/gcp/gcp.nosync.yaml
+++ b/packages/grid/helm/examples/gcp/gcp.nosync.yaml
@@ -67,7 +67,7 @@ frontend:
# =================================================================================
-mongo:
+postgres:
resourcesPreset: large
# =================================================================================
diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml
index 3293056ba2a..2d1a6880c33 100644
--- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml
+++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml
@@ -99,18 +99,20 @@ spec:
- name: REVERSE_TUNNEL_ENABLED
value: "true"
{{- end }}
- # MongoDB
- - name: MONGO_PORT
- value: {{ .Values.mongo.port | quote }}
- - name: MONGO_HOST
- value: "mongo"
- - name: MONGO_USERNAME
- value: {{ .Values.mongo.username | quote }}
- - name: MONGO_PASSWORD
+ # Postgres
+ - name: POSTGRESQL_PORT
+ value: {{ .Values.postgres.port | quote }}
+ - name: POSTGRESQL_HOST
+ value: "postgres"
+ - name: POSTGRESQL_USERNAME
+ value: {{ .Values.postgres.username | quote }}
+ - name: POSTGRESQL_PASSWORD
valueFrom:
secretKeyRef:
- name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }}
+ name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }}
key: rootPassword
+ - name: POSTGRESQL_DBNAME
+ value: {{ .Values.postgres.dbname | quote }}
# SMTP
- name: SMTP_HOST
value: {{ .Values.server.smtp.host | quote }}
diff --git a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml b/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml
deleted file mode 100644
index 91060b90a9b..00000000000
--- a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml
+++ /dev/null
@@ -1,69 +0,0 @@
-apiVersion: apps/v1
-kind: StatefulSet
-metadata:
- name: mongo
- labels:
- {{- include "common.labels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
-spec:
- replicas: 1
- updateStrategy:
- type: RollingUpdate
- selector:
- matchLabels:
- {{- include "common.selectorLabels" . | nindent 6 }}
- app.kubernetes.io/component: mongo
- serviceName: mongo-headless
- podManagementPolicy: OrderedReady
- template:
- metadata:
- labels:
- {{- include "common.labels" . | nindent 8 }}
- app.kubernetes.io/component: mongo
- {{- if .Values.mongo.podLabels }}
- {{- toYaml .Values.mongo.podLabels | nindent 8 }}
- {{- end }}
- {{- if .Values.mongo.podAnnotations }}
- annotations: {{- toYaml .Values.mongo.podAnnotations | nindent 8 }}
- {{- end }}
- spec:
- {{- if .Values.mongo.nodeSelector }}
- nodeSelector: {{- .Values.mongo.nodeSelector | toYaml | nindent 8 }}
- {{- end }}
- containers:
- - name: mongo-container
- image: mongo:7
- imagePullPolicy: Always
- resources: {{ include "common.resources.set" (dict "resources" .Values.mongo.resources "preset" .Values.mongo.resourcesPreset) | nindent 12 }}
- env:
- - name: MONGO_INITDB_ROOT_USERNAME
- value: {{ .Values.mongo.username | required "mongo.username is required" | quote }}
- - name: MONGO_INITDB_ROOT_PASSWORD
- valueFrom:
- secretKeyRef:
- name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }}
- key: rootPassword
- {{- if .Values.mongo.env }}
- {{- toYaml .Values.mongo.env | nindent 12 }}
- {{- end }}
- volumeMounts:
- - mountPath: /data/db
- name: mongo-data
- readOnly: false
- subPath: ''
- ports:
- - name: mongo-port
- containerPort: 27017
- terminationGracePeriodSeconds: 5
- volumeClaimTemplates:
- - metadata:
- name: mongo-data
- labels:
- {{- include "common.volumeLabels" . | nindent 8 }}
- app.kubernetes.io/component: mongo
- spec:
- accessModes:
- - ReadWriteOnce
- resources:
- requests:
- storage: {{ .Values.mongo.storageSize | quote }}
diff --git a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml
similarity index 57%
rename from packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml
rename to packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml
index 7cb97ee3592..4855a7868ff 100644
--- a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml
+++ b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml
@@ -1,15 +1,15 @@
apiVersion: v1
kind: Service
metadata:
- name: mongo-headless
+ name: postgres-headless
labels:
{{- include "common.labels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
spec:
clusterIP: None
ports:
- - name: mongo
- port: 27017
+ - name: postgres
+ port: 5432
selector:
{{- include "common.selectorLabels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
\ No newline at end of file
diff --git a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml
similarity index 69%
rename from packages/grid/helm/syft/templates/mongo/mongo-secret.yaml
rename to packages/grid/helm/syft/templates/postgres/postgres-secret.yaml
index 02c58d276ca..63a990c0d9a 100644
--- a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml
+++ b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml
@@ -1,17 +1,17 @@
-{{- $secretName := "mongo-secret" }}
+{{- $secretName := "postgres-secret" }}
apiVersion: v1
kind: Secret
metadata:
name: {{ $secretName }}
labels:
{{- include "common.labels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
type: Opaque
data:
rootPassword: {{ include "common.secrets.set" (dict
"secret" $secretName
"key" "rootPassword"
"randomDefault" .Values.global.randomizedSecrets
- "default" .Values.mongo.secret.rootPassword
+ "default" .Values.postgres.secret.rootPassword
"context" $)
- }}
+ }}
\ No newline at end of file
diff --git a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml
similarity index 57%
rename from packages/grid/helm/syft/templates/mongo/mongo-service.yaml
rename to packages/grid/helm/syft/templates/postgres/postgres-service.yaml
index a789f4e8f86..9cd8b156bdd 100644
--- a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml
+++ b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml
@@ -1,17 +1,17 @@
apiVersion: v1
kind: Service
metadata:
- name: mongo
+ name: postgres
labels:
{{- include "common.labels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
spec:
type: ClusterIP
selector:
{{- include "common.selectorLabels" . | nindent 4 }}
- app.kubernetes.io/component: mongo
+ app.kubernetes.io/component: postgres
ports:
- - name: mongo
- port: 27017
+ - name: postgres
+ port: 5432
protocol: TCP
- targetPort: 27017
+ targetPort: 5432
\ No newline at end of file
diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml
new file mode 100644
index 00000000000..986031b17e9
--- /dev/null
+++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml
@@ -0,0 +1,72 @@
+apiVersion: apps/v1
+kind: StatefulSet
+metadata:
+ name: postgres
+ labels:
+ {{- include "common.labels" . | nindent 4 }}
+ app.kubernetes.io/component: postgres
+spec:
+ replicas: 1
+ updateStrategy:
+ type: RollingUpdate
+ selector:
+ matchLabels:
+ {{- include "common.selectorLabels" . | nindent 6 }}
+ app.kubernetes.io/component: postgres
+ serviceName: postgres-headless
+ podManagementPolicy: OrderedReady
+ template:
+ metadata:
+ labels:
+ {{- include "common.labels" . | nindent 8 }}
+ app.kubernetes.io/component: postgres
+ {{- if .Values.postgres.podLabels }}
+ {{- toYaml .Values.postgres.podLabels | nindent 8 }}
+ {{- end }}
+ {{- if .Values.postgres.podAnnotations }}
+ annotations: {{- toYaml .Values.postgres.podAnnotations | nindent 8 }}
+ {{- end }}
+ spec:
+ {{- if .Values.postgres.nodeSelector }}
+ nodeSelector: {{- .Values.postgres.nodeSelector | toYaml | nindent 8 }}
+ {{- end }}
+ containers:
+ - name: postgres-container
+ image: postgres:16.1
+ imagePullPolicy: Always
+ resources: {{ include "common.resources.set" (dict "resources" .Values.postgres.resources "preset" .Values.postgres.resourcesPreset) | nindent 12 }}
+ env:
+ - name: POSTGRES_USER
+ value: {{ .Values.postgres.username | required "postgres.username is required" | quote }}
+ - name: POSTGRES_PASSWORD
+ valueFrom:
+ secretKeyRef:
+ name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }}
+ key: rootPassword
+ - name: POSTGRES_DB
+ value: {{ .Values.postgres.dbname | required "postgres.dbname is required" | quote }}
+ {{- if .Values.postgres.env }}
+ {{- toYaml .Values.postgres.env | nindent 12 }}
+ {{- end }}
+ volumeMounts:
+ - mountPath: tmp/data/db
+ name: postgres-data
+ readOnly: false
+ subPath: ''
+ ports:
+ - name: postgres-port
+ containerPort: 5432
+ terminationGracePeriodSeconds: 5
+ volumeClaimTemplates:
+ - metadata:
+ name: postgres-data
+ labels:
+ {{- include "common.volumeLabels" . | nindent 8 }}
+ app.kubernetes.io/component: postgres
+ spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: {{ .Values.postgres.storageSize | quote }}
+
diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml
index c1cec54b441..ba3eebffde8 100644
--- a/packages/grid/helm/syft/values.yaml
+++ b/packages/grid/helm/syft/values.yaml
@@ -12,10 +12,12 @@ global:
# =================================================================================
-mongo:
- # MongoDB config
- port: 27017
- username: root
+postgres:
+# Postgres config
+ port: 5432
+ username: syft_postgres
+ dbname: syftdb_postgres
+ host: postgres
# Extra environment vars
env: null
@@ -35,12 +37,11 @@ mongo:
storageSize: 5Gi
# Mongo secret name. Override this if you want to use a self-managed secret.
- secretKeyName: mongo-secret
+ secretKeyName: postgres-secret
# default/custom secret raw values
secret:
- rootPassword: null
-
+ rootPassword: null
# =================================================================================
frontend:
diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg
index 547445d20bb..d4ee2cff521 100644
--- a/packages/syft/setup.cfg
+++ b/packages/syft/setup.cfg
@@ -31,11 +31,10 @@ syft =
boto3==1.34.56
forbiddenfruit==0.1.4
packaging>=23.0
- pyarrow==15.0.0
+ pyarrow==17.0.0
pycapnp==2.0.0
pydantic[email]==2.6.0
pydantic-settings==2.2.1
- pymongo==4.6.3
pynacl==1.5.0
pyzmq>=23.2.1,<=25.1.1
requests==2.32.3
@@ -66,8 +65,12 @@ syft =
jinja2==3.1.4
tenacity==8.3.0
nh3==0.2.17
+ psycopg[binary]==3.1.19
+ psycopg[pool]==3.1.19
ipython<8.27.0
dynaconf==3.2.6
+ sqlalchemy==2.0.32
+ psycopg2-binary==2.9.9
install_requires =
%(syft)s
@@ -111,9 +114,9 @@ telemetry =
opentelemetry-instrumentation==0.48b0
opentelemetry-instrumentation-requests==0.48b0
opentelemetry-instrumentation-fastapi==0.48b0
- opentelemetry-instrumentation-pymongo==0.48b0
opentelemetry-instrumentation-botocore==0.48b0
opentelemetry-instrumentation-logging==0.48b0
+ opentelemetry-instrumentation-sqlalchemy==0.48b0
; opentelemetry-instrumentation-asyncio==0.48b0
; opentelemetry-instrumentation-sqlite3==0.48b0
; opentelemetry-instrumentation-threading==0.48b0
diff --git a/packages/syft/src/syft/abstract_server.py b/packages/syft/src/syft/abstract_server.py
index c222cf4ea5a..3b7885f0a0e 100644
--- a/packages/syft/src/syft/abstract_server.py
+++ b/packages/syft/src/syft/abstract_server.py
@@ -5,6 +5,7 @@
# relative
from .serde.serializable import serializable
+from .store.db.db import DBConfig
from .types.uid import UID
if TYPE_CHECKING:
@@ -41,6 +42,7 @@ class AbstractServer:
server_side_type: ServerSideType | None
in_memory_workers: bool
services: "ServiceRegistry"
+ db_config: DBConfig
def get_service(self, path_or_func: str | Callable) -> "AbstractService":
raise NotImplementedError
diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py
index a5aa64ca736..85fe33545fa 100644
--- a/packages/syft/src/syft/client/api.py
+++ b/packages/syft/src/syft/client/api.py
@@ -29,8 +29,7 @@
from ..serde.serializable import serializable
from ..serde.serialize import _serialize
from ..serde.signature import Signature
-from ..serde.signature import signature_remove_context
-from ..serde.signature import signature_remove_self
+from ..serde.signature import signature_remove
from ..server.credentials import SyftSigningKey
from ..server.credentials import SyftVerifyKey
from ..service.context import AuthedServiceContext
@@ -738,10 +737,17 @@ def __getattr__(self, name: str) -> Any:
)
def __getitem__(self, key: str | int) -> Any:
+ if hasattr(self, "get_index"):
+ return self.get_index(key)
if hasattr(self, "get_all"):
return self.get_all()[key]
raise NotImplementedError
+ def __iter__(self) -> Any:
+ if hasattr(self, "get_all"):
+ return iter(self.get_all())
+ raise NotImplementedError
+
def _repr_html_(self) -> Any:
if self.path == "settings":
return self.get()._repr_html_()
@@ -1117,9 +1123,10 @@ def build_endpoint_tree(
api_module = APIModule(path="", refresh_callback=self.refresh_api_callback)
for v in endpoints.values():
signature = v.signature
+ args_to_remove = ["context"]
if not v.has_self:
- signature = signature_remove_self(signature)
- signature = signature_remove_context(signature)
+ args_to_remove.append("self")
+ signature = signature_remove(signature, args_to_remove)
if isinstance(v, APIEndpoint):
endpoint_function = generate_remote_function(
self,
diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py
index 692bd017565..d6fd164f44f 100644
--- a/packages/syft/src/syft/client/client.py
+++ b/packages/syft/src/syft/client/client.py
@@ -605,6 +605,49 @@ def login(
result = post_process_result(result, unwrap_on_success=True)
return result
+ def forgot_password(
+ self,
+ email: str,
+ ) -> SyftSigningKey | None:
+ credentials = {"email": email}
+ if self.proxy_target_uid:
+ obj = forward_message_to_proxy(
+ self.make_call,
+ proxy_target_uid=self.proxy_target_uid,
+ path="forgot_password",
+ kwargs=credentials,
+ )
+ else:
+ response = self.server.services.user.forgot_password(
+ context=ServerServiceContext(server=self.server), email=email
+ )
+ obj = post_process_result(response, unwrap_on_success=True)
+
+ return obj
+
+ def reset_password(
+ self,
+ token: str,
+ new_password: str,
+ ) -> SyftSigningKey | None:
+ payload = {"token": token, "new_password": new_password}
+ if self.proxy_target_uid:
+ obj = forward_message_to_proxy(
+ self.make_call,
+ proxy_target_uid=self.proxy_target_uid,
+ path="reset_password",
+ kwargs=payload,
+ )
+ else:
+ response = self.server.services.user.reset_password(
+ context=ServerServiceContext(server=self.server),
+ token=token,
+ new_password=new_password,
+ )
+ obj = post_process_result(response, unwrap_on_success=True)
+
+ return obj
+
def register(self, new_user: UserCreate) -> SyftSigningKey | None:
if self.proxy_target_uid:
response = forward_message_to_proxy(
diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py
index 1cbdb44c488..6410c990eac 100644
--- a/packages/syft/src/syft/custom_worker/config.py
+++ b/packages/syft/src/syft/custom_worker/config.py
@@ -14,6 +14,7 @@
# relative
from ..serde.serializable import serializable
+from ..serde.serialize import _serialize
from ..service.response import SyftSuccess
from ..types.base import SyftBaseModel
from ..types.errors import SyftException
@@ -83,6 +84,10 @@ def merged_custom_cmds(self, sep: str = ";") -> str:
class WorkerConfig(SyftBaseModel):
pass
+ def hash(self) -> str:
+ _bytes = _serialize(self, to_bytes=True, for_hashing=True)
+ return sha256(_bytes).digest().hex()
+
@serializable(canonical_name="CustomWorkerConfig", version=1)
class CustomWorkerConfig(WorkerConfig):
diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py
index 0d295b81982..efed6023ab8 100644
--- a/packages/syft/src/syft/orchestra.py
+++ b/packages/syft/src/syft/orchestra.py
@@ -184,6 +184,7 @@ def deploy_to_python(
debug: bool = False,
migrate: bool = False,
consumer_type: ConsumerType | None = None,
+ db_url: str | None = None,
) -> ServerHandle:
worker_classes = {
ServerType.DATASITE: Datasite,
@@ -216,6 +217,7 @@ def deploy_to_python(
"migrate": migrate,
"deployment_type": deployment_type_enum,
"consumer_type": consumer_type,
+ "db_url": db_url,
}
if port:
@@ -329,6 +331,7 @@ def launch(
migrate: bool = False,
from_state_folder: str | Path | None = None,
consumer_type: ConsumerType | None = None,
+ db_url: str | None = None,
) -> ServerHandle:
if from_state_folder is not None:
with open(f"{from_state_folder}/config.json") as f:
@@ -378,11 +381,12 @@ def launch(
debug=debug,
migrate=migrate,
consumer_type=consumer_type,
+ db_url=db_url,
)
display(
SyftInfo(
message=f"You have launched a development server at http://{host}:{server_handle.port}."
- + "It is intended only for local use."
+ + " It is intended only for local use."
)
)
return server_handle
diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py
index 1ea9d1ae203..0c848585119 100644
--- a/packages/syft/src/syft/protocol/data_protocol.py
+++ b/packages/syft/src/syft/protocol/data_protocol.py
@@ -81,8 +81,14 @@ def handle_annotation_repr_(annotation: type) -> str:
"""Handle typing representation."""
origin = typing.get_origin(annotation)
args = typing.get_args(annotation)
+
+ def get_annotation_repr_for_arg(arg: type) -> str:
+ if hasattr(arg, "__canonical_name__"):
+ return arg.__canonical_name__
+ return getattr(arg, "__name__", str(arg))
+
if origin and args:
- args_repr = ", ".join(getattr(arg, "__name__", str(arg)) for arg in args)
+ args_repr = ", ".join(get_annotation_repr_for_arg(arg) for arg in args)
origin_repr = getattr(origin, "__name__", str(origin))
# Handle typing.Union and types.UnionType
diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json
index 57e224f83e5..023833d9887 100644
--- a/packages/syft/src/syft/protocol/protocol_version.json
+++ b/packages/syft/src/syft/protocol/protocol_version.json
@@ -8,6 +8,75 @@
"3": {
"version": 3,
"hash": "85e2f6d9a6a12e6029b153be4fd2b4efae3bb2b48e617c08b447c5db01bea6c4",
+ }
+ },
+ "Notification": {
+ "2": {
+ "version": 2,
+ "hash": "812d3a612422fb1cf53caa13ec34a7bdfcf033a7c24b7518f527af144cb45f3c",
+ "action": "add"
+ }
+ },
+ "SyftWorkerImage": {
+ "2": {
+ "version": 2,
+ "hash": "afd3a69719cd6d08b1121676ca8d80ca37be96ee5ed5893dc73733fbf47fd035",
+ "action": "add"
+ }
+ },
+ "WorkerSettings": {
+ "2": {
+ "version": 2,
+ "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa",
+ "action": "add"
+ }
+ },
+ "MongoDict": {
+ "1": {
+ "version": 1,
+ "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b",
+ "action": "remove"
+ }
+ },
+ "MongoStoreConfig": {
+ "1": {
+ "version": 1,
+ "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036",
+ "action": "remove"
+ }
+ },
+ "JobItem": {
+ "1": {
+ "version": 1,
+ "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6",
+ "action": "remove"
+ }
+ },
+ "DictStoreConfig": {
+ "1": {
+ "version": 1,
+ "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe",
+ "action": "remove"
+ }
+ },
+ "QueueItem": {
+ "2": {
+ "version": 2,
+ "hash": "1d8615f6daabcd2a285b2f36fd7bef1df76cdd119dd49c02069c50fd1b9c3ff4",
+ "action": "add"
+ }
+ },
+ "ActionQueueItem": {
+ "2": {
+ "version": 2,
+ "hash": "bfda6ef87e4045d663324bb91a215ea06e1f173aec1fb4d9ddd337cdc1f0787f",
+ "action": "add"
+ }
+ },
+ "APIEndpointQueueItem": {
+ "2": {
+ "version": 2,
+ "hash": "3a46370205152fa23a7d2bfa47130dbf2e2bc7ef31f6d3fe4c92fd8d683770b5",
"action": "add"
}
}
diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py
new file mode 100644
index 00000000000..ee86241716c
--- /dev/null
+++ b/packages/syft/src/syft/serde/json_serde.py
@@ -0,0 +1,452 @@
+# stdlib
+import base64
+from collections.abc import Callable
+from dataclasses import dataclass
+from enum import Enum
+import json
+import typing
+from typing import Any
+from typing import Generic
+from typing import TypeVar
+from typing import Union
+from typing import get_args
+from typing import get_origin
+
+# third party
+import pydantic
+
+# syft absolute
+import syft as sy
+
+# relative
+from ..server.credentials import SyftSigningKey
+from ..server.credentials import SyftVerifyKey
+from ..types.datetime import DateTime
+from ..types.syft_object import BaseDateTime
+from ..types.syft_object_registry import SyftObjectRegistry
+from ..types.uid import LineageID
+from ..types.uid import UID
+from .recursive import DEFAULT_EXCLUDE_ATTRS
+
+T = TypeVar("T")
+
+JSON_CANONICAL_NAME_FIELD = "__canonical_name__"
+JSON_VERSION_FIELD = "__version__"
+JSON_DATA_FIELD = "data"
+
+JsonPrimitive = str | int | float | bool | None
+Json = JsonPrimitive | list["Json"] | dict[str, "Json"]
+
+
+def _noop_fn(obj: Any) -> Any:
+ return obj
+
+
+@dataclass
+class JSONSerde(Generic[T]):
+ klass: type[T]
+ serialize_fn: Callable[[T], Json]
+ deserialize_fn: Callable[[Json], T]
+
+ def serialize(self, obj: T) -> Json:
+ return self.serialize_fn(obj)
+
+ def deserialize(self, obj: Json) -> T:
+ return self.deserialize_fn(obj)
+
+
+JSON_SERDE_REGISTRY: dict[type[T], JSONSerde[T]] = {}
+
+
+def register_json_serde(
+ type_: type[T],
+ serialize: Callable[[T], Json] | None = None,
+ deserialize: Callable[[Json], T] | None = None,
+) -> None:
+ if type_ in JSON_SERDE_REGISTRY:
+ raise ValueError(f"Type {type_} is already registered")
+
+ if serialize is None:
+ serialize = _noop_fn
+
+ if deserialize is None:
+ deserialize = _noop_fn
+
+ JSON_SERDE_REGISTRY[type_] = JSONSerde(
+ klass=type_,
+ serialize_fn=serialize,
+ deserialize_fn=deserialize,
+ )
+
+
+# Standard JSON primitives
+register_json_serde(int)
+register_json_serde(str)
+register_json_serde(bool)
+register_json_serde(float)
+register_json_serde(type(None))
+register_json_serde(pydantic.EmailStr)
+
+# Syft primitives
+register_json_serde(UID, lambda uid: uid.no_dash, lambda s: UID(s))
+register_json_serde(LineageID, lambda uid: uid.no_dash, lambda s: LineageID(s))
+register_json_serde(
+ DateTime, lambda dt: dt.utc_timestamp, lambda f: DateTime(utc_timestamp=f)
+)
+register_json_serde(
+ BaseDateTime, lambda dt: dt.utc_timestamp, lambda f: BaseDateTime(utc_timestamp=f)
+)
+register_json_serde(SyftVerifyKey, lambda key: str(key), SyftVerifyKey.from_string)
+register_json_serde(SyftSigningKey, lambda key: str(key), SyftSigningKey.from_string)
+
+
+def _validate_json(value: T) -> T:
+ # Throws TypeError if value is not JSON-serializable
+ json.dumps(value)
+ return value
+
+
+def _is_optional_annotation(annotation: Any) -> bool:
+ try:
+ return annotation | None == annotation
+ except TypeError:
+ return False
+
+
+def _is_annotated_type(annotation: Any) -> bool:
+ return get_origin(annotation) == typing.Annotated
+
+
+def _unwrap_optional_annotation(annotation: Any) -> Any:
+ """Return the type anntation with None type removed, if it is present.
+
+ Args:
+ annotation (Any): type annotation
+
+ Returns:
+ Any: type annotation without None type
+ """
+ if _is_optional_annotation(annotation):
+ args = get_args(annotation)
+ return Union[tuple(arg for arg in args if arg is not type(None))] # noqa
+ return annotation
+
+
+def _unwrap_annotated(annotation: Any) -> Any:
+ # Convert Annotated[T, ...] to T
+ return get_args(annotation)[0]
+
+
+def _unwrap_type_annotation(annotation: Any) -> Any:
+ """
+ recursively unwrap type annotations, removing Annotated and Optional types
+ """
+ if _is_annotated_type(annotation):
+ res = _unwrap_annotated(annotation)
+ return _unwrap_type_annotation(res)
+ elif _is_optional_annotation(annotation):
+ res = _unwrap_optional_annotation(annotation)
+ return _unwrap_type_annotation(res)
+ return annotation
+
+
+def _annotation_issubclass(annotation: Any, cls: type) -> bool:
+ # issubclass throws TypeError if annotation is not a valid type (eg Union)
+ try:
+ return issubclass(annotation, cls)
+ except TypeError:
+ return False
+
+
+def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]:
+ canonical_name, version = SyftObjectRegistry.get_canonical_name_version(obj)
+ serde_attributes = SyftObjectRegistry.get_serde_properties(canonical_name, version)
+ exclude_attrs = serde_attributes[4]
+
+ result: dict[str, Json] = {
+ JSON_CANONICAL_NAME_FIELD: canonical_name,
+ JSON_VERSION_FIELD: version,
+ }
+
+ all_exclude_attrs = set(exclude_attrs) | DEFAULT_EXCLUDE_ATTRS
+
+ for key, type_ in obj.model_fields.items():
+ if key in all_exclude_attrs:
+ continue
+ result[key] = serialize_json(getattr(obj, key), type_.annotation)
+
+ result = _add_searchable_and_unique_attrs(obj, result, raise_errors=False)
+
+ return result
+
+
+def get_property_return_type(obj: Any, attr_name: str) -> Any:
+ """
+ Get the return type annotation of a @property.
+ """
+ cls = type(obj)
+ attr = getattr(cls, attr_name, None)
+
+ if isinstance(attr, property):
+ return attr.fget.__annotations__.get("return", None)
+
+ return None
+
+
+def _add_searchable_and_unique_attrs(
+ obj: pydantic.BaseModel, obj_dict: dict[str, Json], raise_errors: bool = True
+) -> dict[str, Json]:
+ """
+ Add searchable attrs and unique attrs to the serialized object dict, if they are not already present.
+ Needed for adding non-field attributes (like @property)
+
+ Args:
+ obj (pydantic.BaseModel): Object to serialize.
+ obj_dict (dict[str, Json]): Serialized object dict. Should contain the object's fields.
+ raise_errors (bool, optional): Raise errors if an attribute cannot be accessed.
+ If False, the attribute will be skipped. Defaults to True.
+
+ Raises:
+ Exception: Any exception raised when accessing an attribute.
+
+ Returns:
+ dict[str, Json]: Serialized object dict including searchable attributes.
+ """
+ searchable_attrs: list[str] = getattr(obj, "__attr_searchable__", [])
+ unique_attrs: list[str] = getattr(obj, "__attr_unique__", [])
+
+ attrs_to_add = set(searchable_attrs) | set(unique_attrs)
+ for attr in attrs_to_add:
+ if attr not in obj_dict:
+ try:
+ value = getattr(obj, attr)
+ except Exception as e:
+ if raise_errors:
+ raise e
+ else:
+ continue
+ property_annotation = get_property_return_type(obj, attr)
+ obj_dict[attr] = serialize_json(
+ value, validate=False, annotation=property_annotation
+ )
+
+ return obj_dict
+
+
+def _deserialize_pydantic_from_json(
+ obj_dict: dict[str, Json],
+) -> pydantic.BaseModel:
+ try:
+ canonical_name = obj_dict[JSON_CANONICAL_NAME_FIELD]
+ version = obj_dict[JSON_VERSION_FIELD]
+ obj_type = SyftObjectRegistry.get_serde_class(canonical_name, version)
+
+ result = {}
+ for key, type_ in obj_type.model_fields.items():
+ if key not in obj_dict:
+ continue
+ result[key] = deserialize_json(obj_dict[key], type_.annotation)
+
+ return obj_type.model_validate(result)
+ except Exception as e:
+ print(f"Failed to deserialize Pydantic model: {e}")
+ print(json.dumps(obj_dict, indent=2))
+ raise ValueError(f"Failed to deserialize Pydantic model: {e}")
+
+
+def _is_serializable_iterable(annotation: Any) -> bool:
+ # we can only serialize typed iterables without Union/Any
+ # NOTE optional is allowed
+
+ # 1. check if it is an iterable
+ if get_origin(annotation) not in {list, tuple, set, frozenset}:
+ return False
+
+ # 2. check if iterable annotation is serializable
+ args = get_args(annotation)
+ if len(args) != 1:
+ return False
+
+ inner_type = _unwrap_type_annotation(args[0])
+ return inner_type in JSON_SERDE_REGISTRY or _annotation_issubclass(
+ inner_type, pydantic.BaseModel
+ )
+
+
+def _serialize_iterable_to_json(value: Any, annotation: Any) -> Json:
+ # No need to validate in recursive calls
+ return [serialize_json(v, validate=False) for v in value]
+
+
+def _deserialize_iterable_from_json(value: Json, annotation: Any) -> Any:
+ if not isinstance(value, list):
+ raise ValueError(f"Cannot deserialize {type(value)} to {annotation}")
+
+ annotation = _unwrap_type_annotation(annotation)
+
+ if not _is_serializable_iterable(annotation):
+ raise ValueError(f"Cannot deserialize {annotation} from JSON")
+
+ inner_type = _unwrap_type_annotation(get_args(annotation)[0])
+ return [deserialize_json(v, inner_type) for v in value]
+
+
+def _is_serializable_mapping(annotation: Any) -> bool:
+ """
+ Mapping is serializable if:
+ - it is a dict
+ - the key type is str
+ - the value type is serializable and not a Union
+ """
+ if get_origin(annotation) != dict:
+ return False
+
+ args = get_args(annotation)
+ if len(args) != 2:
+ return False
+
+ key_type, value_type = args
+ # JSON only allows string keys
+ if not isinstance(key_type, str):
+ return False
+
+ # check if value type is serializable
+ value_type = _unwrap_type_annotation(value_type)
+ return value_type in JSON_SERDE_REGISTRY or _annotation_issubclass(
+ value_type, pydantic.BaseModel
+ )
+
+
+def _serialize_mapping_to_json(value: Any, annotation: Any) -> Json:
+ _, value_type = get_args(annotation)
+ # No need to validate in recursive calls
+ return {k: serialize_json(v, value_type, validate=False) for k, v in value.items()}
+
+
+def _deserialize_mapping_from_json(value: Json, annotation: Any) -> Any:
+ if not isinstance(value, dict):
+ raise ValueError(f"Cannot deserialize {type(value)} to {annotation}")
+
+ annotation = _unwrap_type_annotation(annotation)
+
+ if not _is_serializable_mapping(annotation):
+ raise ValueError(f"Cannot deserialize {annotation} from JSON")
+
+ _, value_type = get_args(annotation)
+ return {k: deserialize_json(v, value_type) for k, v in value.items()}
+
+
+def _serialize_to_json_bytes(obj: Any) -> str:
+ obj_bytes = sy.serialize(obj, to_bytes=True)
+ return base64.b64encode(obj_bytes).decode("utf-8")
+
+
+def _deserialize_from_json_bytes(obj: str) -> Any:
+ obj_bytes = base64.b64decode(obj)
+ return sy.deserialize(obj_bytes, from_bytes=True)
+
+
+def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> Json:
+ """
+ Serialize a value to a JSON-serializable object, using the schema defined by the
+ provided annotation.
+
+ Serialization is always done according to the annotation, as the same annotation
+ is used for deserialization. If the annotation is not provided or is ambiguous,
+ the JSON serialization will fall back to serializing bytes. Examples:
+ - int, `list[int]` are strictly typed
+ - `str | int`, `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes
+ - Optional types (like int | None) are serialized to the not-None type
+
+ The function chooses the appropriate serialization method in the following order:
+ 1. Method registered in `JSON_SERDE_REGISTRY` for the annotation type.
+ 2. Pydantic model serialization, including all `SyftObjects`.
+ 3. Iterable serialization, if the annotation is a strict iterable (e.g., `list[int]`).
+ 4. Mapping serialization, if the annotation is a strictly typed mapping with string keys.
+ 5. Serialize the object to bytes and encode it as base64.
+
+ Args:
+ value (Any): Value to serialize.
+ annotation (Any, optional): Type annotation for the value. Defaults to None.
+
+ Returns:
+ Json: JSON-serializable object.
+ """
+ if annotation is None:
+ annotation = type(value)
+
+ if value is None:
+ return None
+
+ # Remove None type from annotation if it is present.
+ annotation = _unwrap_type_annotation(annotation)
+
+ if annotation in JSON_SERDE_REGISTRY:
+ result = JSON_SERDE_REGISTRY[annotation].serialize(value)
+ elif _annotation_issubclass(annotation, pydantic.BaseModel):
+ result = _serialize_pydantic_to_json(value)
+ elif _annotation_issubclass(annotation, Enum):
+ result = value.name
+
+ # JSON recursive types
+ # only strictly annotated iterables and mappings are supported
+ # example: list[int] is supported, but not list[int | str]
+ elif _is_serializable_iterable(annotation):
+ result = _serialize_iterable_to_json(value, annotation)
+ elif _is_serializable_mapping(annotation):
+ result = _serialize_mapping_to_json(value, annotation)
+ else:
+ result = _serialize_to_json_bytes(value)
+
+ if validate:
+ _validate_json(result)
+
+ return result
+
+
+def deserialize_json(value: Json, annotation: Any = None) -> Any:
+ """Deserialize a JSON-serializable object to a value, using the schema defined by the
+ provided annotation. Inverse of `serialize_json`.
+
+ Args:
+ value (Json): JSON-serializable object.
+ annotation (Any): Type annotation for the value.
+
+ Returns:
+ Any: Deserialized value.
+ """
+ if (
+ isinstance(value, dict)
+ and JSON_CANONICAL_NAME_FIELD in value
+ and JSON_VERSION_FIELD in value
+ ):
+ return _deserialize_pydantic_from_json(value)
+
+ if value is None:
+ return None
+
+ # Remove None type from annotation if it is present.
+ if annotation is None:
+ raise ValueError("Annotation is required for deserialization")
+
+ annotation = _unwrap_type_annotation(annotation)
+
+ if annotation in JSON_SERDE_REGISTRY:
+ return JSON_SERDE_REGISTRY[annotation].deserialize(value)
+ elif _annotation_issubclass(annotation, pydantic.BaseModel):
+ return _deserialize_pydantic_from_json(value)
+ elif _annotation_issubclass(annotation, Enum):
+ return annotation[value]
+ elif isinstance(value, list):
+ return _deserialize_iterable_from_json(value, annotation)
+ elif isinstance(value, dict):
+ return _deserialize_mapping_from_json(value, annotation)
+ elif isinstance(value, str):
+ return _deserialize_from_json_bytes(value)
+ else:
+ raise ValueError(f"Cannot deserialize {value} to {annotation}")
+
+
+def is_json_primitive(value: Any) -> bool:
+ serialized = serialize_json(value, validate=False)
+ return isinstance(serialized, JsonPrimitive) # type: ignore
diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py
index ed0379ae51b..d5694b4efe6 100644
--- a/packages/syft/src/syft/serde/recursive.py
+++ b/packages/syft/src/syft/serde/recursive.py
@@ -25,6 +25,7 @@
recursive_scheme = get_capnp_schema("recursive_serde.capnp").RecursiveSerde
SPOOLED_FILE_MAX_SIZE_SERDE = 50 * (1024**2) # 50MB
+DEFAULT_EXCLUDE_ATTRS: set[str] = {"syft_pre_hooks__", "syft_post_hooks__"}
def get_types(cls: type, keys: list[str] | None = None) -> list[type] | None:
@@ -192,9 +193,7 @@ def recursive_serde_register(
attribute_list.update(["value"])
exclude_attrs = [] if exclude_attrs is None else exclude_attrs
- attribute_list = (
- attribute_list - set(exclude_attrs) - {"syft_pre_hooks__", "syft_post_hooks__"}
- )
+ attribute_list = attribute_list - set(exclude_attrs) - DEFAULT_EXCLUDE_ATTRS
if inheritable_attrs and attribute_list and not is_pydantic:
# only set __syft_serializable__ for non-pydantic classes because
diff --git a/packages/syft/src/syft/serde/signature.py b/packages/syft/src/syft/serde/signature.py
index 23b0a556fca..0887d148367 100644
--- a/packages/syft/src/syft/serde/signature.py
+++ b/packages/syft/src/syft/serde/signature.py
@@ -86,6 +86,15 @@ def signature_remove_context(signature: Signature) -> Signature:
)
+def signature_remove(signature: Signature, args: list[str]) -> Signature:
+ params = dict(signature.parameters)
+ for arg in args:
+ params.pop(arg, None)
+ return Signature(
+ list(params.values()), return_annotation=signature.return_annotation
+ )
+
+
def get_str_signature_from_docstring(doc: str, callable_name: str) -> str | None:
if not doc or callable_name not in doc:
return None
diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py
index 89cc5ffaab5..6fadf6261f9 100644
--- a/packages/syft/src/syft/serde/third_party.py
+++ b/packages/syft/src/syft/serde/third_party.py
@@ -18,7 +18,6 @@
import pyarrow.parquet as pq
import pydantic
from pydantic._internal._model_construction import ModelMetaclass
-from pymongo.collection import Collection
# relative
from ..types.dicttuple import DictTuple
@@ -58,11 +57,6 @@
# exceptions
recursive_serde_register(cls=TypeError, canonical_name="TypeError", version=1)
-# mongo collection
-recursive_serde_register_type(
- Collection, canonical_name="pymongo_collection", version=1
-)
-
def serialize_dataframe(df: DataFrame) -> bytes:
table = pa.Table.from_pandas(df)
diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py
index 1285e787428..c3d25b10440 100644
--- a/packages/syft/src/syft/server/server.py
+++ b/packages/syft/src/syft/server/server.py
@@ -18,6 +18,7 @@
from time import sleep
import traceback
from typing import Any
+from typing import TypeVar
from typing import cast
# third party
@@ -39,10 +40,6 @@
from ..protocol.data_protocol import get_data_protocol
from ..service.action.action_object import Action
from ..service.action.action_object import ActionObject
-from ..service.action.action_store import ActionStore
-from ..service.action.action_store import DictActionStore
-from ..service.action.action_store import MongoActionStore
-from ..service.action.action_store import SQLiteActionStore
from ..service.code.user_code_stash import UserCodeStash
from ..service.context import AuthedServiceContext
from ..service.context import ServerServiceContext
@@ -55,6 +52,7 @@
from ..service.metadata.server_metadata import ServerMetadata
from ..service.network.utils import PeerHealthCheckTask
from ..service.notifier.notifier_service import NotifierService
+from ..service.output.output_service import OutputStash
from ..service.queue.base_queue import AbstractMessageHandler
from ..service.queue.base_queue import QueueConsumer
from ..service.queue.base_queue import QueueProducer
@@ -75,12 +73,10 @@
from ..service.service import UserServiceConfigRegistry
from ..service.settings.settings import ServerSettings
from ..service.settings.settings import ServerSettingsUpdate
-from ..service.settings.settings_stash import SettingsStash
from ..service.user.user import User
from ..service.user.user import UserCreate
from ..service.user.user import UserView
from ..service.user.user_roles import ServiceRole
-from ..service.user.user_stash import UserStash
from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG
from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME
from ..service.worker.utils import create_default_image
@@ -92,12 +88,17 @@
from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig
from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig
from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit
-from ..store.dict_document_store import DictStoreConfig
+from ..store.db.db import DBConfig
+from ..store.db.db import DBManager
+from ..store.db.postgres import PostgresDBConfig
+from ..store.db.postgres import PostgresDBManager
+from ..store.db.sqlite import SQLiteDBConfig
+from ..store.db.sqlite import SQLiteDBManager
+from ..store.db.stash import ObjectStash
from ..store.document_store import StoreConfig
from ..store.document_store_errors import NotFoundException
from ..store.document_store_errors import StashException
from ..store.linked_obj import LinkedObject
-from ..store.mongo_document_store import MongoStoreConfig
from ..store.sqlite_document_store import SQLiteStoreClientConfig
from ..store.sqlite_document_store import SQLiteStoreConfig
from ..types.datetime import DATETIME_FORMAT
@@ -128,6 +129,8 @@
logger = logging.getLogger(__name__)
+SyftT = TypeVar("SyftT", bound=SyftObject)
+
# if user code needs to be serded and its not available we can call this to refresh
# the code for a specific server UID and thread
CODE_RELOADER: dict[int, Callable] = {}
@@ -307,6 +310,7 @@ def __init__(
signing_key: SyftSigningKey | SigningKey | None = None,
action_store_config: StoreConfig | None = None,
document_store_config: StoreConfig | None = None,
+ db_config: DBConfig | None = None,
root_email: str | None = default_root_email,
root_username: str | None = default_root_username,
root_password: str | None = default_root_password,
@@ -336,6 +340,7 @@ def __init__(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
consumer_type: ConsumerType | None = None,
+ db_url: str | None = None,
):
# 🟡 TODO 22: change our ENV variable format and default init args to make this
# less horrible or add some convenience functions
@@ -351,6 +356,7 @@ def __init__(
self.server_side_type = ServerSideType(server_side_type)
self.client_cache: dict = {}
self.peer_client_cache: dict = {}
+ self._settings = None
if isinstance(server_type, str):
server_type = ServerType(server_type)
@@ -396,27 +402,33 @@ def __init__(
if reset:
self.remove_temp_dir()
- use_sqlite = local_db or (processes > 0 and not is_subprocess)
document_store_config = document_store_config or self.get_default_store(
- use_sqlite=use_sqlite,
store_type="Document Store",
)
action_store_config = action_store_config or self.get_default_store(
- use_sqlite=use_sqlite,
store_type="Action Store",
)
- self.init_stores(
- action_store_config=action_store_config,
- document_store_config=document_store_config,
- )
+ db_config = DBConfig.from_connection_string(db_url) if db_url else db_config
+
+ if db_config is None:
+ db_config = SQLiteDBConfig(
+ filename=f"{self.id}_json.db",
+ path=self.get_temp_dir("db"),
+ )
+
+ self.db_config = db_config
+
+ self.db = self.init_stores(db_config=self.db_config)
# construct services only after init stores
self.services: ServiceRegistry = ServiceRegistry.for_server(self)
+ self.db.init_tables(reset=reset)
+ self.action_store = self.services.action.stash
- create_admin_new( # nosec B106
+ create_root_admin_if_not_exists(
name=root_username,
email=root_email,
- password=root_password,
+ password=root_password, # nosec
server=self,
)
@@ -520,21 +532,19 @@ def runs_in_docker(self) -> bool:
and any("docker" in line for line in open(path))
)
- def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig:
- if use_sqlite:
- path = self.get_temp_dir("db")
- file_name: str = f"{self.id}.sqlite"
- if self.dev_mode:
- # leave this until the logger shows this in the notebook
- print(f"{store_type}'s SQLite DB path: {path/file_name}")
- logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}")
- return SQLiteStoreConfig(
- client_config=SQLiteStoreClientConfig(
- filename=file_name,
- path=path,
- )
+ def get_default_store(self, store_type: str) -> StoreConfig:
+ path = self.get_temp_dir("db")
+ file_name: str = f"{self.id}.sqlite"
+ # if self.dev_mode:
+ # leave this until the logger shows this in the notebook
+ # print(f"{store_type}'s SQLite DB path: {path/file_name}")
+ # logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}")
+ return SQLiteStoreConfig(
+ client_config=SQLiteStoreClientConfig(
+ filename=file_name,
+ path=path,
)
- return DictStoreConfig()
+ )
def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None:
if config is None:
@@ -648,6 +658,7 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None:
worker_stash=self.worker_stash,
)
producer.run()
+
address = producer.address
else:
port = queue_config.client_config.queue_port
@@ -754,6 +765,8 @@ def named(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
consumer_type: ConsumerType | None = None,
+ db_url: str | None = None,
+ db_config: DBConfig | None = None,
) -> Server:
uid = get_named_server_uid(name)
name_hash = hashlib.sha256(name.encode("utf8")).digest()
@@ -785,6 +798,8 @@ def named(
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
consumer_type=consumer_type,
+ db_url=db_url,
+ db_config=db_config,
)
def is_root(self, credentials: SyftVerifyKey) -> bool:
@@ -906,59 +921,36 @@ def reload_user_code() -> None:
if ti is not None:
CODE_RELOADER[ti] = reload_user_code
- def init_stores(
- self,
- document_store_config: StoreConfig,
- action_store_config: StoreConfig,
- ) -> None:
- # We add the python id of the current server in order
- # to create one connection per Server object in MongoClientCache
- # so that we avoid closing the connection from a
- # different thread through the garbage collection
- if isinstance(document_store_config, MongoStoreConfig):
- document_store_config.client_config.server_obj_python_id = id(self)
-
- self.document_store_config = document_store_config
- self.document_store = document_store_config.store_type(
- server_uid=self.id,
- root_verify_key=self.verify_key,
- store_config=document_store_config,
- )
-
- if isinstance(action_store_config, SQLiteStoreConfig):
- self.action_store: ActionStore = SQLiteActionStore(
+ def init_stores(self, db_config: DBConfig) -> DBManager:
+ if isinstance(db_config, SQLiteDBConfig):
+ db = SQLiteDBManager(
+ config=db_config,
server_uid=self.id,
- store_config=action_store_config,
root_verify_key=self.verify_key,
- document_store=self.document_store,
)
- elif isinstance(action_store_config, MongoStoreConfig):
- # We add the python id of the current server in order
- # to create one connection per Server object in MongoClientCache
- # so that we avoid closing the connection from a
- # different thread through the garbage collection
- action_store_config.client_config.server_obj_python_id = id(self)
-
- self.action_store = MongoActionStore(
+ elif isinstance(db_config, PostgresDBConfig):
+ db = PostgresDBManager( # type: ignore
+ config=db_config,
server_uid=self.id,
root_verify_key=self.verify_key,
- store_config=action_store_config,
- document_store=self.document_store,
)
else:
- self.action_store = DictActionStore(
- server_uid=self.id,
- root_verify_key=self.verify_key,
- document_store=self.document_store,
- )
+ raise SyftException(public_message=f"Unsupported DB config: {db_config}")
+
+ self.queue_stash = QueueStash(store=db)
- self.action_store_config = action_store_config
- self.queue_stash = QueueStash(store=self.document_store)
+ print(f"Using {db_config.__class__.__name__} and {db_config.connection_string}")
+
+ return db
@property
def job_stash(self) -> JobStash:
return self.services.job.stash
+ @property
+ def output_stash(self) -> OutputStash:
+ return self.services.output.stash
+
@property
def worker_stash(self) -> WorkerStash:
return self.services.worker.stash
@@ -979,6 +971,12 @@ def get_service_method(self, path_or_func: str | Callable) -> Callable:
def get_service(self, path_or_func: str | Callable) -> AbstractService:
return self.services.get_service(path_or_func)
+ @as_result(ValueError)
+ def get_stash(self, object_type: SyftT) -> ObjectStash[SyftT]:
+ if object_type not in self.services.stashes:
+ raise ValueError(f"Stash for {object_type} not found.")
+ return self.services.stashes[object_type]
+
def _get_service_method_from_path(self, path: str) -> Callable:
path_list = path.split(".")
method_name = path_list.pop()
@@ -1016,10 +1014,12 @@ def update_self(self, settings: ServerSettings) -> None:
# it should be removed once the settings are refactored and the inconsistencies between
# settings and services are resolved.
def get_settings(self) -> ServerSettings | None:
+ if self._settings:
+ return self._settings # type: ignore
if self.signing_key is None:
raise ValueError(f"{self} has no signing key")
- settings_stash = SettingsStash(store=self.document_store)
+ settings_stash = self.services.settings.stash
try:
settings = settings_stash.get_all(self.signing_key.verify_key).unwrap()
@@ -1027,6 +1027,7 @@ def get_settings(self) -> ServerSettings | None:
if len(settings) > 0:
setting = settings[0]
self.update_self(setting)
+ self._settings = setting
return setting
else:
return None
@@ -1039,7 +1040,7 @@ def settings(self) -> ServerSettings:
if self.signing_key is None:
raise ValueError(f"{self} has no signing key")
- settings_stash = SettingsStash(store=self.document_store)
+ settings_stash = self.services.settings.stash
error_msg = f"Cannot get server settings for '{self.name}'"
all_settings = settings_stash.get_all(self.signing_key.verify_key).unwrap(
@@ -1479,7 +1480,9 @@ def add_queueitem_to_queue(
result_obj.syft_server_location = self.id
result_obj.syft_client_verify_key = credentials
- if not self.services.action.store.exists(uid=action.result_id):
+ if not self.services.action.stash.exists(
+ credentials=credentials, uid=action.result_id
+ ):
self.services.action.set_result_to_store(
result_action_object=result_obj,
context=context,
@@ -1687,7 +1690,7 @@ def get_unauthed_context(
@as_result(SyftException, StashException)
def create_initial_settings(self, admin_email: str) -> ServerSettings:
- settings_stash = SettingsStash(store=self.document_store)
+ settings_stash = self.services.settings.stash
if self.signing_key is None:
logger.debug("create_initial_settings failed as there is no signing key")
@@ -1741,17 +1744,32 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings:
).unwrap()
-def create_admin_new(
+def create_root_admin_if_not_exists(
name: str,
email: str,
password: str,
- server: AbstractServer,
+ server: Server,
) -> User | None:
- user_stash = UserStash(store=server.document_store)
+ """
+ If no root admin exists:
+ - all exists checks on the user stash will fail, as we cannot get the role for the admin to check if it exists
+ - result: a new admin is always created
+
+ If a root admin exists with a different email:
+ - cause: DEFAULT_USER_EMAIL env variable is set to a different email than the root admin in the db
+ - verify_key_exists will return True
+ - result: no new admin is created, as the server already has a root admin
+ """
+ user_stash = server.services.user.stash
+
+ email_exists = user_stash.email_exists(email=email).unwrap()
+ if email_exists:
+ logger.debug("Admin not created, a user with this email already exists")
+ return None
- user_exists = user_stash.email_exists(email=email).unwrap()
- if user_exists:
- logger.debug("Admin not created, admin already exists")
+ verify_key_exists = user_stash.verify_key_exists(server.verify_key).unwrap()
+ if verify_key_exists:
+ logger.debug("Admin not created, this server already has a root admin")
return None
create_user = UserCreate(
@@ -1766,12 +1784,12 @@ def create_admin_new(
# 🟡 TODO: change later but for now this gives the main user super user automatically
user = create_user.to(User)
user.signing_key = server.signing_key
- user.verify_key = user.signing_key.verify_key
+ user.verify_key = server.verify_key
new_user = user_stash.set(
- credentials=server.signing_key.verify_key,
+ credentials=server.verify_key,
obj=user,
- ignore_duplicates=True,
+ ignore_duplicates=False,
).unwrap()
logger.debug(f"Created admin {new_user.email}")
diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py
index d7c3555f10c..dfb7f331972 100644
--- a/packages/syft/src/syft/server/service_registry.py
+++ b/packages/syft/src/syft/server/service_registry.py
@@ -3,15 +3,15 @@
from dataclasses import dataclass
from dataclasses import field
import typing
-from typing import Any
from typing import TYPE_CHECKING
+from typing import TypeVar
# relative
from ..serde.serializable import serializable
from ..service.action.action_service import ActionService
-from ..service.action.action_store import ActionStore
from ..service.api.api_service import APIService
from ..service.attestation.attestation_service import AttestationService
+from ..service.blob_storage.remote_profile import RemoteProfileService
from ..service.blob_storage.service import BlobStorageService
from ..service.code.status_service import UserCodeStatusService
from ..service.code.user_code_service import UserCodeService
@@ -40,12 +40,17 @@
from ..service.worker.worker_image_service import SyftWorkerImageService
from ..service.worker.worker_pool_service import SyftWorkerPoolService
from ..service.worker.worker_service import WorkerService
+from ..store.db.stash import ObjectStash
+from ..types.syft_object import SyftObject
if TYPE_CHECKING:
# relative
from .server import Server
+StashT = TypeVar("StashT", bound=SyftObject)
+
+
@serializable(canonical_name="ServiceRegistry", version=1)
@dataclass
class ServiceRegistry:
@@ -79,11 +84,13 @@ class ServiceRegistry:
sync: SyncService
output: OutputService
user_code_status: UserCodeStatusService
+ remote_profile: RemoteProfileService
services: list[AbstractService] = field(default_factory=list, init=False)
service_path_map: dict[str, AbstractService] = field(
default_factory=dict, init=False
)
+ stashes: dict[StashT, ObjectStash[StashT]] = field(default_factory=dict, init=False)
@classmethod
def for_server(cls, server: "Server") -> "ServiceRegistry":
@@ -95,6 +102,11 @@ def __post_init__(self) -> None:
self.services.append(service)
self.service_path_map[service_cls.__name__.lower()] = service
+ # TODO ActionService now has same stash, but interface is still different. Fix this.
+ if hasattr(service, "stash") and not issubclass(service_cls, ActionService):
+ stash: ObjectStash = service.stash
+ self.stashes[stash.object_type] = stash
+
@classmethod
def get_service_classes(
cls,
@@ -109,13 +121,7 @@ def get_service_classes(
def _construct_services(cls, server: "Server") -> dict[str, AbstractService]:
service_dict = {}
for field_name, service_cls in cls.get_service_classes().items():
- svc_kwargs: dict[str, Any] = {}
- if issubclass(service_cls.store_type, ActionStore):
- svc_kwargs["store"] = server.action_store
- else:
- svc_kwargs["store"] = server.document_store
-
- service = service_cls(**svc_kwargs)
+ service = service_cls(store=server.db) # type: ignore
service_dict[field_name] = service
return service_dict
@@ -133,3 +139,6 @@ def _get_service_from_path(self, path: str) -> AbstractService:
return self.service_path_map[service_name.lower()]
except KeyError:
raise ValueError(f"Service {path} not found.")
+
+ def __iter__(self) -> typing.Iterator[AbstractService]:
+ return iter(self.services)
diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py
index 80f15a6d5ba..e1982953a32 100644
--- a/packages/syft/src/syft/server/uvicorn.py
+++ b/packages/syft/src/syft/server/uvicorn.py
@@ -1,6 +1,8 @@
# stdlib
from collections.abc import Callable
from contextlib import asynccontextmanager
+import json
+import logging
import multiprocessing
import multiprocessing.synchronize
import os
@@ -25,6 +27,7 @@
from ..abstract_server import ServerSideType
from ..client.client import API_PATH
from ..deployment_type import DeploymentType
+from ..store.db.db import DBConfig
from ..util.autoreload import enable_autoreload
from ..util.constants import DEFAULT_TIMEOUT
from ..util.telemetry import TRACING_ENABLED
@@ -46,6 +49,9 @@
WAIT_TIME_SECONDS = 20
+logger = logging.getLogger("uvicorn")
+
+
class AppSettings(BaseSettings):
name: str
server_type: ServerType = ServerType.DATASITE
@@ -61,6 +67,8 @@ class AppSettings(BaseSettings):
n_consumers: int = 0
association_request_auto_approval: bool = False
background_tasks: bool = False
+ db_config: DBConfig | None = None
+ db_url: str | None = None
model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")
@@ -91,6 +99,10 @@ def app_factory() -> FastAPI:
worker_class = worker_classes[settings.server_type]
kwargs = settings.model_dump()
+
+ logger.info(
+ f"Starting server with settings: {kwargs} and worker class: {worker_class}"
+ )
if settings.dev_mode:
print(
f"WARN: private key is based on server name: {settings.name} in dev_mode. "
@@ -175,6 +187,8 @@ def run_uvicorn(
env_prefix = AppSettings.model_config.get("env_prefix", "")
for key, value in kwargs.items():
key_with_prefix = f"{env_prefix}{key.upper()}"
+ if isinstance(value, dict):
+ value = json.dumps(value)
os.environ[key_with_prefix] = str(value)
# The `serve_server` function calls `run_uvicorn` in a separate process using `multiprocessing.Process`.
@@ -220,6 +234,7 @@ def serve_server(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
debug: bool = False,
+ db_url: str | None = None,
) -> tuple[Callable, Callable]:
starting_uvicorn_event = multiprocessing.Event()
@@ -249,6 +264,7 @@ def serve_server(
"debug": debug,
"starting_uvicorn_event": starting_uvicorn_event,
"deployment_type": deployment_type,
+ "db_url": db_url,
},
)
diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py
index 57a69d2a4eb..3e10cc7d5fa 100644
--- a/packages/syft/src/syft/server/worker_settings.py
+++ b/packages/syft/src/syft/server/worker_settings.py
@@ -1,6 +1,9 @@
# future
from __future__ import annotations
+# stdlib
+from collections.abc import Callable
+
# third party
from typing_extensions import Self
@@ -13,16 +16,21 @@
from ..server.credentials import SyftSigningKey
from ..service.queue.base_queue import QueueConfig
from ..store.blob_storage import BlobStorageConfig
+from ..store.db.db import DBConfig
from ..store.document_store import StoreConfig
+from ..types.syft_migration import migrate
from ..types.syft_object import SYFT_OBJECT_VERSION_1
+from ..types.syft_object import SYFT_OBJECT_VERSION_2
from ..types.syft_object import SyftObject
+from ..types.transforms import TransformContext
+from ..types.transforms import drop
from ..types.uid import UID
@serializable()
class WorkerSettings(SyftObject):
__canonical_name__ = "WorkerSettings"
- __version__ = SYFT_OBJECT_VERSION_1
+ __version__ = SYFT_OBJECT_VERSION_2
id: UID
name: str
@@ -30,28 +38,58 @@ class WorkerSettings(SyftObject):
server_side_type: ServerSideType
deployment_type: DeploymentType = DeploymentType.REMOTE
signing_key: SyftSigningKey
- document_store_config: StoreConfig
- action_store_config: StoreConfig
+ db_config: DBConfig
blob_store_config: BlobStorageConfig | None = None
queue_config: QueueConfig | None = None
log_level: int | None = None
@classmethod
def from_server(cls, server: AbstractServer) -> Self:
- if server.server_side_type:
- server_side_type: str = server.server_side_type.value
- else:
- server_side_type = ServerSideType.HIGH_SIDE
+ server_side_type = server.server_side_type or ServerSideType.HIGH_SIDE
return cls(
id=server.id,
name=server.name,
server_type=server.server_type,
signing_key=server.signing_key,
- document_store_config=server.document_store_config,
- action_store_config=server.action_store_config,
+ db_config=server.db_config,
server_side_type=server_side_type,
blob_store_config=server.blob_store_config,
queue_config=server.queue_config,
log_level=server.log_level,
deployment_type=server.deployment_type,
)
+
+
+@serializable()
+class WorkerSettingsV1(SyftObject):
+ __canonical_name__ = "WorkerSettings"
+ __version__ = SYFT_OBJECT_VERSION_1
+
+ id: UID
+ name: str
+ server_type: ServerType
+ server_side_type: ServerSideType
+ deployment_type: DeploymentType = DeploymentType.REMOTE
+ signing_key: SyftSigningKey
+ document_store_config: StoreConfig
+ action_store_config: StoreConfig
+ blob_store_config: BlobStorageConfig | None = None
+ queue_config: QueueConfig | None = None
+ log_level: int | None = None
+
+
+def set_db_config(context: TransformContext) -> TransformContext:
+ if context.output:
+ context.output["db_config"] = (
+ context.server.db_config if context.server is not None else DBConfig()
+ )
+ return context
+
+
+@migrate(WorkerSettingsV1, WorkerSettings)
+def migrate_workersettings_v1_to_v2() -> list[Callable]:
+ return [
+ drop("document_store_config"),
+ drop("action_store_config"),
+ set_db_config,
+ ]
diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py
index d7e566b733a..275767665c8 100644
--- a/packages/syft/src/syft/service/action/action_object.py
+++ b/packages/syft/src/syft/service/action/action_object.py
@@ -254,8 +254,8 @@ class ActionObjectPointer:
"created_date", # syft
"updated_date", # syft
"deleted_date", # syft
- "to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs
"__attr_searchable__", # syft
+ "__attr_unique__", # syft
"__canonical_name__", # syft
"__version__", # syft
"__args__", # pydantic
diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py
index 03992eeab07..ab6f9b7ce9a 100644
--- a/packages/syft/src/syft/service/action/action_permissions.py
+++ b/packages/syft/src/syft/service/action/action_permissions.py
@@ -17,12 +17,29 @@ class ActionPermission(Enum):
ALL_WRITE = 32
EXECUTE = 64
ALL_EXECUTE = 128
+ ALL_OWNER = 256
+
+ @property
+ def as_compound(self) -> "ActionPermission":
+ if self in COMPOUND_ACTION_PERMISSION:
+ return self
+ elif self == ActionPermission.READ:
+ return ActionPermission.ALL_READ
+ elif self == ActionPermission.WRITE:
+ return ActionPermission.ALL_WRITE
+ elif self == ActionPermission.EXECUTE:
+ return ActionPermission.ALL_EXECUTE
+ elif self == ActionPermission.OWNER:
+ return ActionPermission.ALL_OWNER
+ else:
+ raise Exception(f"Invalid compound permission {self}")
COMPOUND_ACTION_PERMISSION = {
ActionPermission.ALL_READ,
ActionPermission.ALL_WRITE,
ActionPermission.ALL_EXECUTE,
+ ActionPermission.ALL_OWNER,
}
@@ -64,6 +81,10 @@ def permission_string(self) -> str:
return f"{self.credentials.verify}_{self.permission.name}"
return f"{self.permission.name}"
+ @property
+ def compound_permission_string(self) -> str:
+ return self.permission.as_compound.name
+
def _coll_repr_(self) -> dict[str, Any]:
return {
"uid": str(self.uid),
@@ -122,3 +143,7 @@ def _coll_repr_(self) -> dict[str, Any]:
"uid": str(self.uid),
"server_uid": str(self.server_uid),
}
+
+ @property
+ def permission_string(self) -> str:
+ return str(self.server_uid)
diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py
index 94935631f01..bd871a18164 100644
--- a/packages/syft/src/syft/service/action/action_service.py
+++ b/packages/syft/src/syft/service/action/action_service.py
@@ -9,6 +9,7 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.datetime import DateTime
@@ -43,8 +44,8 @@
from .action_permissions import ActionObjectPermission
from .action_permissions import ActionObjectREAD
from .action_permissions import ActionPermission
-from .action_store import ActionStore
-from .action_store import KeyValueActionStore
+from .action_permissions import StoragePermission
+from .action_store import ActionObjectStash
from .action_types import action_type_for_type
from .numpy import NumpyArrayObject
from .pandas import PandasDataFrameObject # noqa: F401
@@ -55,10 +56,10 @@
@serializable(canonical_name="ActionService", version=1)
class ActionService(AbstractService):
- store_type = ActionStore
+ stash: ActionObjectStash
- def __init__(self, store: KeyValueActionStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
+ self.stash = ActionObjectStash(store)
@service_method(path="action.np_array", name="np_array")
def np_array(self, context: AuthedServiceContext, data: Any) -> Any:
@@ -178,7 +179,7 @@ def _set(
or has_result_read_permission
)
- self.store.set(
+ self.stash.set_or_update(
uid=action_object.id,
credentials=context.credentials,
syft_object=action_object,
@@ -237,7 +238,7 @@ def resolve_links(
) -> ActionObject:
"""Get an object from the action store"""
# If user has permission to get the object / object exists
- result = self.store.get(uid=uid, credentials=context.credentials).unwrap()
+ result = self.stash.get(uid=uid, credentials=context.credentials).unwrap()
# If it's not a leaf
if result.is_link:
@@ -271,7 +272,7 @@ def _get(
resolve_nested: bool = True,
) -> ActionObject | TwinObject:
"""Get an object from the action store"""
- obj = self.store.get(
+ obj = self.stash.get(
uid=uid, credentials=context.credentials, has_permission=has_permission
).unwrap()
@@ -314,7 +315,7 @@ def get_pointer(
self, context: AuthedServiceContext, uid: UID
) -> ActionObjectPointer:
"""Get a pointer from the action store"""
- obj = self.store.get_pointer(
+ obj = self.stash.get_pointer(
uid=uid, credentials=context.credentials, server_uid=context.server.id
).unwrap()
@@ -328,7 +329,7 @@ def get_pointer(
@service_method(path="action.get_mock", name="get_mock", roles=GUEST_ROLE_LEVEL)
def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject:
"""Get a pointer from the action store"""
- return self.store.get_mock(uid=uid).unwrap()
+ return self.stash.get_mock(credentials=context.credentials, uid=uid).unwrap()
@service_method(
path="action.has_storage_permission",
@@ -336,10 +337,12 @@ def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject:
roles=GUEST_ROLE_LEVEL,
)
def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> bool:
- return self.store.has_storage_permission(uid)
+ return self.stash.has_storage_permission(
+ StoragePermission(uid=uid, server_uid=context.server.id)
+ )
def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool:
- return self.store.has_permissions(
+ return self.stash.has_permissions(
[ActionObjectREAD(uid=uid, credentials=context.credentials)]
)
@@ -558,7 +561,7 @@ def blob_permission(
if len(output_readers) > 0:
store_permissions = [store_permission(x) for x in output_readers]
- self.store.add_permissions(store_permissions)
+ self.stash.add_permissions(store_permissions)
if result_blob_id is not None:
blob_permissions = [blob_permission(x) for x in output_readers]
@@ -880,12 +883,12 @@ def has_read_permission_for_action_result(
ActionObjectREAD(uid=_id, credentials=context.credentials)
for _id in action_obj_ids
]
- return self.store.has_permissions(permissions)
+ return self.stash.has_permissions(permissions)
@service_method(path="action.exists", name="exists", roles=GUEST_ROLE_LEVEL)
def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool:
"""Checks if the given object id exists in the Action Store"""
- return self.store.exists(obj_id)
+ return self.stash.exists(context.credentials, obj_id)
@service_method(
path="action.delete",
@@ -896,7 +899,7 @@ def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool:
def delete(
self, context: AuthedServiceContext, uid: UID, soft_delete: bool = False
) -> SyftSuccess:
- obj = self.store.get(uid=uid, credentials=context.credentials).unwrap()
+ obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap()
return_msg = []
@@ -952,7 +955,7 @@ def _delete_from_action_store(
soft_delete: bool = False,
) -> SyftSuccess:
if soft_delete:
- obj = self.store.get(uid=uid, credentials=context.credentials).unwrap()
+ obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap()
if isinstance(obj, TwinObject):
self._soft_delete_action_obj(
@@ -964,7 +967,7 @@ def _delete_from_action_store(
if isinstance(obj, ActionObject):
self._soft_delete_action_obj(context=context, action_obj=obj).unwrap()
else:
- self.store.delete(credentials=context.credentials, uid=uid).unwrap()
+ self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap()
return SyftSuccess(message=f"Action object with uid '{uid}' deleted.")
diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py
index 0228165edb9..e6597d19d25 100644
--- a/packages/syft/src/syft/service/action/action_store.py
+++ b/packages/syft/src/syft/service/action/action_store.py
@@ -1,128 +1,53 @@
# future
from __future__ import annotations
-# stdlib
-import threading
-
# relative
from ...serde.serializable import serializable
-from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
-from ...store.dict_document_store import DictStoreConfig
-from ...store.document_store import BasePartitionSettings
-from ...store.document_store import DocumentStore
-from ...store.document_store import StoreConfig
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
-from ...store.document_store_errors import ObjectCRUDPermissionException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.syft_object import SyftObject
from ...types.twin_object import TwinObject
-from ...types.uid import LineageID
from ...types.uid import UID
+from .action_object import ActionObject
from .action_object import is_action_data_empty
from .action_permissions import ActionObjectEXECUTE
-from .action_permissions import ActionObjectOWNER
from .action_permissions import ActionObjectPermission
from .action_permissions import ActionObjectREAD
from .action_permissions import ActionObjectWRITE
-from .action_permissions import ActionPermission
from .action_permissions import StoragePermission
-lock = threading.RLock()
-
-
-class ActionStore:
- pass
-
-
-@serializable(canonical_name="KeyValueActionStore", version=1)
-class KeyValueActionStore(ActionStore):
- """Generic Key-Value Action store.
-
- Parameters:
- store_config: StoreConfig
- Backend specific configuration, including connection configuration, database name, or client class type.
- root_verify_key: Optional[SyftVerifyKey]
- Signature verification key, used for checking access permissions.
- """
-
- def __init__(
- self,
- server_uid: UID,
- store_config: StoreConfig,
- root_verify_key: SyftVerifyKey | None = None,
- document_store: DocumentStore | None = None,
- ) -> None:
- self.server_uid = server_uid
- self.store_config = store_config
- self.settings = BasePartitionSettings(name="Action")
- self.data = self.store_config.backing_store(
- "data", self.settings, self.store_config
- )
- self.permissions = self.store_config.backing_store(
- "permissions", self.settings, self.store_config, ddtype=set
- )
- self.storage_permissions = self.store_config.backing_store(
- "storage_permissions", self.settings, self.store_config, ddtype=set
- )
- if root_verify_key is None:
- root_verify_key = SyftSigningKey.generate().verify_key
- self.root_verify_key = root_verify_key
-
- self.__user_stash = None
- if document_store is not None:
- # relative
- from ...service.user.user_stash import UserStash
-
- self.__user_stash = UserStash(store=document_store)
+@serializable(canonical_name="ActionObjectSQLStore", version=1)
+class ActionObjectStash(ObjectStash[ActionObject]):
+ # We are storing ActionObject, Action, TwinObject
+ allow_any_type = True
@as_result(NotFoundException, SyftException)
def get(
self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False
- ) -> SyftObject:
+ ) -> ActionObject:
uid = uid.id # We only need the UID from LineageID or UID
-
- # if you get something you need READ permission
- read_permission = ActionObjectREAD(uid=uid, credentials=credentials)
-
- if not has_permission and not self.has_permission(read_permission):
- raise SyftException(public_message=f"Permission: {read_permission} denied")
-
- # TODO: Remove try/except?
- try:
- if isinstance(uid, LineageID):
- syft_object = self.data[uid.id]
- elif isinstance(uid, UID):
- syft_object = self.data[uid]
- else:
- raise SyftException(
- public_message=f"Unrecognized UID type: {type(uid)}"
- )
- return syft_object
- except Exception as e:
- raise NotFoundException.from_exception(
- e, public_message=f"Object {uid} not found"
- )
+ # TODO remove and use get_by_uid instead
+ return self.get_by_uid(
+ credentials=credentials,
+ uid=uid,
+ has_permission=has_permission,
+ ).unwrap()
@as_result(NotFoundException, SyftException)
- def get_mock(self, uid: UID) -> SyftObject:
+ def get_mock(self, credentials: SyftVerifyKey, uid: UID) -> SyftObject:
uid = uid.id # We only need the UID from LineageID or UID
- try:
- syft_object = self.data[uid]
-
- if isinstance(syft_object, TwinObject) and not is_action_data_empty(
- syft_object.mock
- ):
- return syft_object.mock
- raise NotFoundException(public_message=f"No mock found for object {uid}")
- except Exception as e:
- raise NotFoundException.from_exception(
- e, public_message=f"Object {uid} not found"
- )
+ obj = self.get_by_uid(
+ credentials=credentials, uid=uid, has_permission=True
+ ).unwrap()
+ if isinstance(obj, TwinObject) and not is_action_data_empty(obj.mock):
+ return obj.mock
+ raise NotFoundException(public_message=f"No mock found for object {uid}")
@as_result(NotFoundException, SyftException)
def get_pointer(
@@ -133,34 +58,27 @@ def get_pointer(
) -> SyftObject:
uid = uid.id # We only need the UID from LineageID or UID
- try:
- if uid not in self.data:
- raise SyftException(public_message="Permission denied")
-
- obj = self.data[uid]
- read_permission = ActionObjectREAD(uid=uid, credentials=credentials)
-
- # if you have permission you can have private data
- if self.has_permission(read_permission):
- if isinstance(obj, TwinObject):
- return obj.private.syft_point_to(server_uid)
- return obj.syft_point_to(server_uid)
+ obj = self.get_by_uid(
+ credentials=credentials, uid=uid, has_permission=True
+ ).unwrap()
+ has_permissions = self.has_permission(
+ ActionObjectREAD(uid=uid, credentials=credentials)
+ )
- # if its a twin with a mock anyone can have this
+ if has_permissions:
if isinstance(obj, TwinObject):
- return obj.mock.syft_point_to(server_uid)
+ return obj.private.syft_point_to(server_uid)
+ return obj.syft_point_to(server_uid) # type: ignore
- # finally worst case you get ActionDataEmpty so you can still trace
- return obj.as_empty().syft_point_to(server_uid)
- # TODO: Check if this can be removed
- except Exception as e:
- raise SyftException(public_message=str(e))
+ # if its a twin with a mock anyone can have this
+ if isinstance(obj, TwinObject):
+ return obj.mock.syft_point_to(server_uid)
- def exists(self, uid: UID) -> bool:
- return uid.id in self.data # We only need the UID from LineageID or UID
+ # finally worst case you get ActionDataEmpty so you can still trace
+ return obj.as_empty().syft_point_to(server_uid) # type: ignore
@as_result(SyftException, StashException)
- def set(
+ def set_or_update( # type: ignore
self,
uid: UID,
credentials: SyftVerifyKey,
@@ -170,284 +88,47 @@ def set(
) -> UID:
uid = uid.id # We only need the UID from LineageID or UID
- # if you set something you need WRITE permission
- write_permission = ActionObjectWRITE(uid=uid, credentials=credentials)
- can_write = self.has_permission(write_permission)
-
- if not self.exists(uid=uid):
- # attempt to claim it for writing
+ if self.exists(credentials=credentials, uid=uid):
+ permissions: list[ActionObjectPermission] = []
if has_result_read_permission:
- ownership_result = self.take_ownership(uid=uid, credentials=credentials)
- can_write = True if ownership_result.is_ok() else False
+ permissions.append(ActionObjectREAD(uid=uid, credentials=credentials))
else:
- # root takes owneship, but you can still write
- ownership_result = self.take_ownership(
- uid=uid, credentials=self.root_verify_key
+ permissions.extend(
+ [
+ ActionObjectWRITE(uid=uid, credentials=credentials),
+ ActionObjectEXECUTE(uid=uid, credentials=credentials),
+ ]
)
- can_write = True if ownership_result.is_ok() else False
-
- if not can_write:
- raise SyftException(public_message=f"Permission: {write_permission} denied")
-
- self.data[uid] = syft_object
- if uid not in self.permissions:
- # create default permissions
- self.permissions[uid] = set()
- if has_result_read_permission:
- self.add_permission(ActionObjectREAD(uid=uid, credentials=credentials))
- else:
- self.add_permissions(
- [
- ActionObjectWRITE(uid=uid, credentials=credentials),
- ActionObjectEXECUTE(uid=uid, credentials=credentials),
- ]
- )
-
- if uid not in self.storage_permissions:
- # create default storage permissions
- self.storage_permissions[uid] = set()
- if add_storage_permission:
- self.add_storage_permission(
- StoragePermission(uid=uid, server_uid=self.server_uid)
- )
-
- return uid
-
- @as_result(SyftException)
- def take_ownership(self, uid: UID, credentials: SyftVerifyKey) -> bool:
- uid = uid.id # We only need the UID from LineageID or UID
-
- # first person using this UID can claim ownership
- if uid in self.permissions or uid in self.data:
- raise SyftException(public_message=f"Object {uid} already owned")
-
- self.add_permissions(
- [
- ActionObjectOWNER(uid=uid, credentials=credentials),
- ActionObjectWRITE(uid=uid, credentials=credentials),
- ActionObjectREAD(uid=uid, credentials=credentials),
- ActionObjectEXECUTE(uid=uid, credentials=credentials),
- ]
- )
-
- return True
-
- @as_result(StashException)
- def delete(self, uid: UID, credentials: SyftVerifyKey) -> UID:
- uid = uid.id # We only need the UID from LineageID or UID
-
- # if you delete something you need OWNER permission
- # is it bad to evict a key and have someone else reuse it?
- # perhaps we should keep permissions but no data?
- owner_permission = ActionObjectOWNER(uid=uid, credentials=credentials)
-
- if not self.has_permission(owner_permission):
- raise StashException(
- public_message=f"Permission: {owner_permission} denied"
- )
-
- if uid in self.data:
- del self.data[uid]
- if uid in self.permissions:
- del self.permissions[uid]
-
- return uid
-
- def has_permission(self, permission: ActionObjectPermission) -> bool:
- if not isinstance(permission.permission, ActionPermission):
- # If we reached this point, it's a malformed object error, let it bubble up
- raise TypeError(f"ObjectPermission type: {permission.permission} not valid")
-
- if (
- permission.credentials is not None
- and self.root_verify_key.verify == permission.credentials.verify
- ):
- return True
-
- if self.__user_stash is not None:
- # relative
- from ...service.user.user_roles import ServiceRole
-
- res = self.__user_stash.get_by_verify_key(
- credentials=permission.credentials,
- verify_key=permission.credentials,
- )
-
- if (
- res.is_ok()
- and (user := res.ok()) is not None
- and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN)
- ):
- return True
-
- if (
- permission.uid in self.permissions
- and permission.permission_string in self.permissions[permission.uid]
- ):
- return True
-
- # 🟡 TODO 14: add ALL_READ, ALL_EXECUTE etc
- if permission.permission == ActionPermission.OWNER:
- pass
- elif permission.permission == ActionPermission.READ:
- pass
- elif permission.permission == ActionPermission.WRITE:
- pass
- elif permission.permission == ActionPermission.EXECUTE:
- pass
-
- return False
-
- def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool:
- return all(self.has_permission(p) for p in permissions)
-
- def add_permission(self, permission: ActionObjectPermission) -> None:
- permissions = self.permissions[permission.uid]
- permissions.add(permission.permission_string)
- self.permissions[permission.uid] = permissions
-
- def remove_permission(self, permission: ActionObjectPermission) -> None:
- permissions = self.permissions[permission.uid]
- permissions.remove(permission.permission_string)
- self.permissions[permission.uid] = permissions
-
- def add_permissions(self, permissions: list[ActionObjectPermission]) -> None:
- for permission in permissions:
- self.add_permission(permission)
-
- @as_result(ObjectCRUDPermissionException)
- def _get_permissions_for_uid(self, uid: UID) -> set[str]:
- if uid in self.permissions:
- return self.permissions[uid]
- raise ObjectCRUDPermissionException(
- public_message=f"No permissions found for uid: {uid}"
- )
-
- @as_result(SyftException)
- def get_all_permissions(self) -> dict[UID, set[str]]:
- return dict(self.permissions.items())
-
- def add_storage_permission(self, permission: StoragePermission) -> None:
- permissions = self.storage_permissions[permission.uid]
- permissions.add(permission.server_uid)
- self.storage_permissions[permission.uid] = permissions
-
- def add_storage_permissions(self, permissions: list[StoragePermission]) -> None:
- for permission in permissions:
- self.add_storage_permission(permission)
-
- def remove_storage_permission(self, permission: StoragePermission) -> None:
- permissions = self.storage_permissions[permission.uid]
- permissions.remove(permission.server_uid)
- self.storage_permissions[permission.uid] = permissions
-
- def has_storage_permission(self, permission: StoragePermission | UID) -> bool:
- if isinstance(permission, UID):
- permission = StoragePermission(uid=permission, server_uid=self.server_uid)
-
- if permission.uid in self.storage_permissions:
- return permission.server_uid in self.storage_permissions[permission.uid]
-
- return False
-
- @as_result(ObjectCRUDPermissionException)
- def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]:
- if uid in self.storage_permissions:
- return self.storage_permissions[uid]
- raise ObjectCRUDPermissionException(f"No storage permissions found for {uid}")
-
- @as_result(SyftException)
- def get_all_storage_permissions(self) -> dict[UID, set[UID]]:
- return dict(self.storage_permissions.items())
-
- def _all(
- self,
- credentials: SyftVerifyKey,
- has_permission: bool | None = False,
- ) -> list[SyftObject]:
- # this checks permissions
- res = [self.get(uid, credentials, has_permission) for uid in self.data.keys()]
- return [x.ok() for x in res if x.is_ok()] # type: ignore
-
- @as_result(ObjectCRUDPermissionException)
- def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey) -> bool:
- has_root_permission = credentials == self.root_verify_key
-
- if not has_root_permission:
- raise ObjectCRUDPermissionException(
- public_message="You don't have permissions to migrate data."
- )
-
- for key, value in self.data.items():
- try:
- if value.__canonical_name__ != to_klass.__canonical_name__:
- continue
- migrated_value = value.migrate_to(to_klass.__version__)
- except Exception as e:
- raise SyftException.from_exception(
- e,
- public_message=f"Failed to migrate data to {to_klass} for qk: {key}",
+ storage_permission = []
+ if add_storage_permission:
+ storage_permission.append(
+ StoragePermission(uid=uid, server_uid=self.server_uid)
)
- self.set(
- uid=key,
+ self.update(
credentials=credentials,
- syft_object=migrated_value,
+ obj=syft_object,
).unwrap()
+ self.add_permissions(permissions).unwrap()
+ self.add_storage_permissions(storage_permission).unwrap()
+ return uid
- return True
-
-
-@serializable(canonical_name="DictActionStore", version=1)
-class DictActionStore(KeyValueActionStore):
- """Dictionary-Based Key-Value Action store.
-
- Parameters:
- store_config: StoreConfig
- Backend specific configuration, including client class type.
- root_verify_key: Optional[SyftVerifyKey]
- Signature verification key, used for checking access permissions.
- """
-
- def __init__(
- self,
- server_uid: UID,
- store_config: StoreConfig | None = None,
- root_verify_key: SyftVerifyKey | None = None,
- document_store: DocumentStore | None = None,
- ) -> None:
- store_config = store_config if store_config is not None else DictStoreConfig()
- super().__init__(
- server_uid=server_uid,
- store_config=store_config,
- root_verify_key=root_verify_key,
- document_store=document_store,
+ owner_credentials = (
+ credentials if has_result_read_permission else self.root_verify_key
)
+ # if not has_result_read_permission
+ # root takes owneship, but you can still write and execute
+ super().set(
+ credentials=owner_credentials,
+ obj=syft_object,
+ add_permissions=[
+ ActionObjectWRITE(uid=uid, credentials=credentials),
+ ActionObjectEXECUTE(uid=uid, credentials=credentials),
+ ],
+ add_storage_permission=add_storage_permission,
+ ).unwrap()
+ return uid
-@serializable(canonical_name="SQLiteActionStore", version=1)
-class SQLiteActionStore(KeyValueActionStore):
- """SQLite-Based Key-Value Action store.
-
- Parameters:
- store_config: StoreConfig
- SQLite specific configuration, including connection settings or client class type.
- root_verify_key: Optional[SyftVerifyKey]
- Signature verification key, used for checking access permissions.
- """
-
- pass
-
-
-@serializable(canonical_name="MongoActionStore", version=1)
-class MongoActionStore(KeyValueActionStore):
- """Mongo-Based Action store.
-
- Parameters:
- store_config: StoreConfig
- Mongo specific configuration.
- root_verify_key: Optional[SyftVerifyKey]
- Signature verification key, used for checking access permissions.
- """
-
- pass
+ def set(self, *args, **kwargs): # type: ignore
+ raise Exception("Use `ActionObjectStash.set_or_update` instead.")
diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py
index 44567e8d8ef..1cebbd1e0fc 100644
--- a/packages/syft/src/syft/service/api/api.py
+++ b/packages/syft/src/syft/service/api/api.py
@@ -575,7 +575,7 @@ def exec_code(
api_service = context.server.get_service("apiservice")
api_service.stash.upsert(
- context.server.services.user.admin_verify_key(), self
+ context.server.services.user.root_verify_key, self
).unwrap()
print = original_print # type: ignore
@@ -650,7 +650,7 @@ def code_string(context: TransformContext) -> TransformContext:
)
context.server = cast(AbstractServer, context.server)
- admin_key = context.server.services.user.admin_verify_key()
+ admin_key = context.server.services.user.root_verify_key
# If endpoint exists **AND** (has visible access **OR** the user is admin)
if endpoint_type is not None and (
diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py
index a8c443a6271..8df26f0ac15 100644
--- a/packages/syft/src/syft/service/api/api_service.py
+++ b/packages/syft/src/syft/service/api/api_service.py
@@ -10,7 +10,7 @@
from ...serde.serializable import serializable
from ...service.action.action_endpoint import CustomEndpointActionObject
from ...service.action.action_object import ActionObject
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
@@ -37,11 +37,9 @@
@serializable(canonical_name="APIService", version=1)
class APIService(AbstractService):
- store: DocumentStore
stash: TwinAPIEndpointStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = TwinAPIEndpointStash(store=store)
@service_method(
@@ -263,7 +261,7 @@ def api_endpoints(
context: AuthedServiceContext,
) -> list[TwinAPIEndpointView]:
"""Retrieves a list of available API endpoints view available to the user."""
- admin_key = context.server.services.user.admin_verify_key()
+ admin_key = context.server.services.user.root_verify_key
all_api_endpoints = self.stash.get_all(admin_key).unwrap()
api_endpoint_view = [
@@ -587,7 +585,7 @@ def execute_server_side_endpoint_mock_by_id(
def get_endpoint_by_uid(
self, context: AuthedServiceContext, uid: UID
) -> TwinAPIEndpoint:
- admin_key = context.server.services.user.admin_verify_key()
+ admin_key = context.server.services.user.root_verify_key
return self.stash.get_by_uid(admin_key, uid).unwrap()
@as_result(StashException)
diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py
index 3a610c23cef..0c0c6f73020 100644
--- a/packages/syft/src/syft/service/api/api_stash.py
+++ b/packages/syft/src/syft/service/api/api_stash.py
@@ -1,11 +1,7 @@
-# stdlib
-
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
@@ -14,24 +10,21 @@
MISSING_PATH_STRING = "Endpoint path: {path} does not exist."
-@serializable(canonical_name="TwinAPIEndpointStash", version=1)
-class TwinAPIEndpointStash(NewBaseUIDStoreStash):
- object_type = TwinAPIEndpoint
- settings: PartitionSettings = PartitionSettings(
- name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="TwinAPIEndpointSQLStash", version=1)
+class TwinAPIEndpointStash(ObjectStash[TwinAPIEndpoint]):
@as_result(StashException, NotFoundException)
def get_by_path(self, credentials: SyftVerifyKey, path: str) -> TwinAPIEndpoint:
- endpoints = self.get_all(credentials=credentials).unwrap()
- for endpoint in endpoints:
- if endpoint.path == path:
- return endpoint
+ # TODO standardize by returning None if endpoint doesnt exist.
+ res = self.get_one(
+ credentials=credentials,
+ filters={"path": path},
+ )
- raise NotFoundException(public_message=MISSING_PATH_STRING.format(path=path))
+ if res.is_err():
+ raise NotFoundException(
+ public_message=MISSING_PATH_STRING.format(path=path)
+ )
+ return res.unwrap()
@as_result(StashException)
def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool:
@@ -49,11 +42,9 @@ def upsert(
has_permission: bool = False,
) -> TwinAPIEndpoint:
"""Upsert an endpoint."""
- path_exists = self.path_exists(
- credentials=credentials, path=endpoint.path
- ).unwrap()
+ exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap()
- if path_exists:
+ if exists:
super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap()
return (
diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py
index debb838c4b1..93110f72d74 100644
--- a/packages/syft/src/syft/service/attestation/attestation_service.py
+++ b/packages/syft/src/syft/service/attestation/attestation_service.py
@@ -6,7 +6,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.errors import SyftException
from ...types.result import as_result
from ...util.util import str_to_bool
@@ -24,8 +24,8 @@
class AttestationService(AbstractService):
"""This service is responsible for getting all sorts of attestations for any client."""
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
+ pass
@as_result(SyftException)
def perform_request(
diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py
index 25d6042bd09..76abe869ae6 100644
--- a/packages/syft/src/syft/service/blob_storage/remote_profile.py
+++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py
@@ -1,10 +1,10 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
+from ..service import AbstractService
@serializable()
@@ -24,12 +24,14 @@ class AzureRemoteProfile(RemoteProfile):
container_name: str
-@serializable(canonical_name="RemoteProfileStash", version=1)
-class RemoteProfileStash(NewBaseUIDStoreStash):
- object_type = RemoteProfile
- settings: PartitionSettings = PartitionSettings(
- name=RemoteProfile.__canonical_name__, object_type=RemoteProfile
- )
+@serializable(canonical_name="RemoteProfileSQLStash", version=1)
+class RemoteProfileStash(ObjectStash[RemoteProfile]):
+ pass
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+
+@serializable(canonical_name="RemoteProfileService", version=1)
+class RemoteProfileService(AbstractService):
+ stash: RemoteProfileStash
+
+ def __init__(self, store: DBManager) -> None:
+ self.stash = RemoteProfileStash(store=store)
diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py
index c4bc955e29d..055e4d946e4 100644
--- a/packages/syft/src/syft/service/blob_storage/service.py
+++ b/packages/syft/src/syft/service/blob_storage/service.py
@@ -11,8 +11,7 @@
from ...store.blob_storage import BlobRetrieval
from ...store.blob_storage.on_disk import OnDiskBlobDeposit
from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit
-from ...store.document_store import DocumentStore
-from ...store.document_store import UIDPartitionKey
+from ...store.db.db import DBManager
from ...types.blob_storage import AzureSecureFilePathLocation
from ...types.blob_storage import BlobFileType
from ...types.blob_storage import BlobStorageEntry
@@ -38,12 +37,10 @@
@serializable(canonical_name="BlobStorageService", version=1)
class BlobStorageService(AbstractService):
- store: DocumentStore
stash: BlobStorageStash
remote_profile_stash: RemoteProfileStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = BlobStorageStash(store=store)
self.remote_profile_stash = RemoteProfileStash(store=store)
@@ -139,8 +136,8 @@ def mount_azure(
def get_files_from_bucket(
self, context: AuthedServiceContext, bucket_name: str
) -> list:
- bse_list = self.stash.find_all(
- context.credentials, bucket_name=bucket_name
+ bse_list = self.stash.get_all(
+ context.credentials, filters={"bucket_name": bucket_name}
).unwrap()
blob_files = []
@@ -323,8 +320,8 @@ def delete(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess:
public_message=f"Failed to delete blob file with id '{uid}'. Error: {e}"
)
- self.stash.delete(
- context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True
+ self.stash.delete_by_uid(
+ context.credentials, uid, has_permission=True
).unwrap()
except Exception as e:
raise SyftException(
diff --git a/packages/syft/src/syft/service/blob_storage/stash.py b/packages/syft/src/syft/service/blob_storage/stash.py
index 0ab4d7a8aa5..9cb002b7eb9 100644
--- a/packages/syft/src/syft/service/blob_storage/stash.py
+++ b/packages/syft/src/syft/service/blob_storage/stash.py
@@ -1,17 +1,9 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
+from ...store.db.stash import ObjectStash
from ...types.blob_storage import BlobStorageEntry
-@serializable(canonical_name="BlobStorageStash", version=1)
-class BlobStorageStash(NewBaseUIDStoreStash):
- object_type = BlobStorageEntry
- settings: PartitionSettings = PartitionSettings(
- name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="BlobStorageSQLStash", version=1)
+class BlobStorageStash(ObjectStash[BlobStorageEntry]):
+ pass
diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py
index 1ffb70ebb6f..d6c1a56e801 100644
--- a/packages/syft/src/syft/service/code/status_service.py
+++ b/packages/syft/src/syft/service/code/status_service.py
@@ -4,14 +4,9 @@
# relative
from ...serde.serializable import serializable
-from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store import UIDPartitionKey
-from ...store.document_store_errors import StashException
-from ...types.result import as_result
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftSuccess
@@ -23,35 +18,19 @@
from .user_code import UserCodeStatusCollection
-@serializable(canonical_name="StatusStash", version=1)
-class StatusStash(NewBaseUIDStoreStash):
- object_type = UserCodeStatusCollection
+@serializable(canonical_name="StatusSQLStash", version=1)
+class StatusStash(ObjectStash[UserCodeStatusCollection]):
settings: PartitionSettings = PartitionSettings(
name=UserCodeStatusCollection.__canonical_name__,
object_type=UserCodeStatusCollection,
)
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store)
- self.store = store
- self.settings = self.settings
- self._object_type = self.object_type
-
- @as_result(StashException)
- def get_by_uid(
- self, credentials: SyftVerifyKey, uid: UID
- ) -> UserCodeStatusCollection:
- qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
-
@serializable(canonical_name="UserCodeStatusService", version=1)
class UserCodeStatusService(AbstractService):
- store: DocumentStore
stash: StatusStash
- def __init__(self, store: DocumentStore):
- self.store = store
+ def __init__(self, store: DBManager):
self.stash = StatusStash(store=store)
@service_method(path="code_status.create", name="create", roles=ADMIN_ROLE_LEVEL)
diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py
index 0921ea9e704..81e586d5c02 100644
--- a/packages/syft/src/syft/service/code/user_code.py
+++ b/packages/syft/src/syft/service/code/user_code.py
@@ -46,7 +46,6 @@
from ...serde.signature import signature_remove_context
from ...serde.signature import signature_remove_self
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import PartitionKey
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
@@ -107,11 +106,6 @@
# relative
from ...service.sync.diff_state import AttrDiff
-UserVerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey)
-CodeHashPartitionKey = PartitionKey(key="code_hash", type_=str)
-ServiceFuncNamePartitionKey = PartitionKey(key="service_func_name", type_=str)
-SubmitTimePartitionKey = PartitionKey(key="submit_time", type_=DateTime)
-
PyCodeObject = Any
diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py
index 32863990a6a..79490223f45 100644
--- a/packages/syft/src/syft/service/code/user_code_service.py
+++ b/packages/syft/src/syft/service/code/user_code_service.py
@@ -5,7 +5,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
@@ -60,11 +60,9 @@ class IsExecutionAllowedEnum(str, Enum):
@serializable(canonical_name="UserCodeService", version=1)
class UserCodeService(AbstractService):
- store: DocumentStore
stash: UserCodeStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = UserCodeStash(store=store)
@service_method(
@@ -77,11 +75,12 @@ def submit(
self, context: AuthedServiceContext, code: SubmitUserCode
) -> SyftSuccess:
"""Add User Code"""
- user_code = self._submit(context, code, exists_ok=False)
+ user_code = self._submit(context, code, exists_ok=False).unwrap()
return SyftSuccess(
message="User Code Submitted", require_api_update=True, value=user_code
)
+ @as_result(SyftException)
def _submit(
self,
context: AuthedServiceContext,
@@ -107,17 +106,17 @@ def _submit(
context.credentials,
code_hash=get_code_hash(submit_code.code, context.credentials),
).unwrap()
-
- if not exists_ok:
+ # no exception, code exists
+ if exists_ok:
+ return existing_code
+ else:
raise SyftException(
- public_message="The code to be submitted already exists"
+ public_message="UserCode with this code already exists"
)
- return existing_code
except NotFoundException:
pass
code = submit_code.to(UserCode, context=context)
-
result = self._post_user_code_transform_ops(context, code)
if result.is_err():
@@ -281,17 +280,11 @@ def _get_or_submit_user_code(
- If the code is a SubmitUserCode and the code hash does not exist, submit the code
"""
if isinstance(code, UserCode):
- # Get existing UserCode
- try:
- return self.stash.get_by_uid(context.credentials, code.id).unwrap()
- except NotFoundException as exc:
- raise NotFoundException.from_exception(
- exc, public_message=f"UserCode {code.id} not found on this server"
- )
+ return self.stash.get_by_uid(context.credentials, code.id).unwrap()
else: # code: SubmitUserCode
# Submit new UserCode, or get existing UserCode with the same code hash
# TODO: Why is this tagged as unreachable?
- return self._submit(context, code, exists_ok=True) # type: ignore[unreachable]
+ return self._submit(context, code, exists_ok=True).unwrap() # type: ignore[unreachable]
@service_method(
path="code.request_code_execution",
diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py
index 308de4d28bf..232342bd8d5 100644
--- a/packages/syft/src/syft/service/code/user_code_stash.py
+++ b/packages/syft/src/syft/service/code/user_code_stash.py
@@ -1,47 +1,32 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
-from .user_code import CodeHashPartitionKey
-from .user_code import ServiceFuncNamePartitionKey
-from .user_code import SubmitTimePartitionKey
from .user_code import UserCode
-from .user_code import UserVerifyKeyPartitionKey
-@serializable(canonical_name="UserCodeStash", version=1)
-class UserCodeStash(NewBaseUIDStoreStash):
- object_type = UserCode
+@serializable(canonical_name="UserCodeSQLStash", version=1)
+class UserCodeStash(ObjectStash[UserCode]):
settings: PartitionSettings = PartitionSettings(
name=UserCode.__canonical_name__, object_type=UserCode
)
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
- @as_result(StashException, NotFoundException)
- def get_all_by_user_verify_key(
- self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey
- ) -> list[UserCode]:
- qks = QueryKeys(qks=[UserVerifyKeyPartitionKey.with_obj(user_verify_key)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
-
@as_result(StashException, NotFoundException)
def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode:
- qks = QueryKeys(qks=[CodeHashPartitionKey.with_obj(code_hash)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"code_hash": code_hash},
+ ).unwrap()
@as_result(StashException)
def get_by_service_func_name(
self, credentials: SyftVerifyKey, service_func_name: str
) -> list[UserCode]:
- qks = QueryKeys(qks=[ServiceFuncNamePartitionKey.with_obj(service_func_name)])
- return self.query_all(
- credentials=credentials, qks=qks, order_by=SubmitTimePartitionKey
+ return self.get_all(
+ credentials=credentials,
+ filters={"service_func_name": service_func_name},
).unwrap()
diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py
index a3d06bdb4fa..ff2967f169b 100644
--- a/packages/syft/src/syft/service/code_history/code_history_service.py
+++ b/packages/syft/src/syft/service/code_history/code_history_service.py
@@ -3,7 +3,7 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...types.uid import UID
from ..code.user_code import SubmitUserCode
@@ -24,11 +24,9 @@
@serializable(canonical_name="CodeHistoryService", version=1)
class CodeHistoryService(AbstractService):
- store: DocumentStore
stash: CodeHistoryStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = CodeHistoryStash(store=store)
@service_method(
@@ -46,7 +44,7 @@ def submit_version(
if isinstance(code, SubmitUserCode):
code = context.server.services.user_code._submit(
context=context, submit_code=code
- )
+ ).unwrap()
try:
code_history = self.stash.get_by_service_func_name_and_verify_key(
@@ -189,11 +187,13 @@ def get_by_func_name_and_user_email(
) -> list[CodeHistory]:
user_verify_key = context.server.services.user.user_verify_key(user_email)
- kwargs = {
+ filters = {
"id": user_id,
"email": user_email,
"verify_key": user_verify_key,
"service_func_name": service_func_name,
}
- return self.stash.find_all(credentials=context.credentials, **kwargs).unwrap()
+ return self.stash.get_all(
+ credentials=context.credentials, filters=filters
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py
index d03358feb83..69dfd272717 100644
--- a/packages/syft/src/syft/service/code_history/code_history_stash.py
+++ b/packages/syft/src/syft/service/code_history/code_history_stash.py
@@ -1,56 +1,43 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import StashException
from ...types.result import as_result
from .code_history import CodeHistory
-NamePartitionKey = PartitionKey(key="service_func_name", type_=str)
-VerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey)
-
-
-@serializable(canonical_name="CodeHistoryStash", version=1)
-class CodeHistoryStash(NewBaseUIDStoreStash):
- object_type = CodeHistory
- settings: PartitionSettings = PartitionSettings(
- name=CodeHistory.__canonical_name__, object_type=CodeHistory
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="CodeHistoryStashSQL", version=1)
+class CodeHistoryStash(ObjectStash[CodeHistory]):
@as_result(StashException)
def get_by_service_func_name_and_verify_key(
self,
credentials: SyftVerifyKey,
service_func_name: str,
user_verify_key: SyftVerifyKey,
- ) -> list[CodeHistory]:
- qks = QueryKeys(
- qks=[
- NamePartitionKey.with_obj(service_func_name),
- VerifyKeyPartitionKey.with_obj(user_verify_key),
- ]
- )
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ ) -> CodeHistory:
+ return self.get_one(
+ credentials=credentials,
+ filters={
+ "user_verify_key": user_verify_key,
+ "service_func_name": service_func_name,
+ },
+ ).unwrap()
@as_result(StashException)
def get_by_service_func_name(
self, credentials: SyftVerifyKey, service_func_name: str
) -> list[CodeHistory]:
- qks = QueryKeys(qks=[NamePartitionKey.with_obj(service_func_name)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"service_func_name": service_func_name},
+ ).unwrap()
@as_result(StashException)
def get_by_verify_key(
self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey
) -> list[CodeHistory]:
- if isinstance(user_verify_key, str):
- user_verify_key = SyftVerifyKey.from_string(user_verify_key)
- qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(user_verify_key)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"user_verify_key": user_verify_key},
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py
index c54e8bc67ea..e7e482a4337 100644
--- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py
+++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py
@@ -3,10 +3,8 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import StashException
from ...types.result import as_result
from ..context import AuthedServiceContext
@@ -14,44 +12,35 @@
from ..service import AbstractService
from ..service import SERVICE_TO_TYPES
from ..service import TYPE_TO_SERVICE
-from .data_subject_member import ChildPartitionKey
from .data_subject_member import DataSubjectMemberRelationship
-from .data_subject_member import ParentPartitionKey
-@serializable(canonical_name="DataSubjectMemberStash", version=1)
-class DataSubjectMemberStash(NewBaseUIDStoreStash):
- object_type = DataSubjectMemberRelationship
- settings: PartitionSettings = PartitionSettings(
- name=DataSubjectMemberRelationship.__canonical_name__,
- object_type=DataSubjectMemberRelationship,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="DataSubjectMemberSQLStash", version=1)
+class DataSubjectMemberStash(ObjectStash[DataSubjectMemberRelationship]):
@as_result(StashException)
def get_all_for_parent(
self, credentials: SyftVerifyKey, name: str
) -> list[DataSubjectMemberRelationship]:
- qks = QueryKeys(qks=[ParentPartitionKey.with_obj(name)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"parent": name},
+ ).unwrap()
@as_result(StashException)
def get_all_for_child(
self, credentials: SyftVerifyKey, name: str
) -> list[DataSubjectMemberRelationship]:
- qks = QueryKeys(qks=[ChildPartitionKey.with_obj(name)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"child": name},
+ ).unwrap()
@serializable(canonical_name="DataSubjectMemberService", version=1)
class DataSubjectMemberService(AbstractService):
- store: DocumentStore
stash: DataSubjectMemberStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = DataSubjectMemberStash(store=store)
def add(
diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py
index b8d5e6b8528..ecde100edf5 100644
--- a/packages/syft/src/syft/service/data_subject/data_subject_service.py
+++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py
@@ -5,10 +5,8 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import StashException
from ...types.result import as_result
from ..context import AuthedServiceContext
@@ -19,43 +17,23 @@
from ..service import service_method
from .data_subject import DataSubject
from .data_subject import DataSubjectCreate
-from .data_subject import NamePartitionKey
-@serializable(canonical_name="DataSubjectStash", version=1)
-class DataSubjectStash(NewBaseUIDStoreStash):
- object_type = DataSubject
- settings: PartitionSettings = PartitionSettings(
- name=DataSubject.__canonical_name__, object_type=DataSubject
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="DataSubjectSQLStash", version=1)
+class DataSubjectStash(ObjectStash[DataSubject]):
@as_result(StashException)
def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject:
- qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)])
- return self.query_one(credentials, qks=qks).unwrap()
-
- @as_result(StashException)
- def update(
- self,
- credentials: SyftVerifyKey,
- data_subject: DataSubject,
- has_permission: bool = False,
- ) -> DataSubject:
- res = self.check_type(data_subject, DataSubject).unwrap()
- # we dont use and_then logic here as it is hard because of the order of the arguments
- return super().update(credentials=credentials, obj=res).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"name": name},
+ ).unwrap()
@serializable(canonical_name="DataSubjectService", version=1)
class DataSubjectService(AbstractService):
- store: DocumentStore
stash: DataSubjectStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = DataSubjectStash(store=store)
@service_method(path="data_subject.add", name="add_data_subject")
diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py
index 6c4a91e06c2..fbc5318be57 100644
--- a/packages/syft/src/syft/service/dataset/dataset.py
+++ b/packages/syft/src/syft/service/dataset/dataset.py
@@ -7,6 +7,7 @@
from typing import Any
# third party
+from IPython.display import display
import markdown
import pandas as pd
from pydantic import ConfigDict
@@ -291,8 +292,8 @@ def _private_data(self) -> Any:
def data(self) -> Any:
try:
return self._private_data().unwrap()
- except SyftException as e:
- print(e)
+ except SyftException:
+ display(SyftError(message="You have no access to the private data"))
return None
@@ -533,6 +534,7 @@ def _repr_html_(self) -> Any:
{self.assets._repr_html_()}
"""
+ @property
def action_ids(self) -> list[UID]:
return [asset.action_id for asset in self.asset_list if asset.action_id]
diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py
index 43cbfacb117..a25a49ee60d 100644
--- a/packages/syft/src/syft/service/dataset/dataset_service.py
+++ b/packages/syft/src/syft/service/dataset/dataset_service.py
@@ -5,7 +5,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.dicttuple import DictTuple
from ...types.uid import UID
from ..action.action_permissions import ActionObjectPermission
@@ -25,7 +25,6 @@
from .dataset import CreateDataset
from .dataset import Dataset
from .dataset import DatasetPageView
-from .dataset import DatasetUpdate
from .dataset_stash import DatasetStash
logger = logging.getLogger(__name__)
@@ -69,11 +68,9 @@ def _paginate_dataset_collection(
@serializable(canonical_name="DatasetService", version=1)
class DatasetService(AbstractService):
- store: DocumentStore
stash: DatasetStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = DatasetStash(store=store)
@service_method(
@@ -117,13 +114,11 @@ def get_all(
page_index: int | None = 0,
) -> DatasetPageView | DictTuple[str, Dataset]:
"""Get a Dataset"""
- datasets = self.stash.get_all(context.credentials).unwrap()
+ datasets = self.stash.get_all_active(context.credentials).unwrap()
for dataset in datasets:
if context.server is not None:
dataset.server_uid = context.server.id
- if dataset.to_be_deleted:
- datasets.remove(dataset)
return _paginate_dataset_collection(
datasets=datasets, page_size=page_size, page_index=page_index
@@ -234,11 +229,9 @@ def delete(
return_msg.append(f"Asset with id '{asset.id}' successfully deleted.")
# soft delete the dataset object from the store
- dataset_update = DatasetUpdate(
- id=uid, name=f"_deleted_{dataset.name}_{uid}", to_be_deleted=True
- )
- self.stash.update(context.credentials, dataset_update).unwrap()
-
+ dataset.name = f"_deleted_{dataset.name}_{uid}"
+ dataset.to_be_deleted = True
+ self.stash.update(context.credentials, dataset).unwrap()
return_msg.append(f"Dataset with id '{uid}' successfully deleted.")
return SyftSuccess(message="\n".join(return_msg))
diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py
index 19fc33c5906..aee2a280372 100644
--- a/packages/syft/src/syft/service/dataset/dataset_stash.py
+++ b/packages/syft/src/syft/service/dataset/dataset_stash.py
@@ -1,64 +1,56 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
from ...types.uid import UID
from .dataset import Dataset
-from .dataset import DatasetUpdate
-NamePartitionKey = PartitionKey(key="name", type_=str)
-ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID])
-
-
-@serializable(canonical_name="DatasetStash", version=1)
-class DatasetStash(NewBaseUIDStoreStash):
- object_type = Dataset
- settings: PartitionSettings = PartitionSettings(
- name=Dataset.__canonical_name__, object_type=Dataset
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="DatasetStashSQL", version=1)
+class DatasetStash(ObjectStash[Dataset]):
@as_result(StashException, NotFoundException)
def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset:
- qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(credentials=credentials, filters={"name": name}).unwrap()
@as_result(StashException)
def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]:
- qks = QueryKeys(qks=[ActionIDsPartitionKey.with_obj(uid)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
-
- @as_result(StashException)
- def get_all(
- self,
- credentials: SyftVerifyKey,
- order_by: PartitionKey | None = None,
- has_permission: bool = False,
- ) -> list:
- result = super().get_all(credentials, order_by, has_permission).unwrap()
- filtered_datasets = [dataset for dataset in result if not dataset.to_be_deleted]
- return filtered_datasets
+ return self.get_all_active(
+ credentials=credentials,
+ filters={"action_ids__contains": uid},
+ ).unwrap()
- # FIX: This shouldn't be the update method, it just marks the dataset for deletion
@as_result(StashException)
- def update(
+ def get_all_active(
self,
credentials: SyftVerifyKey,
- obj: DatasetUpdate,
has_permission: bool = False,
- ) -> Dataset:
- _obj = self.check_type(obj, DatasetUpdate).unwrap()
- # FIX: This method needs a revamp
- qk = self.partition.store_query_key(obj)
- return self.partition.update(
- credentials=credentials, qk=qk, obj=_obj, has_permission=has_permission
- ).unwrap()
+ order_by: str | None = None,
+ sort_order: str | None = None,
+ limit: int | None = None,
+ offset: int | None = None,
+ filters: dict | None = None,
+ ) -> list[Dataset]:
+ # TODO standardize soft delete and move to ObjectStash.get_all
+ default_filters = {"to_be_deleted": False}
+ filters = filters or {}
+ filters.update(default_filters)
+
+ if offset is None:
+ offset = 0
+
+ return (
+ super()
+ .get_all(
+ credentials=credentials,
+ filters=filters,
+ has_permission=has_permission,
+ order_by=order_by,
+ sort_order=sort_order,
+ limit=limit,
+ offset=offset,
+ )
+ .unwrap()
+ )
diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py
index 2f88c60e123..064c8806f91 100644
--- a/packages/syft/src/syft/service/enclave/enclave_service.py
+++ b/packages/syft/src/syft/service/enclave/enclave_service.py
@@ -2,13 +2,11 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ..service import AbstractService
@serializable(canonical_name="EnclaveService", version=1)
class EnclaveService(AbstractService):
- store: DocumentStore
-
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
+ pass
diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py
index fbbdfd7d856..5cff02ecb74 100644
--- a/packages/syft/src/syft/service/job/job_service.py
+++ b/packages/syft/src/syft/service/job/job_service.py
@@ -6,7 +6,7 @@
# relative
from ...serde.serializable import serializable
from ...server.worker_settings import WorkerSettings
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.errors import SyftException
from ...types.uid import UID
from ..action.action_object import ActionObject
@@ -40,11 +40,9 @@ def wait_until(predicate: Callable[[], bool], timeout: int = 10) -> SyftSuccess:
@serializable(canonical_name="JobService", version=1)
class JobService(AbstractService):
- store: DocumentStore
stash: JobStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = JobStash(store=store)
@service_method(
@@ -132,7 +130,7 @@ def restart(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess:
context.credentials, queue_item
).unwrap()
- context.server.job_stash.set(context.credentials, job).unwrap()
+ self.stash.set(context.credentials, job).unwrap()
context.server.services.log.restart(context, job.log_id)
return SyftSuccess(message="Great Success!")
@@ -218,7 +216,7 @@ def add_read_permission_job_for_code_owner(
job.id, ActionPermission.READ, user_code.user_verify_key
)
# TODO: make add_permission wrappable
- return self.stash.add_permission(permission=permission)
+ return self.stash.add_permission(permission=permission).unwrap()
@service_method(
path="job.add_read_permission_log_for_code_owner",
@@ -232,7 +230,7 @@ def add_read_permission_log_for_code_owner(
ActionObjectPermission(
log_id, ActionPermission.READ, user_code.user_verify_key
)
- )
+ ).unwrap()
@service_method(
path="job.create_job_for_user_code_id",
diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py
index fc83b675503..358834470fb 100644
--- a/packages/syft/src/syft/service/job/job_stash.py
+++ b/packages/syft/src/syft/service/job/job_stash.py
@@ -1,5 +1,4 @@
# stdlib
-from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -20,26 +19,19 @@
from ...server.credentials import SyftVerifyKey
from ...service.context import AuthedServiceContext
from ...service.worker.worker_pool import SyftWorker
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store import UIDPartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
-from ...store.document_store_errors import TooManyItemsFoundException
from ...types.datetime import DateTime
from ...types.datetime import format_timedelta
from ...types.errors import SyftException
from ...types.result import Err
from ...types.result import as_result
-from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
-from ...types.transforms import make_set_default
from ...types.uid import UID
from ...util.markdown import as_markdown_code
from ...util.util import prompt_warning_message
@@ -125,6 +117,7 @@ class Job(SyncableSyftObject):
"user_code_id",
"result_id",
]
+
__repr_attrs__ = [
"id",
"result",
@@ -133,6 +126,7 @@ class Job(SyncableSyftObject):
"creation_time",
"user_code_name",
]
+
__exclude_sync_diff_attrs__ = ["action", "server_uid"]
__table_coll_widths__ = [
"min-content",
@@ -740,16 +734,12 @@ def from_job(
return info
-@serializable(canonical_name="JobStash", version=1)
-class JobStash(NewBaseUIDStoreStash):
- object_type = Job
+@serializable(canonical_name="JobStashSQL", version=1)
+class JobStash(ObjectStash[Job]):
settings: PartitionSettings = PartitionSettings(
name=Job.__canonical_name__, object_type=Job
)
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
@as_result(StashException)
def set_result(
self,
@@ -764,98 +754,40 @@ def set_result(
and item.result.syft_blob_storage_entry_id is not None
):
item.result._clear_cache()
- return (
- super()
- .update(credentials, item, add_permissions)
- .unwrap(public_message="Failed to update")
- )
-
- @as_result(StashException)
- def get_by_result_id(
- self,
- credentials: SyftVerifyKey,
- result_id: UID,
- ) -> Job:
- qks = QueryKeys(
- qks=[PartitionKey(key="result_id", type_=UID).with_obj(result_id)]
- )
- res = self.query_all(credentials=credentials, qks=qks).unwrap()
-
- if len(res) == 0:
- raise NotFoundException()
- elif len(res) > 1:
- raise TooManyItemsFoundException()
- else:
- return res[0]
-
- @as_result(StashException)
- def get_by_parent_id(self, credentials: SyftVerifyKey, uid: UID) -> list[Job]:
- qks = QueryKeys(
- qks=[PartitionKey(key="parent_job_id", type_=UID).with_obj(uid)]
+ return self.update(credentials, item, add_permissions).unwrap(
+ public_message="Failed to update"
)
- return self.query_all(credentials=credentials, qks=qks).unwrap()
-
- @as_result(StashException)
- def delete_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> bool: # type: ignore[override]
- qk = UIDPartitionKey.with_obj(uid)
- return super().delete(credentials=credentials, qk=qk).unwrap()
- @as_result(StashException)
def get_active(self, credentials: SyftVerifyKey) -> list[Job]:
- qks = QueryKeys(
- qks=[
- PartitionKey(key="status", type_=JobStatus).with_obj(
- JobStatus.PROCESSING
- )
- ]
- )
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"status": JobStatus.CREATED},
+ ).unwrap()
- @as_result(StashException)
def get_by_worker(self, credentials: SyftVerifyKey, worker_id: str) -> list[Job]:
- qks = QueryKeys(
- qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)]
- )
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"job_worker_id": worker_id},
+ ).unwrap()
@as_result(StashException)
def get_by_user_code_id(
self, credentials: SyftVerifyKey, user_code_id: UID
) -> list[Job]:
- qks = QueryKeys(
- qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)]
- )
- return self.query_all(credentials=credentials, qks=qks).unwrap()
-
-
-@serializable()
-class JobV1(SyncableSyftObject):
- __canonical_name__ = "JobItem"
- __version__ = SYFT_OBJECT_VERSION_1
-
- id: UID
- server_uid: UID
- result: Any | None = None
- resolved: bool = False
- status: JobStatus = JobStatus.CREATED
- log_id: UID | None = None
- parent_job_id: UID | None = None
- n_iters: int | None = 0
- current_iter: int | None = None
- creation_time: str | None = Field(
- default_factory=lambda: str(datetime.now(tz=timezone.utc))
- )
- action: Action | None = None
- job_pid: int | None = None
- job_worker_id: UID | None = None
- updated_at: DateTime | None = None
- user_code_id: UID | None = None
- requested_by: UID | None = None
- job_type: JobType = JobType.JOB
+ return self.get_all(
+ credentials=credentials,
+ filters={"user_code_id": user_code_id},
+ ).unwrap()
+ @as_result(StashException)
+ def get_by_parent_id(self, credentials: SyftVerifyKey, uid: UID) -> list[Job]:
+ return self.get_all(
+ credentials=credentials,
+ filters={"parent_job_id": uid},
+ ).unwrap()
-@migrate(JobV1, Job)
-def migrate_job_update_v1_current() -> list[Callable]:
- return [
- make_set_default("endpoint", None),
- ]
+ @as_result(StashException)
+ def get_by_result_id(self, credentials: SyftVerifyKey, uid: UID) -> Job:
+ return self.get_one(
+ credentials=credentials, filters={"result_id": uid}
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py
index d3529b0906f..d4b96a0deed 100644
--- a/packages/syft/src/syft/service/log/log_service.py
+++ b/packages/syft/src/syft/service/log/log_service.py
@@ -1,6 +1,6 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.uid import UID
from ..action.action_permissions import StoragePermission
from ..context import AuthedServiceContext
@@ -16,11 +16,9 @@
@serializable(canonical_name="LogService", version=1)
class LogService(AbstractService):
- store: DocumentStore
stash: LogStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = LogStash(store=store)
@service_method(path="log.add", name="add", roles=DATA_SCIENTIST_ROLE_LEVEL)
diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py
index c4072bfcfa5..ef50b081c24 100644
--- a/packages/syft/src/syft/service/log/log_stash.py
+++ b/packages/syft/src/syft/service/log/log_stash.py
@@ -1,17 +1,9 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
+from ...store.db.stash import ObjectStash
from .log import SyftLog
@serializable(canonical_name="LogStash", version=1)
-class LogStash(NewBaseUIDStoreStash):
- object_type = SyftLog
- settings: PartitionSettings = PartitionSettings(
- name=SyftLog.__canonical_name__, object_type=SyftLog
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+class LogStash(ObjectStash[SyftLog]):
+ pass
diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py
index 70453d9b084..b7b450b037b 100644
--- a/packages/syft/src/syft/service/metadata/metadata_service.py
+++ b/packages/syft/src/syft/service/metadata/metadata_service.py
@@ -2,7 +2,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ..context import AuthedServiceContext
from ..service import AbstractService
from ..service import service_method
@@ -12,8 +12,8 @@
@serializable(canonical_name="MetadataService", version=1)
class MetadataService(AbstractService):
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
+ pass
@service_method(
path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL
diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py
index 09848789559..b1d461e4e80 100644
--- a/packages/syft/src/syft/service/migration/migration_service.py
+++ b/packages/syft/src/syft/service/migration/migration_service.py
@@ -1,25 +1,25 @@
# stdlib
from collections import defaultdict
-from typing import cast
# syft absolute
import syft
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...store.document_store import StorePartition
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...types.blob_storage import BlobStorageEntry
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.syft_object import SyftObject
from ...types.syft_object_registry import SyftObjectRegistry
+from ...types.twin_object import TwinObject
from ..action.action_object import Action
from ..action.action_object import ActionObject
from ..action.action_permissions import ActionObjectPermission
from ..action.action_permissions import StoragePermission
-from ..action.action_store import KeyValueActionStore
+from ..action.action_store import ActionObjectStash
from ..context import AuthedServiceContext
from ..response import SyftSuccess
from ..service import AbstractService
@@ -34,11 +34,9 @@
@serializable(canonical_name="MigrationService", version=1)
class MigrationService(AbstractService):
- store: DocumentStore
stash: SyftMigrationStateStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = SyftMigrationStateStash(store=store)
@service_method(path="migration", name="get_version")
@@ -75,9 +73,7 @@ def register_migration_state(
obj = SyftObjectMigrationState(
current_version=current_version, canonical_name=canonical_name
)
- return self.stash.set(
- migration_state=obj, credentials=context.credentials
- ).unwrap()
+ return self.stash.set(obj=obj, credentials=context.credentials).unwrap()
@as_result(SyftException, NotFoundException)
def _find_klasses_pending_for_migration(
@@ -120,83 +116,37 @@ def get_all_store_metadata(
return self._get_all_store_metadata(
context,
document_store_object_types=document_store_object_types,
- include_action_store=include_action_store,
).unwrap()
- @as_result(SyftException)
- def _get_partition_from_type(
- self,
- context: AuthedServiceContext,
- object_type: type[SyftObject],
- ) -> KeyValueActionStore | StorePartition:
- object_partition: KeyValueActionStore | StorePartition | None = None
- if issubclass(object_type, ActionObject):
- object_partition = cast(KeyValueActionStore, context.server.action_store)
- else:
- canonical_name = object_type.__canonical_name__ # type: ignore[unreachable]
- object_partition = self.store.partitions.get(canonical_name)
-
- if object_partition is None:
- raise SyftException(
- public_message=f"Object partition not found for {object_type}"
- ) # type: ignore
-
- return object_partition
-
- @as_result(SyftException)
- def _get_store_metadata(
- self,
- context: AuthedServiceContext,
- object_type: type[SyftObject],
- ) -> StoreMetadata:
- object_partition = self._get_partition_from_type(context, object_type).unwrap()
- permissions = dict(object_partition.get_all_permissions().unwrap())
- storage_permissions = dict(
- object_partition.get_all_storage_permissions().unwrap()
- )
- return StoreMetadata(
- object_type=object_type,
- permissions=permissions,
- storage_permissions=storage_permissions,
- )
-
@as_result(SyftException)
def _get_all_store_metadata(
self,
context: AuthedServiceContext,
document_store_object_types: list[type[SyftObject]] | None = None,
- include_action_store: bool = True,
) -> dict[type[SyftObject], StoreMetadata]:
- if document_store_object_types is None:
- document_store_object_types = self.store.get_partition_object_types()
-
+ # metadata = permissions + storage permissions
+ stashes = context.server.services.stashes
store_metadata = {}
- for klass in document_store_object_types:
- store_metadata[klass] = self._get_store_metadata(context, klass).unwrap()
- if include_action_store:
- store_metadata[ActionObject] = self._get_store_metadata(
- context, ActionObject
- ).unwrap()
- return store_metadata
+ for klass, stash in stashes.items():
+ if (
+ document_store_object_types is not None
+ and klass not in document_store_object_types
+ ):
+ continue
+ store_metadata[klass] = StoreMetadata(
+ object_type=klass,
+ permissions=stash.get_all_permissions().unwrap(),
+ storage_permissions=stash.get_all_storage_permissions().unwrap(),
+ )
- @service_method(
- path="migration.update_store_metadata",
- name="update_store_metadata",
- roles=ADMIN_ROLE_LEVEL,
- )
- def update_store_metadata(
- self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata]
- ) -> None:
- return self._update_store_metadata(context, store_metadata).unwrap()
+ return store_metadata
@as_result(SyftException)
def _update_store_metadata_for_klass(
self, context: AuthedServiceContext, metadata: StoreMetadata
) -> None:
- object_partition = self._get_partition_from_type(
- context, metadata.object_type
- ).unwrap()
+ stash = self._search_stash_for_klass(context, metadata.object_type).unwrap()
permissions = [
ActionObjectPermission.from_permission_string(uid, perm_str)
for uid, perm_strs in metadata.permissions.items()
@@ -209,8 +159,8 @@ def _update_store_metadata_for_klass(
for server_uid in server_uids
]
- object_partition.add_permissions(permissions)
- object_partition.add_storage_permissions(storage_permissions)
+ stash.add_permissions(permissions, ignore_missing=True).unwrap()
+ stash.add_storage_permissions(storage_permissions, ignore_missing=True).unwrap()
@as_result(SyftException)
def _update_store_metadata(
@@ -220,21 +170,6 @@ def _update_store_metadata(
for metadata in store_metadata.values():
self._update_store_metadata_for_klass(context, metadata).unwrap()
- @service_method(
- path="migration.get_migration_objects",
- name="get_migration_objects",
- roles=ADMIN_ROLE_LEVEL,
- )
- def get_migration_objects(
- self,
- context: AuthedServiceContext,
- document_store_object_types: list[type[SyftObject]] | None = None,
- get_all: bool = False,
- ) -> dict:
- return self._get_migration_objects(
- context, document_store_object_types, get_all
- ).unwrap()
-
@as_result(SyftException)
def _get_migration_objects(
self,
@@ -243,7 +178,7 @@ def _get_migration_objects(
get_all: bool = False,
) -> dict[type[SyftObject], list[SyftObject]]:
if document_store_object_types is None:
- document_store_object_types = self.store.get_partition_object_types()
+ document_store_object_types = list(context.server.services.stashes.keys())
if get_all:
klasses_to_migrate = document_store_object_types
@@ -255,14 +190,12 @@ def _get_migration_objects(
result = defaultdict(list)
for klass in klasses_to_migrate:
- canonical_name = klass.__canonical_name__
- object_partition = self.store.partitions.get(canonical_name)
- if object_partition is None:
+ stash_or_err = self._search_stash_for_klass(context, klass)
+ if stash_or_err.is_err():
continue
- objects = object_partition.all(
- context.credentials, has_permission=True
- ).unwrap()
- for object in objects:
+ stash = stash_or_err.unwrap()
+
+ for object in stash._data:
actual_klass = type(object)
use_klass = (
klass
@@ -274,24 +207,33 @@ def _get_migration_objects(
return dict(result)
@as_result(SyftException)
- def _search_partition_for_object(
- self, context: AuthedServiceContext, obj: SyftObject
- ) -> StorePartition:
- klass = type(obj)
+ def _search_stash_for_klass(
+ self, context: AuthedServiceContext, klass: type[SyftObject]
+ ) -> ObjectStash:
+ if issubclass(klass, ActionObject | TwinObject | Action):
+ return context.server.services.action.stash
+
+ stashes: dict[str, ObjectStash] = { # type: ignore
+ t.__canonical_name__: stash
+ for t, stash in context.server.services.stashes.items()
+ }
+
mro = klass.__mro__
class_index = 0
- object_partition = None
+ object_stash = None
while len(mro) > class_index:
- canonical_name = mro[class_index].__canonical_name__
- object_partition = self.store.partitions.get(canonical_name)
- if object_partition is not None:
+ try:
+ canonical_name = mro[class_index].__canonical_name__
+ except AttributeError:
+ # Classes without cname dont have a stash
+ break
+ object_stash = stashes.get(canonical_name)
+ if object_stash is not None:
break
class_index += 1
- if object_partition is None:
- raise SyftException(
- public_message=f"Object partition not found for {klass}"
- )
- return object_partition
+ if object_stash is None:
+ raise SyftException(public_message=f"Object stash not found for {klass}")
+ return object_stash
@service_method(
path="migration.create_migrated_objects",
@@ -316,12 +258,11 @@ def _create_migrated_objects(
ignore_existing: bool = True,
) -> SyftSuccess:
for migrated_object in migrated_objects:
- object_partition = self._search_partition_for_object(
- context, migrated_object
+ stash = self._search_stash_for_klass(
+ context, type(migrated_object)
).unwrap()
- # upsert the object
- result = object_partition.set(
+ result = stash.set(
context.credentials,
obj=migrated_object,
)
@@ -340,34 +281,20 @@ def _create_migrated_objects(
result.unwrap() # this will raise the exception inside the wrapper
return SyftSuccess(message="Created migrate objects!")
- @service_method(
- path="migration.update_migrated_objects",
- name="update_migrated_objects",
- roles=ADMIN_ROLE_LEVEL,
- )
- def update_migrated_objects(
- self, context: AuthedServiceContext, migrated_objects: list[SyftObject]
- ) -> None:
- self._update_migrated_objects(context, migrated_objects).unwrap()
-
@as_result(SyftException)
def _update_migrated_objects(
self, context: AuthedServiceContext, migrated_objects: list[SyftObject]
) -> SyftSuccess:
for migrated_object in migrated_objects:
- object_partition = self._search_partition_for_object(
- context, migrated_object
+ stash = self._search_stash_for_klass(
+ context, type(migrated_object)
).unwrap()
- qk = object_partition.settings.store_key.with_obj(migrated_object.id)
- object_partition._update(
+ stash.update(
context.credentials,
- qk=qk,
obj=migrated_object,
- has_permission=True,
- overwrite=True,
- allow_missing_keys=True,
).unwrap()
+
return SyftSuccess(message="Updated migration objects!")
@as_result(SyftException)
@@ -405,33 +332,18 @@ def migrate_data(
context: AuthedServiceContext,
document_store_object_types: list[type[SyftObject]] | None = None,
) -> SyftSuccess:
- # Track all object type that need migration for document store
-
- # get all objects, keyed by type (because we might want to have different rules for different types)
- # Q: will this be tricky with the protocol????
- # A: For now we will assume that the client will have the same version
-
- # Then, locally we write stuff that says
- # for klass, objects in migration_dict.items():
- # for object in objects:
- # if isinstance(object, X):
- # do something custom
- # else:
- # migrated_value = object.migrate_to(klass.__version__, context)
- #
- # migrated_values = [SyftObject]
- # client.migration.write_migrated_values(migrated_values)
-
migration_objects = self._get_migration_objects(
context, document_store_object_types
).unwrap()
migrated_objects = self._migrate_objects(context, migration_objects).unwrap()
self._update_migrated_objects(context, migrated_objects).unwrap()
+
migration_actionobjects = self._get_migration_actionobjects(context).unwrap()
migrated_actionobjects = self._migrate_objects(
context, migration_actionobjects
).unwrap()
self._update_migrated_actionobjects(context, migrated_actionobjects).unwrap()
+
return SyftSuccess(message="Data upgraded to the latest version")
@service_method(
@@ -449,7 +361,7 @@ def _get_migration_actionobjects(
self, context: AuthedServiceContext, get_all: bool = False
) -> dict[type[SyftObject], list[SyftObject]]:
# Track all object types from action store
- action_object_types = [Action, ActionObject]
+ action_object_types = [Action, ActionObject, TwinObject]
action_object_types.extend(ActionObject.__subclasses__())
klass_by_canonical_name: dict[str, type[SyftObject]] = {
klass.__canonical_name__: klass for klass in action_object_types
@@ -459,10 +371,8 @@ def _get_migration_actionobjects(
context=context, object_types=action_object_types
).unwrap()
result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list)
- action_store = context.server.action_store
- action_store_objects = action_store._all(
- context.credentials, has_permission=True
- )
+ action_stash = context.server.services.action.stash
+ action_store_objects = action_stash.get_all(context.credentials).unwrap()
for obj in action_store_objects:
if get_all or type(obj) in action_object_pending_migration:
@@ -470,26 +380,16 @@ def _get_migration_actionobjects(
result_dict[klass].append(obj) # type: ignore
return dict(result_dict)
- @service_method(
- path="migration.update_migrated_actionobjects",
- name="update_migrated_actionobjects",
- roles=ADMIN_ROLE_LEVEL,
- )
- def update_migrated_actionobjects(
- self, context: AuthedServiceContext, objects: list[SyftObject]
- ) -> SyftSuccess:
- self._update_migrated_actionobjects(context, objects).unwrap()
- return SyftSuccess(message="succesfully migrated actionobjects")
-
@as_result(SyftException)
def _update_migrated_actionobjects(
self, context: AuthedServiceContext, objects: list[SyftObject]
) -> str:
- # Track all object types from action store
- action_store = context.server.action_store
+ action_store: ActionObjectStash = context.server.services.action.stash
for obj in objects:
- action_store.set(
- uid=obj.id, credentials=context.credentials, syft_object=obj
+ action_store.set_or_update(
+ uid=obj.id,
+ credentials=context.credentials,
+ syft_object=obj,
).unwrap()
return "success"
diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py
index b7fa2bb7cbd..22363d867f2 100644
--- a/packages/syft/src/syft/service/migration/object_migration_state.py
+++ b/packages/syft/src/syft/service/migration/object_migration_state.py
@@ -15,10 +15,8 @@
from ...serde.serialize import _serialize
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseStash
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import CreateBlobStorageEntry
@@ -34,7 +32,6 @@
from ...types.transforms import make_set_default
from ...types.uid import UID
from ...util.util import prompt_warning_message
-from ..action.action_permissions import ActionObjectPermission
from ..response import SyftSuccess
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
from ..worker.worker_image import SyftWorkerImage
@@ -70,45 +67,16 @@ def supported_versions(self) -> list:
KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str)
-@serializable(canonical_name="SyftMigrationStateStash", version=1)
-class SyftMigrationStateStash(NewBaseStash):
- object_type = SyftObjectMigrationState
- settings: PartitionSettings = PartitionSettings(
- name=SyftObjectMigrationState.__canonical_name__,
- object_type=SyftObjectMigrationState,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
- @as_result(SyftException)
- def set( # type: ignore [override]
- self,
- credentials: SyftVerifyKey,
- migration_state: SyftObjectMigrationState,
- add_permissions: list[ActionObjectPermission] | None = None,
- add_storage_permission: bool = True,
- ignore_duplicates: bool = False,
- ) -> SyftObjectMigrationState:
- obj = self.check_type(migration_state, self.object_type).unwrap()
- return (
- super()
- .set(
- credentials=credentials,
- obj=obj,
- add_permissions=add_permissions,
- add_storage_permission=add_storage_permission,
- ignore_duplicates=ignore_duplicates,
- )
- .unwrap()
- )
-
+@serializable(canonical_name="SyftMigrationStateSQLStash", version=1)
+class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]):
@as_result(SyftException, NotFoundException)
def get_by_name(
self, canonical_name: str, credentials: SyftVerifyKey
) -> SyftObjectMigrationState:
- qks = KlassNamePartitionKey.with_obj(canonical_name)
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"canonical_name": canonical_name},
+ ).unwrap()
@serializable()
@@ -276,6 +244,14 @@ def get_items_by_canonical_name(self, canonical_name: str) -> list[SyftObject]:
return v
return []
+ def get_metadata_by_canonical_name(self, canonical_name: str) -> StoreMetadata:
+ for k, v in self.metadata.items():
+ if k.__canonical_name__ == canonical_name:
+ return v
+ return StoreMetadata(
+ object_type=SyftObject, permissions={}, storage_permissions={}
+ )
+
def copy_without_workerpools(self) -> "MigrationData":
items_to_exclude = [
WorkerPool.__canonical_name__,
diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py
index 428501fb92d..aa2264dbe10 100644
--- a/packages/syft/src/syft/service/network/network_service.py
+++ b/packages/syft/src/syft/service/network/network_service.py
@@ -15,11 +15,8 @@
from ...server.credentials import SyftVerifyKey
from ...server.worker_settings import WorkerSettings
from ...service.settings.settings import ServerSettings
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
@@ -36,7 +33,6 @@
from ...util.util import prompt_warning_message
from ...util.util import str_to_bool
from ..context import AuthedServiceContext
-from ..data_subject.data_subject import NamePartitionKey
from ..metadata.server_metadata import ServerMetadata
from ..request.request import Request
from ..request.request import RequestStatus
@@ -61,10 +57,6 @@
logger = logging.getLogger(__name__)
-VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey)
-ServerTypePartitionKey = PartitionKey(key="server_type", type_=ServerType)
-OrderByNamePartitionKey = PartitionKey(key="name", type_=str)
-
REVERSE_TUNNEL_ENABLED = "REVERSE_TUNNEL_ENABLED"
@@ -79,40 +71,20 @@ class ServerPeerAssociationStatus(Enum):
PEER_NOT_FOUND = "PEER_NOT_FOUND"
-@serializable(canonical_name="NetworkStash", version=1)
-class NetworkStash(NewBaseUIDStoreStash):
- object_type = ServerPeer
- settings: PartitionSettings = PartitionSettings(
- name=ServerPeer.__canonical_name__, object_type=ServerPeer
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="NetworkSQLStash", version=1)
+class NetworkStash(ObjectStash[ServerPeer]):
@as_result(StashException, NotFoundException)
def get_by_name(self, credentials: SyftVerifyKey, name: str) -> ServerPeer:
- qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)])
try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
- except NotFoundException as exc:
+ return self.get_one(
+ credentials=credentials,
+ filters={"name": name},
+ ).unwrap()
+ except NotFoundException as e:
raise NotFoundException.from_exception(
- exc, public_message=f"ServerPeer with {name} not found"
+ e, public_message=f"ServerPeer with {name} not found"
)
- @as_result(StashException)
- def update(
- self,
- credentials: SyftVerifyKey,
- peer_update: ServerPeerUpdate,
- has_permission: bool = False,
- ) -> ServerPeer:
- self.check_type(peer_update, ServerPeerUpdate).unwrap()
- return (
- super()
- .update(credentials, peer_update, has_permission=has_permission)
- .unwrap()
- )
-
@as_result(StashException)
def create_or_update_peer(
self, credentials: SyftVerifyKey, peer: ServerPeer
@@ -148,28 +120,26 @@ def create_or_update_peer(
def get_by_verify_key(
self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> ServerPeer:
- qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)])
- return self.query_one(credentials, qks).unwrap(
- private_message=f"ServerPeer with {verify_key} not found"
- )
+ return self.get_one(
+ credentials=credentials,
+ filters={"verify_key": verify_key},
+ ).unwrap()
@as_result(StashException)
def get_by_server_type(
self, credentials: SyftVerifyKey, server_type: ServerType
) -> list[ServerPeer]:
- qks = QueryKeys(qks=[ServerTypePartitionKey.with_obj(server_type)])
- return self.query_all(
- credentials=credentials, qks=qks, order_by=OrderByNamePartitionKey
+ return self.get_all(
+ credentials=credentials,
+ filters={"server_type": server_type},
).unwrap()
@serializable(canonical_name="NetworkService", version=1)
class NetworkService(AbstractService):
- store: DocumentStore
stash: NetworkStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = NetworkStash(store=store)
if reverse_tunnel_enabled():
self.rtunnel_service = ReverseTunnelService()
@@ -407,7 +377,8 @@ def get_all_peers(self, context: AuthedServiceContext) -> list[ServerPeer]:
"""Get all Peers"""
return self.stash.get_all(
credentials=context.server.verify_key,
- order_by=OrderByNamePartitionKey,
+ order_by="name",
+ sort_order="asc",
).unwrap()
@service_method(
@@ -447,7 +418,7 @@ def update_peer(
peer = self.stash.update(
credentials=context.server.verify_key,
- peer_update=peer_update,
+ obj=peer_update,
).unwrap()
self.set_reverse_tunnel_config(context=context, remote_server_peer=peer)
@@ -494,7 +465,6 @@ def set_reverse_tunnel_config(
def delete_peer_by_id(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess:
"""Delete Server Peer"""
peer_to_delete = self.stash.get_by_uid(context.credentials, uid).unwrap()
- peer_to_delete = cast(ServerPeer, peer_to_delete)
server_side_type = cast(ServerType, context.server.server_type)
if server_side_type.value == ServerType.GATEWAY.value:
@@ -608,7 +578,7 @@ def add_route(
)
self.stash.update(
credentials=context.server.verify_key,
- peer_update=peer_update,
+ obj=peer_update,
).unwrap()
return SyftSuccess(
@@ -736,7 +706,7 @@ def delete_route(
id=remote_server_peer.id, server_routes=remote_server_peer.server_routes
)
self.stash.update(
- credentials=context.server.verify_key, peer_update=peer_update
+ credentials=context.server.verify_key, obj=peer_update
).unwrap()
return SyftSuccess(message=return_message)
diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py
index 5cd7a5f2136..f6de35f75fd 100644
--- a/packages/syft/src/syft/service/network/routes.py
+++ b/packages/syft/src/syft/service/network/routes.py
@@ -130,8 +130,7 @@ def server(self) -> AbstractServer | None:
server_type=self.worker_settings.server_type,
server_side_type=self.worker_settings.server_side_type,
signing_key=self.worker_settings.signing_key,
- document_store_config=self.worker_settings.document_store_config,
- action_store_config=self.worker_settings.action_store_config,
+ db_config=self.worker_settings.db_config,
processes=1,
)
return server
diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py
index 280e836f17b..655d40b9b3e 100644
--- a/packages/syft/src/syft/service/network/utils.py
+++ b/packages/syft/src/syft/service/network/utils.py
@@ -88,7 +88,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> None:
result = network_stash.update(
credentials=context.server.verify_key,
- peer_update=peer_update,
+ obj=peer_update,
has_permission=True,
)
diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py
index 2ebc0908a88..c53da5fabef 100644
--- a/packages/syft/src/syft/service/notification/email_templates.py
+++ b/packages/syft/src/syft/service/notification/email_templates.py
@@ -133,7 +133,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) ->
@staticmethod
def email_body(notification: "Notification", context: AuthedServiceContext) -> str:
user_service = context.server.services.user
- admin_verify_key = user_service.admin_verify_key()
+ admin_verify_key = user_service.root_verify_key
user = user_service.stash.get_by_verify_key(
credentials=admin_verify_key, verify_key=notification.to_user_verify_key
).unwrap()
@@ -224,7 +224,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) ->
@staticmethod
def email_body(notification: "Notification", context: AuthedServiceContext) -> str:
user_service = context.server.services.user
- admin_verify_key = user_service.admin_verify_key()
+ admin_verify_key = user_service.root_verify_key
admin = user_service.get_by_verify_key(admin_verify_key).unwrap()
admin_name = admin.name
diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py
index 15b2b9725d8..7d6cb83e2f1 100644
--- a/packages/syft/src/syft/service/notification/notification_service.py
+++ b/packages/syft/src/syft/service/notification/notification_service.py
@@ -2,7 +2,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
from ...types.result import as_result
@@ -28,11 +28,9 @@
@serializable(canonical_name="NotificationService", version=1)
class NotificationService(AbstractService):
- store: DocumentStore
stash: NotificationStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = NotificationStash(store=store)
@service_method(path="notifications.send", name="send")
diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py
index fd41ad0dda6..029cf1b325a 100644
--- a/packages/syft/src/syft/service/notification/notification_stash.py
+++ b/packages/syft/src/syft/service/notification/notification_stash.py
@@ -1,80 +1,49 @@
-# stdlib
-
-# third party
-
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
-from ...types.datetime import DateTime
from ...types.result import as_result
from ...types.uid import UID
from .notifications import Notification
from .notifications import NotificationStatus
-FromUserVerifyKeyPartitionKey = PartitionKey(
- key="from_user_verify_key", type_=SyftVerifyKey
-)
-ToUserVerifyKeyPartitionKey = PartitionKey(
- key="to_user_verify_key", type_=SyftVerifyKey
-)
-StatusPartitionKey = PartitionKey(key="status", type_=NotificationStatus)
-
-OrderByCreatedAtTimeStampPartitionKey = PartitionKey(key="created_at", type_=DateTime)
-
-LinkedObjectPartitionKey = PartitionKey(key="linked_obj", type_=LinkedObject)
-
-
-@serializable(canonical_name="NotificationStash", version=1)
-class NotificationStash(NewBaseUIDStoreStash):
- object_type = Notification
- settings: PartitionSettings = PartitionSettings(
- name=Notification.__canonical_name__,
- object_type=Notification,
- )
+@serializable(canonical_name="NotificationSQLStash", version=1)
+class NotificationStash(ObjectStash[Notification]):
@as_result(StashException)
def get_all_inbox_for_verify_key(
self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> list[Notification]:
- qks = QueryKeys(
- qks=[
- ToUserVerifyKeyPartitionKey.with_obj(verify_key),
- ]
- )
- return self.get_all_for_verify_key(
- credentials=credentials, verify_key=verify_key, qks=qks
+ if not isinstance(verify_key, SyftVerifyKey | str):
+ raise AttributeError("verify_key must be of type SyftVerifyKey or str")
+ return self.get_all(
+ credentials,
+ filters={"to_user_verify_key": verify_key},
).unwrap()
@as_result(StashException)
def get_all_sent_for_verify_key(
self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> list[Notification]:
- qks = QueryKeys(
- qks=[
- FromUserVerifyKeyPartitionKey.with_obj(verify_key),
- ]
- )
- return self.get_all_for_verify_key(
- credentials, verify_key=verify_key, qks=qks
+ if not isinstance(verify_key, SyftVerifyKey | str):
+ raise AttributeError("verify_key must be of type SyftVerifyKey or str")
+ return self.get_all(
+ credentials,
+ filters={"from_user_verify_key": verify_key},
).unwrap()
@as_result(StashException)
def get_all_for_verify_key(
- self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, qks: QueryKeys
+ self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> list[Notification]:
- if isinstance(verify_key, str):
- verify_key = SyftVerifyKey.from_string(verify_key)
- return self.query_all(
+ if not isinstance(verify_key, SyftVerifyKey | str):
+ raise AttributeError("verify_key must be of type SyftVerifyKey or str")
+ return self.get_all(
credentials,
- qks=qks,
- order_by=OrderByCreatedAtTimeStampPartitionKey,
+ filters={"from_user_verify_key": verify_key},
).unwrap()
@as_result(StashException)
@@ -84,16 +53,14 @@ def get_all_by_verify_key_for_status(
verify_key: SyftVerifyKey,
status: NotificationStatus,
) -> list[Notification]:
- qks = QueryKeys(
- qks=[
- ToUserVerifyKeyPartitionKey.with_obj(verify_key),
- StatusPartitionKey.with_obj(status),
- ]
- )
- return self.query_all(
+ if not isinstance(verify_key, SyftVerifyKey | str):
+ raise AttributeError("verify_key must be of type SyftVerifyKey or str")
+ return self.get_all(
credentials,
- qks=qks,
- order_by=OrderByCreatedAtTimeStampPartitionKey,
+ filters={
+ "to_user_verify_key": str(verify_key),
+ "status": status.name,
+ },
).unwrap()
@as_result(StashException, NotFoundException)
@@ -102,14 +69,12 @@ def get_notification_for_linked_obj(
credentials: SyftVerifyKey,
linked_obj: LinkedObject,
) -> Notification:
- qks = QueryKeys(
- qks=[
- LinkedObjectPartitionKey.with_obj(linked_obj),
- ]
- )
- return self.query_one(credentials=credentials, qks=qks).unwrap(
- public_message=f"Notifications for Linked Object {linked_obj} not found"
- )
+ return self.get_one(
+ credentials,
+ filters={
+ "linked_obj.id": linked_obj.id,
+ },
+ ).unwrap()
@as_result(StashException, NotFoundException)
def update_notification_status(
@@ -123,6 +88,8 @@ def update_notification_status(
def delete_all_for_verify_key(
self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> bool:
+ if not isinstance(verify_key, SyftVerifyKey | str):
+ raise AttributeError("verify_key must be of type SyftVerifyKey or str")
notifications = self.get_all_inbox_for_verify_key(
credentials,
verify_key=verify_key,
diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py
index dffad4e414e..3fbddf6eb98 100644
--- a/packages/syft/src/syft/service/notification/notifications.py
+++ b/packages/syft/src/syft/service/notification/notifications.py
@@ -8,7 +8,9 @@
from ...server.credentials import SyftVerifyKey
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
+from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import add_credentials_for_key
@@ -48,7 +50,7 @@ class ReplyNotification(SyftObject):
@serializable()
-class Notification(SyftObject):
+class NotificationV1(SyftObject):
__canonical_name__ = "Notification"
__version__ = SYFT_OBJECT_VERSION_1
@@ -71,6 +73,32 @@ class Notification(SyftObject):
__repr_attrs__ = ["subject", "status", "created_at", "linked_obj"]
__table_sort_attr__ = "Created at"
+
+@serializable()
+class Notification(SyftObject):
+ __canonical_name__ = "Notification"
+ __version__ = SYFT_OBJECT_VERSION_2
+
+ subject: str
+ server_uid: UID
+ from_user_verify_key: SyftVerifyKey
+ to_user_verify_key: SyftVerifyKey
+ created_at: DateTime
+ status: NotificationStatus = NotificationStatus.UNREAD
+ linked_obj: LinkedObject | None = None
+ notifier_types: list[NOTIFIERS] = []
+ email_template: type[EmailTemplate] | None = None
+ replies: list[ReplyNotification] = []
+
+ __attr_searchable__ = [
+ "from_user_verify_key",
+ "to_user_verify_key",
+ "status",
+ ]
+ __repr_attrs__ = ["subject", "status", "created_at", "linked_obj"]
+ __table_sort_attr__ = "Created at"
+ __order_by__ = ("created_at", "asc")
+
def _repr_html_(self) -> str:
return f"""
@@ -145,3 +173,8 @@ def createnotification_to_notification() -> list[Callable]:
add_credentials_for_key("from_user_verify_key"),
add_server_uid_for_key("server_uid"),
]
+
+
+@migrate(NotificationV1, Notification)
+def migrate_nofitication_v1_to_v2() -> list[Callable]:
+ return [] # skip migration, no changes in the class
diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py
index dbde4905972..31cb439a614 100644
--- a/packages/syft/src/syft/service/notifier/notifier_service.py
+++ b/packages/syft/src/syft/service/notifier/notifier_service.py
@@ -8,7 +8,7 @@
# relative
from ...abstract_server import AbstractServer
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
@@ -35,11 +35,9 @@ class RateLimitException(SyftException):
@serializable(canonical_name="NotifierService", version=1)
class NotifierService(AbstractService):
- store: DocumentStore
- stash: NotifierStash # Which stash should we use?
+ stash: NotifierStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = NotifierStash(store=store)
@as_result(StashException)
@@ -258,7 +256,7 @@ def init_notifier(
"""
try:
# Create a new NotifierStash since its a static method.
- notifier_stash = NotifierStash(store=server.document_store)
+ notifier_stash = NotifierStash(store=server.db)
should_update = False
# Get the notifier
@@ -325,7 +323,7 @@ def set_email_rate_limit(
def dispatch_notification(
self, context: AuthedServiceContext, notification: Notification
) -> SyftSuccess:
- admin_key = context.server.services.user.admin_verify_key()
+ admin_key = context.server.services.user.root_verify_key
# Silently fail on notification not delivered
try:
diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py
index 8dbe8e31e8f..861f3a89975 100644
--- a/packages/syft/src/syft/service/notifier/notifier_stash.py
+++ b/packages/syft/src/syft/service/notifier/notifier_stash.py
@@ -5,54 +5,29 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseStash
-from ...store.document_store import PartitionKey
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
-from ...types.uid import UID
-from ..action.action_permissions import ActionObjectPermission
+from ...util.telemetry import instrument
from .notifier import NotifierSettings
-NamePartitionKey = PartitionKey(key="name", type_=str)
-ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID])
-
-@serializable(canonical_name="NotifierStash", version=1)
-class NotifierStash(NewBaseStash):
- object_type = NotifierSettings
+@instrument
+@serializable(canonical_name="NotifierSQLStash", version=1)
+class NotifierStash(ObjectStash[NotifierSettings]):
settings: PartitionSettings = PartitionSettings(
name=NotifierSettings.__canonical_name__, object_type=NotifierSettings
)
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
- def admin_verify_key(self) -> SyftVerifyKey:
- return self.partition.root_verify_key
-
- # TODO: should this method behave like a singleton?
@as_result(StashException, NotFoundException)
def get(self, credentials: SyftVerifyKey) -> NotifierSettings:
"""Get Settings"""
- settings: list[NotifierSettings] = self.get_all(credentials).unwrap()
- if len(settings) == 0:
- raise NotFoundException
- return settings[0]
-
- @as_result(StashException)
- def set(
- self,
- credentials: SyftVerifyKey,
- settings: NotifierSettings,
- add_permissions: list[ActionObjectPermission] | None = None,
- add_storage_permission: bool = True,
- ignore_duplicates: bool = False,
- ) -> NotifierSettings:
- result = self.check_type(settings, self.object_type).unwrap()
- # we dont use and_then logic here as it is hard because of the order of the arguments
- return (
- super().set(credentials=credentials, obj=result).unwrap()
- ) # TODO check if result isInstance(Ok)
+ # actually get latest settings
+ result = self.get_all(credentials, limit=1, sort_order="desc").unwrap()
+ if len(result) > 0:
+ return result[0]
+ raise NotFoundException(
+ public_message="No settings found for the current user."
+ )
diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py
index 422788d22f5..5d26ff2cb3e 100644
--- a/packages/syft/src/syft/service/output/output_service.py
+++ b/packages/syft/src/syft/service/output/output_service.py
@@ -7,14 +7,10 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
-from ...store.document_store_errors import TooManyItemsFoundException
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.result import as_result
@@ -183,65 +179,41 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]:
return res
-@serializable(canonical_name="OutputStash", version=1)
-class OutputStash(NewBaseUIDStoreStash):
- object_type = ExecutionOutput
- settings: PartitionSettings = PartitionSettings(
- name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store)
- self.store = store
- self.settings = self.settings
- self._object_type = self.object_type
-
+@serializable(canonical_name="OutputStashSQL", version=1)
+class OutputStash(ObjectStash[ExecutionOutput]):
@as_result(StashException)
def get_by_user_code_id(
self, credentials: SyftVerifyKey, user_code_id: UID
) -> list[ExecutionOutput]:
- qks = QueryKeys(
- qks=[UserCodeIdPartitionKey.with_obj(user_code_id)],
- )
- return self.query_all(
- credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey
+ return self.get_all(
+ credentials=credentials,
+ filters={"user_code_id": user_code_id},
).unwrap()
@as_result(StashException)
def get_by_job_id(
- self, credentials: SyftVerifyKey, user_code_id: UID
- ) -> ExecutionOutput:
- qks = QueryKeys(
- qks=[JobIdPartitionKey.with_obj(user_code_id)],
- )
- res = self.query_all(
- credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey
+ self, credentials: SyftVerifyKey, job_id: UID
+ ) -> ExecutionOutput | None:
+ return self.get_one(
+ credentials=credentials,
+ filters={"job_id": job_id},
).unwrap()
- if len(res) == 0:
- raise NotFoundException()
- elif len(res) > 1:
- raise TooManyItemsFoundException()
- return res[0]
@as_result(StashException)
def get_by_output_policy_id(
self, credentials: SyftVerifyKey, output_policy_id: UID
) -> list[ExecutionOutput]:
- qks = QueryKeys(
- qks=[OutputPolicyIdPartitionKey.with_obj(output_policy_id)],
- )
- return self.query_all(
- credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey
+ return self.get_all(
+ credentials=credentials,
+ filters={"output_policy_id": output_policy_id},
).unwrap()
@serializable(canonical_name="OutputService", version=1)
class OutputService(AbstractService):
- store: DocumentStore
stash: OutputStash
- def __init__(self, store: DocumentStore):
- self.store = store
+ def __init__(self, store: DBManager):
self.stash = OutputStash(store=store)
@service_method(
@@ -310,7 +282,7 @@ def has_output_read_permissions(
ActionObjectREAD(uid=_id.id, credentials=user_verify_key)
for _id in result_ids
]
- if context.server.services.action.store.has_permissions(permissions):
+ if context.server.services.action.stash.has_permissions(permissions):
return True
return False
@@ -321,11 +293,11 @@ def has_output_read_permissions(
roles=ADMIN_ROLE_LEVEL,
)
def get_by_job_id(
- self, context: AuthedServiceContext, user_code_id: UID
+ self, context: AuthedServiceContext, job_id: UID
) -> ExecutionOutput:
return self.stash.get_by_job_id(
credentials=context.server.verify_key, # type: ignore
- user_code_id=user_code_id,
+ job_id=job_id,
).unwrap()
@service_method(
diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py
index 3b1f33c0a08..1e33755418e 100644
--- a/packages/syft/src/syft/service/policy/policy.py
+++ b/packages/syft/src/syft/service/policy/policy.py
@@ -174,7 +174,6 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str
from ..action.action_object import ActionObject
# fetches the all the current api's connected
- api_list = APIRegistry.get_all_api()
output_kwargs = {}
for k, v in kwargs.items():
uid = v
@@ -190,7 +189,7 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str
raise Exception(f"Input {k} must have a UID not {type(v)}")
_obj_exists = False
- for api in api_list:
+ for identity, api in APIRegistry.__api_registry__.items():
try:
if api.services.action.exists(uid):
server_identity = ServerIdentity.from_api(api)
@@ -205,6 +204,9 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str
# To handle the cases , where there an old api objects in
# in APIRegistry
continue
+ except Exception as e:
+ print(f"Error in partition_by_server with identity {identity}", e)
+ raise e
if not _obj_exists:
raise Exception(f"Input data {k}:{uid} does not belong to any Datasite")
@@ -335,7 +337,7 @@ class UserOwned(PolicyRule):
def is_owned(
self, context: AuthedServiceContext, action_object: ActionObject
) -> bool:
- action_store = context.server.services.action.store
+ action_store = context.server.services.action.stash
return action_store.has_permission(
ActionObjectPermission(
action_object.id, ActionPermission.OWNER, context.credentials
diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py
index fe32c226dbf..1e5f430d109 100644
--- a/packages/syft/src/syft/service/policy/policy_service.py
+++ b/packages/syft/src/syft/service/policy/policy_service.py
@@ -2,7 +2,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftSuccess
@@ -16,11 +16,9 @@
@serializable(canonical_name="PolicyService", version=1)
class PolicyService(AbstractService):
- store: DocumentStore
stash: UserPolicyStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = UserPolicyStash(store=store)
@service_method(path="policy.get_all", name="get_all")
diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py
index 38ab7f54c06..9e3a103280b 100644
--- a/packages/syft/src/syft/service/policy/user_policy_stash.py
+++ b/packages/syft/src/syft/service/policy/user_policy_stash.py
@@ -3,30 +3,20 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
-from .policy import PolicyUserVerifyKeyPartitionKey
from .policy import UserPolicy
-@serializable(canonical_name="UserPolicyStash", version=1)
-class UserPolicyStash(NewBaseUIDStoreStash):
- object_type = UserPolicy
- settings: PartitionSettings = PartitionSettings(
- name=UserPolicy.__canonical_name__, object_type=UserPolicy
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="UserPolicySQLStash", version=1)
+class UserPolicyStash(ObjectStash[UserPolicy]):
@as_result(StashException, NotFoundException)
def get_all_by_user_verify_key(
self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey
) -> list[UserPolicy]:
- qks = QueryKeys(qks=[PolicyUserVerifyKeyPartitionKey.with_obj(user_verify_key)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"user_verify_key": user_verify_key},
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py
index 0b64a9d8870..5ed21f5007d 100644
--- a/packages/syft/src/syft/service/project/project.py
+++ b/packages/syft/src/syft/service/project/project.py
@@ -1247,12 +1247,12 @@ def send(self, return_all_projects: bool = False) -> Project | list[Project]:
projects_map = self._create_projects(self.clients)
# bootstrap project with pending events on leader server's project
- self._bootstrap_events(projects_map[leader])
+ self._bootstrap_events(projects_map[leader.id]) # type: ignore
if return_all_projects:
return list(projects_map.values())
- return projects_map[leader]
+ return projects_map[leader.id] # type: ignore
def _pre_submit_checks(self, clients: list[SyftClient]) -> bool:
try:
@@ -1277,12 +1277,12 @@ def _exchange_routes(self, leader: SyftClient, followers: list[SyftClient]) -> N
self.leader_server_route = connection_to_route(leader.connection)
- def _create_projects(self, clients: list[SyftClient]) -> dict[SyftClient, Project]:
- projects: dict[SyftClient, Project] = {}
+ def _create_projects(self, clients: list[SyftClient]) -> dict[UID, Project]:
+ projects: dict[UID, Project] = {}
for client in clients:
result = client.api.services.project.create_project(project=self).value
- projects[client] = result
+ projects[client.id] = result # type: ignore
return projects
diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py
index 3a2543e38a1..2df2fd42e1c 100644
--- a/packages/syft/src/syft/service/project/project_service.py
+++ b/packages/syft/src/syft/service/project/project_service.py
@@ -4,7 +4,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
@@ -32,11 +32,9 @@
@serializable(canonical_name="ProjectService", version=1)
class ProjectService(AbstractService):
- store: DocumentStore
stash: ProjectStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = ProjectStash(store=store)
@as_result(SyftException)
diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py
index bf81bd5b9b1..13dab37bdea 100644
--- a/packages/syft/src/syft/service/project/project_stash.py
+++ b/packages/syft/src/syft/service/project/project_stash.py
@@ -5,49 +5,27 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store import UIDPartitionKey
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
-from ...types.uid import UID
-from ..request.request import Request
from .project import Project
-# TODO: Move to a partitions file?
-VerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey)
-NamePartitionKey = PartitionKey(key="name", type_=str)
-
-@serializable(canonical_name="ProjectStash", version=1)
-class ProjectStash(NewBaseUIDStoreStash):
- object_type = Project
- settings: PartitionSettings = PartitionSettings(
- name=Project.__canonical_name__, object_type=Project
- )
-
- # TODO: Shouldn't this be a list of projects?
+@serializable(canonical_name="ProjectSQLStash", version=1)
+class ProjectStash(ObjectStash[Project]):
@as_result(StashException)
def get_all_for_verify_key(
- self, credentials: SyftVerifyKey, verify_key: VerifyKeyPartitionKey
- ) -> list[Request]:
- if isinstance(verify_key, str):
- verify_key = SyftVerifyKey.from_string(verify_key)
- qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)])
- return self.query_all(
+ self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
+ ) -> list[Project]:
+ return self.get_all(
credentials=credentials,
- qks=qks,
+ filters={"user_verify_key": verify_key},
).unwrap()
- @as_result(StashException, NotFoundException)
- def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> Project:
- qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
-
@as_result(StashException, NotFoundException)
def get_by_name(self, credentials: SyftVerifyKey, project_name: str) -> Project:
- qks = QueryKeys(qks=[NamePartitionKey.with_obj(project_name)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"name": project_name},
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py
index ae170cad95b..b2807389cf1 100644
--- a/packages/syft/src/syft/service/queue/queue.py
+++ b/packages/syft/src/syft/service/queue/queue.py
@@ -6,6 +6,7 @@
from threading import Thread
import time
from typing import Any
+from typing import cast
# third party
import psutil
@@ -178,8 +179,7 @@ def handle_message_multiprocessing(
id=worker_settings.id,
name=worker_settings.name,
signing_key=worker_settings.signing_key,
- document_store_config=worker_settings.document_store_config,
- action_store_config=worker_settings.action_store_config,
+ db_config=worker_settings.db_config,
blob_storage_config=worker_settings.blob_store_config,
server_side_type=worker_settings.server_side_type,
queue_config=queue_config,
@@ -256,7 +256,7 @@ def handle_message_multiprocessing(
public_message=f"Job {queue_item.job_id} not found!"
)
- job_item.server_uid = worker.id
+ job_item.server_uid = worker.id # type: ignore[assignment]
job_item.result = result
job_item.resolved = True
job_item.status = job_status
@@ -282,7 +282,10 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
from ...server.server import Server
queue_item = deserialize(message, from_bytes=True)
+ queue_item = cast(QueueItem, queue_item)
worker_settings = queue_item.worker_settings
+ if worker_settings is None:
+ raise ValueError("Worker settings are missing in the queue item.")
queue_config = worker_settings.queue_config
queue_config.client_config.create_producer = False
@@ -292,9 +295,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
id=worker_settings.id,
name=worker_settings.name,
signing_key=worker_settings.signing_key,
- document_store_config=worker_settings.document_store_config,
- action_store_config=worker_settings.action_store_config,
- blob_storage_config=worker_settings.blob_store_config,
+ db_config=worker_settings.db_config,
server_side_type=worker_settings.server_side_type,
deployment_type=worker_settings.deployment_type,
queue_config=queue_config,
diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py
index c898893ee35..b98f344745d 100644
--- a/packages/syft/src/syft/service/queue/queue_service.py
+++ b/packages/syft/src/syft/service/queue/queue_service.py
@@ -2,30 +2,14 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...types.uid import UID
-from ..context import AuthedServiceContext
+from ...store.db.db import DBManager
from ..service import AbstractService
-from ..service import service_method
-from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
-from .queue_stash import QueueItem
from .queue_stash import QueueStash
@serializable(canonical_name="QueueService", version=1)
class QueueService(AbstractService):
- store: DocumentStore
stash: QueueStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = QueueStash(store=store)
-
- @service_method(
- path="queue.get_subjobs",
- name="get_subjobs",
- roles=DATA_SCIENTIST_ROLE_LEVEL,
- )
- def get_subjobs(self, context: AuthedServiceContext, uid: UID) -> list[QueueItem]:
- # FIX: There is no get_by_parent_id in QueueStash
- return self.stash.get_by_parent_id(context.credentials, uid=uid)
diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py
index 251f4a9fb63..aa5b872b226 100644
--- a/packages/syft/src/syft/service/queue/queue_stash.py
+++ b/packages/syft/src/syft/service/queue/queue_stash.py
@@ -1,4 +1,5 @@
# stdlib
+from collections.abc import Callable
from enum import Enum
from typing import Any
@@ -6,22 +7,24 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...server.worker_settings import WorkerSettings
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseStash
+from ...server.worker_settings import WorkerSettingsV1
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store import UIDPartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
from ...types.errors import SyftException
from ...types.result import as_result
+from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
+from ...types.transforms import TransformContext
from ...types.uid import UID
from ..action.action_permissions import ActionObjectPermission
+__all__ = ["QueueItem"]
+
@serializable(canonical_name="Status", version=1)
class Status(str, Enum):
@@ -37,7 +40,7 @@ class Status(str, Enum):
@serializable()
-class QueueItem(SyftObject):
+class QueueItemV1(SyftObject):
__canonical_name__ = "QueueItem"
__version__ = SYFT_OBJECT_VERSION_1
@@ -49,6 +52,29 @@ class QueueItem(SyftObject):
resolved: bool = False
status: Status = Status.CREATED
+ method: str
+ service: str
+ args: list
+ kwargs: dict[str, Any]
+ job_id: UID | None = None
+ worker_settings: WorkerSettingsV1 | None = None
+ has_execute_permissions: bool = False
+ worker_pool: LinkedObject
+
+
+@serializable()
+class QueueItem(SyftObject):
+ __canonical_name__ = "QueueItem"
+ __version__ = SYFT_OBJECT_VERSION_2
+
+ __attr_searchable__ = ["status", "worker_pool_id"]
+
+ id: UID
+ server_uid: UID
+ result: Any | None = None
+ resolved: bool = False
+ status: Status = Status.CREATED
+
method: str
service: str
args: list
@@ -64,6 +90,10 @@ def __repr__(self) -> str:
def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str:
return f"
: {self.status}"
+ @property
+ def worker_pool_id(self) -> UID:
+ return self.worker_pool.object_uid
+
@property
def is_action(self) -> bool:
return self.service_path == "Action" and self.method_name == "execute"
@@ -78,7 +108,7 @@ def action(self) -> Any:
@serializable()
class ActionQueueItem(QueueItem):
__canonical_name__ = "ActionQueueItem"
- __version__ = SYFT_OBJECT_VERSION_1
+ __version__ = SYFT_OBJECT_VERSION_2
method: str = "execute"
service: str = "actionservice"
@@ -87,22 +117,32 @@ class ActionQueueItem(QueueItem):
@serializable()
class APIEndpointQueueItem(QueueItem):
__canonical_name__ = "APIEndpointQueueItem"
- __version__ = SYFT_OBJECT_VERSION_1
+ __version__ = SYFT_OBJECT_VERSION_2
method: str
service: str = "apiservice"
-@serializable(canonical_name="QueueStash", version=1)
-class QueueStash(NewBaseStash):
- object_type = QueueItem
- settings: PartitionSettings = PartitionSettings(
- name=QueueItem.__canonical_name__, object_type=QueueItem
- )
+@serializable()
+class ActionQueueItemV1(QueueItemV1):
+ __canonical_name__ = "ActionQueueItem"
+ __version__ = SYFT_OBJECT_VERSION_1
+
+ method: str = "execute"
+ service: str = "actionservice"
+
+
+@serializable()
+class APIEndpointQueueItemV1(QueueItemV1):
+ __canonical_name__ = "APIEndpointQueueItem"
+ __version__ = SYFT_OBJECT_VERSION_1
+
+ method: str
+ service: str = "apiservice"
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="QueueSQLStash", version=1)
+class QueueStash(ObjectStash[QueueItem]):
# FIX: Check the return value for None. set_result is used extensively
@as_result(StashException)
def set_result(
@@ -133,11 +173,6 @@ def set_placeholder(
return super().set(credentials, item, add_permissions).unwrap()
return item
- @as_result(StashException)
- def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem:
- qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
-
@as_result(StashException)
def pop(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem | None:
try:
@@ -156,23 +191,51 @@ def pop_on_complete(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem:
self.delete_by_uid(credentials=credentials, uid=uid)
return queue_item
- @as_result(StashException)
- def delete_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> UID:
- qk = UIDPartitionKey.with_obj(uid)
- super().delete(credentials=credentials, qk=qk).unwrap()
- return uid
-
@as_result(StashException)
def get_by_status(
self, credentials: SyftVerifyKey, status: Status
) -> list[QueueItem]:
- qks = QueryKeys(qks=StatusPartitionKey.with_obj(status))
-
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"status": status},
+ ).unwrap()
@as_result(StashException)
def _get_by_worker_pool(
self, credentials: SyftVerifyKey, worker_pool: LinkedObject
) -> list[QueueItem]:
- qks = QueryKeys(qks=_WorkerPoolPartitionKey.with_obj(worker_pool))
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ worker_pool_id = worker_pool.object_uid
+
+ return self.get_all(
+ credentials=credentials,
+ filters={"worker_pool_id": worker_pool_id},
+ ).unwrap()
+
+
+def upgrade_worker_settings_for_queue(context: TransformContext) -> TransformContext:
+ if context.output and context.output["worker_settings"] is not None:
+ worker_settings_old: WorkerSettingsV1 | None = context.output["worker_settings"]
+ if worker_settings_old is None:
+ return context
+
+ worker_settings = worker_settings_old.migrate_to(
+ WorkerSettings.__version__, context=context.to_server_context()
+ )
+ context.output["worker_settings"] = worker_settings
+
+ return context
+
+
+@migrate(QueueItemV1, QueueItem)
+def migrate_queue_item_from_v1_to_v2() -> list[Callable]:
+ return [upgrade_worker_settings_for_queue]
+
+
+@migrate(ActionQueueItemV1, ActionQueueItem)
+def migrate_action_queue_item_v1_to_v2() -> list[Callable]:
+ return [upgrade_worker_settings_for_queue]
+
+
+@migrate(APIEndpointQueueItemV1, APIEndpointQueueItem)
+def migrate_api_endpoint_queue_item_v1_to_v2() -> list[Callable]:
+ return [upgrade_worker_settings_for_queue]
diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py
index f6993d6b032..b2f6d4b9a6e 100644
--- a/packages/syft/src/syft/service/queue/zmq_consumer.py
+++ b/packages/syft/src/syft/service/queue/zmq_consumer.py
@@ -288,7 +288,7 @@ def _set_worker_job(self, job_id: UID | None) -> None:
ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING
)
res = self.worker_stash.update_consumer_state(
- credentials=self.worker_stash.partition.root_verify_key,
+ credentials=self.worker_stash.root_verify_key,
worker_uid=self.syft_worker_id,
consumer_state=consumer_state,
)
diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py
index 5cb5056f8a2..197f87ab283 100644
--- a/packages/syft/src/syft/service/queue/zmq_producer.py
+++ b/packages/syft/src/syft/service/queue/zmq_producer.py
@@ -159,7 +159,7 @@ def read_items(self) -> None:
# Items to be queued
items_to_queue = self.queue_stash.get_by_status(
- self.queue_stash.partition.root_verify_key,
+ self.queue_stash.root_verify_key,
status=Status.CREATED,
).unwrap()
@@ -167,7 +167,7 @@ def read_items(self) -> None:
# Queue Items that are in the processing state
items_processing = self.queue_stash.get_by_status(
- self.queue_stash.partition.root_verify_key,
+ self.queue_stash.root_verify_key,
status=Status.PROCESSING,
).unwrap()
@@ -281,14 +281,14 @@ def update_consumer_state_for_worker(
try:
try:
self.worker_stash.get_by_uid(
- credentials=self.worker_stash.partition.root_verify_key,
+ credentials=self.worker_stash.root_verify_key,
uid=syft_worker_id,
).unwrap()
except Exception:
return None
self.worker_stash.update_consumer_state(
- credentials=self.worker_stash.partition.root_verify_key,
+ credentials=self.worker_stash.root_verify_key,
worker_uid=syft_worker_id,
consumer_state=consumer_state,
).unwrap()
diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py
index 1a492ea1d53..674c66a019a 100644
--- a/packages/syft/src/syft/service/request/request.py
+++ b/packages/syft/src/syft/service/request/request.py
@@ -40,8 +40,8 @@
from ...util.notebook_ui.icons import Icon
from ...util.util import prompt_warning_message
from ..action.action_object import ActionObject
-from ..action.action_store import ActionObjectPermission
-from ..action.action_store import ActionPermission
+from ..action.action_permissions import ActionObjectPermission
+from ..action.action_permissions import ActionPermission
from ..code.user_code import UserCode
from ..code.user_code import UserCodeStatus
from ..code.user_code import UserCodeStatusCollection
@@ -112,7 +112,7 @@ class ActionStoreChange(Change):
@as_result(SyftException)
def _run(self, context: ChangeContext, apply: bool) -> SyftSuccess:
- action_store = context.server.services.action.store
+ action_store = context.server.services.action.stash
# can we ever have a lineage ID in the store?
obj_uid = self.linked_obj.object_uid
@@ -362,6 +362,7 @@ class Request(SyncableSyftObject):
__attr_searchable__ = [
"requesting_user_verify_key",
"approving_user_verify_key",
+ "code_id",
]
__attr_unique__ = ["request_hash"]
__repr_attrs__ = [
diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py
index 6343007f6ca..ed94c185689 100644
--- a/packages/syft/src/syft/service/request/request_service.py
+++ b/packages/syft/src/syft/service/request/request_service.py
@@ -4,7 +4,7 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.linked_obj import LinkedObject
from ...types.errors import SyftException
from ...types.result import as_result
@@ -37,11 +37,9 @@
@serializable(canonical_name="RequestService", version=1)
class RequestService(AbstractService):
- store: DocumentStore
stash: RequestStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = RequestStash(store=store)
@service_method(path="request.submit", name="submit", roles=GUEST_ROLE_LEVEL)
@@ -59,7 +57,7 @@ def submit(
request,
).unwrap()
- root_verify_key = context.server.services.user.admin_verify_key()
+ root_verify_key = context.server.services.user.root_verify_key
if send_message:
message_subject = f"Result to request {str(request.id)[:4]}...{str(request.id)[-3:]}\
diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py
index 19bb2f5720e..a28fd5842e1 100644
--- a/packages/syft/src/syft/service/request/request_stash.py
+++ b/packages/syft/src/syft/service/request/request_stash.py
@@ -1,59 +1,31 @@
-# stdlib
-
-# third party
-
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...types.datetime import DateTime
+from ...store.db.stash import ObjectStash
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.uid import UID
from .request import Request
-RequestingUserVerifyKeyPartitionKey = PartitionKey(
- key="requesting_user_verify_key", type_=SyftVerifyKey
-)
-
-OrderByRequestTimeStampPartitionKey = PartitionKey(key="request_time", type_=DateTime)
-
-
-@serializable(canonical_name="RequestStash", version=1)
-class RequestStash(NewBaseUIDStoreStash):
- object_type = Request
- settings: PartitionSettings = PartitionSettings(
- name=Request.__canonical_name__, object_type=Request
- )
+@serializable(canonical_name="RequestStashSQL", version=1)
+class RequestStash(ObjectStash[Request]):
@as_result(SyftException)
def get_all_for_verify_key(
self,
credentials: SyftVerifyKey,
verify_key: SyftVerifyKey,
) -> list[Request]:
- if isinstance(verify_key, str):
- verify_key = SyftVerifyKey.from_string(verify_key)
- qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)])
- return self.query_all(
+ return self.get_all(
credentials=credentials,
- qks=qks,
- order_by=OrderByRequestTimeStampPartitionKey,
+ filters={"requesting_user_verify_key": verify_key},
).unwrap()
@as_result(SyftException)
def get_by_usercode_id(
self, credentials: SyftVerifyKey, user_code_id: UID
) -> list[Request]:
- all_requests = self.get_all(credentials=credentials).unwrap()
- res = []
- for r in all_requests:
- try:
- if r.code_id == user_code_id:
- res.append(r)
- except SyftException:
- pass
- return res
+ return self.get_all(
+ credentials=credentials,
+ filters={"code_id": user_code_id},
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py
index 784eca2e340..49749711853 100644
--- a/packages/syft/src/syft/service/service.py
+++ b/packages/syft/src/syft/service/service.py
@@ -37,6 +37,7 @@
from ..serde.signature import signature_remove_context
from ..serde.signature import signature_remove_self
from ..server.credentials import SyftVerifyKey
+from ..store.db.stash import ObjectStash
from ..store.document_store import DocumentStore
from ..store.linked_obj import LinkedObject
from ..types.errors import SyftException
@@ -71,6 +72,7 @@ class AbstractService:
server: AbstractServer
server_uid: UID
store_type: type = DocumentStore
+ stash: ObjectStash
@as_result(SyftException)
def resolve_link(
diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py
index 43fe685b7e5..10890350e2d 100644
--- a/packages/syft/src/syft/service/settings/settings_service.py
+++ b/packages/syft/src/syft/service/settings/settings_service.py
@@ -8,7 +8,7 @@
# relative
from ...abstract_server import ServerSideType
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.sqlite_document_store import SQLiteStoreConfig
@@ -48,11 +48,9 @@
@serializable(canonical_name="SettingsService", version=1)
class SettingsService(AbstractService):
- store: DocumentStore
stash: SettingsStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = SettingsStash(store=store)
@service_method(path="settings.get", name="get")
@@ -123,14 +121,16 @@ def update(
def _update(
self, context: AuthedServiceContext, settings: ServerSettingsUpdate
) -> ServerSettings:
- all_settings = self.stash.get_all(context.credentials).unwrap()
+ all_settings = self.stash.get_all(
+ context.credentials, limit=1, sort_order="desc"
+ ).unwrap()
if len(all_settings) > 0:
new_settings = all_settings[0].model_copy(
update=settings.to_dict(exclude_empty=True)
)
ServerSettings.model_validate(new_settings.to_dict())
update_result = self.stash.update(
- context.credentials, settings=new_settings
+ context.credentials, obj=new_settings
).unwrap()
# If notifications_enabled is present in the update, we need to update the notifier settings
@@ -173,10 +173,12 @@ def set_server_side_type_dangerous(
public_message=f"Not a valid server_side_type, please use one of the options from: {side_type_options}"
)
- current_settings = self.stash.get_all(context.credentials).unwrap()
+ current_settings = self.stash.get_all(
+ context.credentials, limit=1, sort_order="desc"
+ ).unwrap()
if len(current_settings) > 0:
new_settings = current_settings[0]
- new_settings.server_side_type = server_side_type
+ new_settings.server_side_type = ServerSideType(server_side_type)
updated_settings = self.stash.update(
context.credentials, new_settings
).unwrap()
diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py
index aa02847504a..c22c08045f3 100644
--- a/packages/syft/src/syft/service/settings/settings_stash.py
+++ b/packages/syft/src/syft/service/settings/settings_stash.py
@@ -1,36 +1,11 @@
# relative
from ...serde.serializable import serializable
-from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store_errors import StashException
-from ...types.result import as_result
-from ...types.uid import UID
+from ...store.db.stash import ObjectStash
+from ...util.telemetry import instrument
from .settings import ServerSettings
-NamePartitionKey = PartitionKey(key="name", type_=str)
-ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID])
-
-@serializable(canonical_name="SettingsStash", version=1)
-class SettingsStash(NewBaseUIDStoreStash):
- object_type = ServerSettings
- settings: PartitionSettings = PartitionSettings(
- name=ServerSettings.__canonical_name__, object_type=ServerSettings
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
- # Should we have this at all?
- @as_result(StashException)
- def update(
- self,
- credentials: SyftVerifyKey,
- settings: ServerSettings,
- has_permission: bool = False,
- ) -> ServerSettings:
- obj = self.check_type(settings, self.object_type).unwrap()
- return super().update(credentials=credentials, obj=obj).unwrap()
+@instrument
+@serializable(canonical_name="SettingsStashSQL", version=1)
+class SettingsStash(ObjectStash[ServerSettings]):
+ pass
diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py
index 75959be55e5..ddafd86b1d3 100644
--- a/packages/syft/src/syft/service/sync/sync_service.py
+++ b/packages/syft/src/syft/service/sync/sync_service.py
@@ -6,7 +6,8 @@
# relative
from ...client.api import ServerIdentity
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store import NewBaseStash
from ...store.document_store_errors import NotFoundException
from ...store.linked_obj import LinkedObject
@@ -36,22 +37,20 @@
logger = logging.getLogger(__name__)
-def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> Any:
+def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash:
if isinstance(item, ActionObject):
service = context.server.services.action # type: ignore
- return service.store # type: ignore
+ return service.stash # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore
- return service.stash.partition
+ return service.stash
@instrument
@serializable(canonical_name="SyncService", version=1)
class SyncService(AbstractService):
- store: DocumentStore
stash: SyncStash
- def __init__(self, store: DocumentStore):
- self.store = store
+ def __init__(self, store: DBManager):
self.stash = SyncStash(store=store)
def add_actionobject_read_permissions(
@@ -60,14 +59,14 @@ def add_actionobject_read_permissions(
action_object: ActionObject,
new_permissions: list[ActionObjectPermission],
) -> None:
- store_to = context.server.services.action.store # type: ignore
+ action_stash = context.server.services.action.stash
for permission in new_permissions:
if permission.permission == ActionPermission.READ:
- store_to.add_permission(permission)
+ action_stash.add_permission(permission)
blob_id = action_object.syft_blob_storage_entry_id
if blob_id:
- store_to_blob = context.server.services.blob_sotrage.stash.partition # type: ignore
+ blob_stash = context.server.services.blob_storage.stash
for permission in new_permissions:
if permission.permission == ActionPermission.READ:
permission_blob = ActionObjectPermission(
@@ -75,7 +74,7 @@ def add_actionobject_read_permissions(
permission=permission.permission,
credentials=permission.credentials,
)
- store_to_blob.add_permission(permission_blob)
+ blob_stash.add_permission(permission_blob)
def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None:
if hasattr(x, "__dict__") and isinstance(x, SyftObject):
@@ -246,9 +245,9 @@ def get_permissions(
self,
context: AuthedServiceContext,
items: list[SyncableSyftObject],
- ) -> tuple[dict[UID, set[str]], dict[UID, set[str]]]:
- permissions = {}
- storage_permissions = {}
+ ) -> tuple[dict[UID, set[str]], dict[UID, set[UID]]]:
+ permissions: dict[UID, set[str]] = {}
+ storage_permissions: dict[UID, set[UID]] = {}
for item in items:
store = get_store(context, item)
@@ -369,7 +368,9 @@ def build_current_state(
storage_permissions = {}
try:
- previous_state = self.stash.get_latest(context=context).unwrap()
+ previous_state = self.stash.get_latest(
+ credentials=context.credentials
+ ).unwrap()
except NotFoundException:
previous_state = None
diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py
index 8633f31f130..114ee209af1 100644
--- a/packages/syft/src/syft/service/sync/sync_stash.py
+++ b/packages/syft/src/syft/service/sync/sync_stash.py
@@ -1,69 +1,37 @@
-# stdlib
-import threading
-
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
+from ...server.credentials import SyftVerifyKey
+from ...store.db.db import DBManager
+from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import StashException
-from ...types.datetime import DateTime
from ...types.result import as_result
-from ..context import AuthedServiceContext
from .sync_state import SyncState
-OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime)
-
@serializable(canonical_name="SyncStash", version=1)
-class SyncStash(NewBaseUIDStoreStash):
- object_type = SyncState
+class SyncStash(ObjectStash[SyncState]):
settings: PartitionSettings = PartitionSettings(
name=SyncState.__canonical_name__,
object_type=SyncState,
)
- def __init__(self, store: DocumentStore):
+ def __init__(self, store: DBManager) -> None:
super().__init__(store)
- self.store = store
- self.settings = self.settings
- self._object_type = self.object_type
self.last_state: SyncState | None = None
@as_result(StashException)
- def get_latest(self, context: AuthedServiceContext) -> SyncState | None:
+ def get_latest(self, credentials: SyftVerifyKey) -> SyncState | None:
if self.last_state is not None:
return self.last_state
- all_states = self.get_all(
- credentials=context.server.verify_key, # type: ignore
- order_by=OrderByDatePartitionKey,
+
+ states = self.get_all(
+ credentials=credentials,
+ order_by="created_at",
+ sort_order="desc",
+ limit=1,
).unwrap()
- if len(all_states) > 0:
- self.last_state = all_states[-1]
- return all_states[-1]
+ if len(states) > 0:
+ return states[0]
return None
-
- def unwrap_set(self, context: AuthedServiceContext, item: SyncState) -> SyncState:
- return super().set(context, item).unwrap()
-
- @as_result(StashException)
- def set( # type: ignore
- self,
- context: AuthedServiceContext,
- item: SyncState,
- **kwargs,
- ) -> SyncState:
- self.last_state = item
-
- # use threading
- threading.Thread(
- target=self.unwrap_set,
- args=(
- context,
- item,
- ),
- kwargs=kwargs,
- ).start()
- return item
diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py
index dab3b6cccd1..a0d27a00786 100644
--- a/packages/syft/src/syft/service/user/user_service.py
+++ b/packages/syft/src/syft/service/user/user_service.py
@@ -11,7 +11,7 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
@@ -81,23 +81,21 @@ def _paginate(
@serializable(canonical_name="UserService", version=1)
class UserService(AbstractService):
- store: DocumentStore
stash: UserStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = UserStash(store=store)
@as_result(StashException)
def _add_user(self, credentials: SyftVerifyKey, user: User) -> User:
- action_object_permissions = ActionObjectPermission(
- uid=user.id, permission=ActionPermission.ALL_READ
- )
-
return self.stash.set(
credentials=credentials,
obj=user,
- add_permissions=[action_object_permissions],
+ add_permissions=[
+ ActionObjectPermission(
+ uid=user.id, permission=ActionPermission.ALL_READ
+ ),
+ ],
).unwrap()
def _check_if_email_exists(self, credentials: SyftVerifyKey, email: str) -> bool:
@@ -132,7 +130,7 @@ def forgot_password(
"If the email is valid, we sent a password "
+ "reset token to your email or a password request to the admin."
)
- root_key = self.admin_verify_key()
+ root_key = self.root_verify_key
root_context = AuthedServiceContext(server=context.server, credentials=root_key)
@@ -243,7 +241,7 @@ def reset_password(
self, context: UnauthedServiceContext, token: str, new_password: str
) -> SyftSuccess:
"""Resets a certain user password using a temporary token."""
- root_key = self.admin_verify_key()
+ root_key = self.root_verify_key
root_context = AuthedServiceContext(server=context.server, credentials=root_key)
try:
@@ -321,21 +319,36 @@ def view(self, context: AuthedServiceContext, uid: UID) -> UserView:
def get_all(
self,
context: AuthedServiceContext,
+ order_by: str | None = None,
+ sort_order: str | None = None,
page_size: int | None = 0,
page_index: int | None = 0,
) -> list[UserView]:
- if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]:
- users = self.stash.get_all(
- context.credentials, has_permission=True
- ).unwrap()
- else:
- users = self.stash.get_all(context.credentials).unwrap()
+ users = self.stash.get_all(
+ context.credentials,
+ order_by=order_by,
+ sort_order=sort_order,
+ ).unwrap()
users = [user.to(UserView) for user in users]
return _paginate(users, page_size, page_index)
+ @service_method(
+ path="user.get_index", name="get_index", roles=DATA_OWNER_ROLE_LEVEL
+ )
+ def get_index(
+ self,
+ context: AuthedServiceContext,
+ index: int,
+ ) -> UserView:
+ return (
+ self.stash.get_index(credentials=context.credentials, index=index)
+ .unwrap()
+ .to(UserView)
+ )
+
def signing_key_for_verify_key(self, verify_key: SyftVerifyKey) -> UserPrivateKey:
user = self.stash.get_by_verify_key(
- credentials=self.stash.admin_verify_key(), verify_key=verify_key
+ credentials=self.stash.root_verify_key, verify_key=verify_key
).unwrap()
return user.to(UserPrivateKey)
@@ -348,12 +361,11 @@ def get_role_for_credentials(
# they could be different
# TODO: This fn is cryptic -- when does each situation occur?
if isinstance(credentials, SyftVerifyKey):
- user = self.stash.get_by_verify_key(
- credentials=credentials, verify_key=credentials
- ).unwrap()
+ role = self.stash.get_role(credentials=credentials)
+ return role
elif isinstance(credentials, SyftSigningKey):
user = self.stash.get_by_signing_key(
- credentials=credentials,
+ credentials=credentials.verify_key,
signing_key=credentials, # type: ignore
).unwrap()
else:
@@ -378,7 +390,9 @@ def search(
if len(kwargs) == 0:
raise SyftException(public_message="Invalid search parameters")
- users = self.stash.find_all(credentials=context.credentials, **kwargs).unwrap()
+ users = self.stash.get_all(
+ credentials=context.credentials, filters=kwargs
+ ).unwrap()
users = [user.to(UserView) for user in users] if users is not None else []
return _paginate(users, page_size, page_index)
@@ -506,45 +520,54 @@ def update(
).unwrap()
if user.role == ServiceRole.ADMIN:
- settings_stash = SettingsStash(store=self.store)
- settings = settings_stash.get_all(context.credentials).unwrap()
+ settings_stash = SettingsStash(store=self.stash.db)
+ settings = settings_stash.get_all(
+ context.credentials, limit=1, sort_order="desc"
+ ).unwrap()
# TODO: Chance to refactor here in settings, as we're always doing get_att[0]
if len(settings) > 0:
settings_data = settings[0]
settings_data.admin_email = user.email
settings_stash.update(
- credentials=context.credentials, settings=settings_data
+ credentials=context.credentials, obj=settings_data
)
return user.to(UserView)
@service_method(path="user.delete", name="delete", roles=GUEST_ROLE_LEVEL)
def delete(self, context: AuthedServiceContext, uid: UID) -> UID:
- user = self.stash.get_by_uid(credentials=context.credentials, uid=uid).unwrap()
+ user_to_delete = self.stash.get_by_uid(
+ credentials=context.credentials, uid=uid
+ ).unwrap()
- if (
+ # Cannot delete root user
+ if user_to_delete.verify_key == self.root_verify_key:
+ raise UserPermissionError(
+ private_message=f"User {context.credentials} attempted to delete root user."
+ )
+
+ # - Admins can delete any user
+ # - Data Owners can delete Data Scientists and Guests
+ has_delete_permissions = (
context.role == ServiceRole.ADMIN
or context.role == ServiceRole.DATA_OWNER
- and user.role
- in [
- ServiceRole.GUEST,
- ServiceRole.DATA_SCIENTIST,
- ]
- ):
- pass
- else:
+ and user_to_delete.role in [ServiceRole.GUEST, ServiceRole.DATA_SCIENTIST]
+ )
+
+ if not has_delete_permissions:
raise UserPermissionError(
- f"User {context.credentials} ({context.role}) tried to delete user {uid} ({user.role})"
+ private_message=(
+ f"User {context.credentials} ({context.role}) tried to delete user "
+ f"{uid} ({user_to_delete.role})"
+ )
)
# TODO: Remove notifications for the deleted user
- self.stash.delete_by_uid(
- credentials=context.credentials, uid=uid, has_permission=True
+ return self.stash.delete_by_uid(
+ credentials=context.credentials, uid=uid
).unwrap()
- return uid
-
def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess:
"""Verify user
TODO: We might want to use a SyftObject instead
@@ -554,7 +577,7 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess:
raise SyftException(public_message="Invalid login credentials")
user = self.stash.get_by_email(
- credentials=self.admin_verify_key(), email=context.login_credentials.email
+ credentials=self.root_verify_key, email=context.login_credentials.email
).unwrap()
if check_pwd(context.login_credentials.password, user.hashed_password):
@@ -573,9 +596,9 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess:
return SyftSuccess(message="Login successful.", value=user.to(UserPrivateKey))
- def admin_verify_key(self) -> SyftVerifyKey:
- # TODO: Remove passthrough method?
- return self.stash.admin_verify_key()
+ @property
+ def root_verify_key(self) -> SyftVerifyKey:
+ return self.stash.root_verify_key
def register(
self, context: ServerServiceContext, new_user: UserCreate
@@ -616,7 +639,7 @@ def register(
success_message = f"User '{user.name}' successfully registered!"
# Notification Step
- root_key = self.admin_verify_key()
+ root_key = self.root_verify_key
root_context = AuthedServiceContext(server=context.server, credentials=root_key)
link = None
@@ -643,7 +666,7 @@ def register(
@as_result(StashException)
def user_verify_key(self, email: str) -> SyftVerifyKey:
# we are bypassing permissions here, so dont use to return a result directly to the user
- credentials = self.admin_verify_key()
+ credentials = self.root_verify_key
user = self.stash.get_by_email(credentials=credentials, email=email).unwrap()
if user.verify_key is None:
raise UserError(f"User {email} has no verify key")
@@ -652,7 +675,7 @@ def user_verify_key(self, email: str) -> SyftVerifyKey:
@as_result(StashException)
def get_by_verify_key(self, verify_key: SyftVerifyKey) -> UserView:
# we are bypassing permissions here, so dont use to return a result directly to the user
- credentials = self.admin_verify_key()
+ credentials = self.root_verify_key
user = self.stash.get_by_verify_key(
credentials=credentials, verify_key=verify_key
).unwrap()
diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py
index 3272f9a946c..92fb87d37b3 100644
--- a/packages/syft/src/syft/service/user/user_stash.py
+++ b/packages/syft/src/syft/service/user/user_stash.py
@@ -2,90 +2,63 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
-from ...store.document_store import UIDPartitionKey
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
-from ...types.uid import UID
from .user import User
from .user_roles import ServiceRole
-# 🟡 TODO 27: it would be nice if these could be defined closer to the User
-EmailPartitionKey = PartitionKey(key="email", type_=str)
-PasswordResetTokenPartitionKey = PartitionKey(key="reset_token", type_=str)
-RolePartitionKey = PartitionKey(key="role", type_=ServiceRole)
-SigningKeyPartitionKey = PartitionKey(key="signing_key", type_=SyftSigningKey)
-VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey)
-
-
-@serializable(canonical_name="UserStash", version=1)
-class UserStash(NewBaseUIDStoreStash):
- object_type = User
- settings: PartitionSettings = PartitionSettings(
- name=User.__canonical_name__,
- object_type=User,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
- def admin_verify_key(self) -> SyftVerifyKey:
- return self.partition.root_verify_key
+@serializable(canonical_name="UserStashSQL", version=1)
+class UserStash(ObjectStash[User]):
@as_result(StashException, NotFoundException)
def admin_user(self) -> User:
# TODO: This returns only one user, the first user with the role ADMIN
- admin_credentials = self.admin_verify_key()
+ admin_credentials = self.root_verify_key
return self.get_by_role(
credentials=admin_credentials, role=ServiceRole.ADMIN
).unwrap()
- @as_result(StashException, NotFoundException)
- def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> User:
- qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)])
- try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
- except NotFoundException as exc:
- raise NotFoundException.from_exception(
- exc, public_message=f"User {uid} not found"
- )
-
@as_result(StashException, NotFoundException)
def get_by_reset_token(self, credentials: SyftVerifyKey, token: str) -> User:
- qks = QueryKeys(qks=[PasswordResetTokenPartitionKey.with_obj(token)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"reset_token": token},
+ ).unwrap()
@as_result(StashException, NotFoundException)
def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User:
- qks = QueryKeys(qks=[EmailPartitionKey.with_obj(email)])
+ return self.get_one(
+ credentials=credentials,
+ filters={"email": email},
+ ).unwrap()
+ @as_result(StashException)
+ def email_exists(self, email: str) -> bool:
try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
- except NotFoundException as exc:
- raise NotFoundException.from_exception(
- exc, public_message=f"User {email} not found"
- )
+ self.get_by_email(credentials=self.root_verify_key, email=email).unwrap()
+ return True
+ except NotFoundException:
+ return False
@as_result(StashException)
- def email_exists(self, email: str) -> bool:
+ def verify_key_exists(self, verify_key: SyftVerifyKey) -> bool:
try:
- self.get_by_email(credentials=self.admin_verify_key(), email=email).unwrap()
+ self.get_by_verify_key(
+ credentials=self.root_verify_key, verify_key=verify_key
+ ).unwrap()
return True
except NotFoundException:
return False
@as_result(StashException, NotFoundException)
def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User:
- # TODO: Is this method correct? Should'nt it return a list of all member with a particular role?
- qks = QueryKeys(qks=[RolePartitionKey.with_obj(role)])
-
try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"role": role},
+ ).unwrap()
except NotFoundException as exc:
private_msg = f"User with role {role} not found"
raise NotFoundException.from_exception(exc, private_message=private_msg)
@@ -94,28 +67,25 @@ def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User:
def get_by_signing_key(
self, credentials: SyftVerifyKey, signing_key: SyftSigningKey | str
) -> User:
- if isinstance(signing_key, str):
- signing_key = SyftSigningKey.from_string(signing_key)
-
- qks = QueryKeys(qks=[SigningKeyPartitionKey.with_obj(signing_key)])
-
try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"signing_key": signing_key},
+ ).unwrap()
except NotFoundException as exc:
private_msg = f"User with signing key {signing_key} not found"
raise NotFoundException.from_exception(exc, private_message=private_msg)
@as_result(StashException, NotFoundException)
def get_by_verify_key(
- self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey | str
+ self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey
) -> User:
- if isinstance(verify_key, str):
- verify_key = SyftVerifyKey.from_string(verify_key)
-
- qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)])
-
try:
- return self.query_one(credentials=credentials, qks=qks).unwrap()
+ return self.get_one(
+ credentials=credentials,
+ filters={"verify_key": verify_key},
+ ).unwrap()
+
except NotFoundException as exc:
private_msg = f"User with verify key {verify_key} not found"
raise NotFoundException.from_exception(exc, private_message=private_msg)
diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py
index 7310decd85d..83a30bb670b 100644
--- a/packages/syft/src/syft/service/worker/image_registry_service.py
+++ b/packages/syft/src/syft/service/worker/image_registry_service.py
@@ -2,7 +2,7 @@
# relative
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.errors import SyftException
from ...types.uid import UID
from ..context import AuthedServiceContext
@@ -20,11 +20,9 @@
@serializable(canonical_name="SyftImageRegistryService", version=1)
class SyftImageRegistryService(AbstractService):
- store: DocumentStore
stash: SyftImageRegistryStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = SyftImageRegistryStash(store=store)
@service_method(
diff --git a/packages/syft/src/syft/service/worker/image_registry_stash.py b/packages/syft/src/syft/service/worker/image_registry_stash.py
index 5c469da0825..cfb71b9848b 100644
--- a/packages/syft/src/syft/service/worker/image_registry_stash.py
+++ b/packages/syft/src/syft/service/worker/image_registry_stash.py
@@ -1,55 +1,34 @@
-# stdlib
-
-# third party
-
-# stdlib
-
# stdlib
from typing import Literal
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
from ...types.result import as_result
from .image_registry import SyftImageRegistry
-__all__ = ["SyftImageRegistryStash"]
-
-
-URLPartitionKey = PartitionKey(key="url", type_=str)
-
-
-@serializable(canonical_name="SyftImageRegistryStash", version=1)
-class SyftImageRegistryStash(NewBaseUIDStoreStash):
- object_type = SyftImageRegistry
- settings: PartitionSettings = PartitionSettings(
- name=SyftImageRegistry.__canonical_name__,
- object_type=SyftImageRegistry,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="SyftImageRegistrySQLStash", version=1)
+class SyftImageRegistryStash(ObjectStash[SyftImageRegistry]):
@as_result(SyftException, StashException, NotFoundException)
def get_by_url(
self,
credentials: SyftVerifyKey,
url: str,
- ) -> SyftImageRegistry | None:
- qks = QueryKeys(qks=[URLPartitionKey.with_obj(url)])
- return self.query_one(credentials=credentials, qks=qks).unwrap(
- public_message=f"Image Registry with url {url} not found"
- )
+ ) -> SyftImageRegistry:
+ return self.get_one(
+ credentials=credentials,
+ filters={"url": url},
+ ).unwrap()
@as_result(SyftException, StashException)
def delete_by_url(self, credentials: SyftVerifyKey, url: str) -> Literal[True]:
- qk = URLPartitionKey.with_obj(url)
- return super().delete(credentials=credentials, qk=qk).unwrap()
+ item = self.get_by_url(credentials=credentials, url=url).unwrap()
+ self.delete_by_uid(credentials=credentials, uid=item.id).unwrap()
+
+ # TODO standardize delete return type
+ return True
diff --git a/packages/syft/src/syft/service/worker/worker_image.py b/packages/syft/src/syft/service/worker/worker_image.py
index e5c110a6e0e..99ca3ad6040 100644
--- a/packages/syft/src/syft/service/worker/worker_image.py
+++ b/packages/syft/src/syft/service/worker/worker_image.py
@@ -7,18 +7,20 @@
from ...server.credentials import SyftVerifyKey
from ...types.datetime import DateTime
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.uid import UID
from .image_identifier import SyftWorkerImageIdentifier
@serializable()
-class SyftWorkerImage(SyftObject):
+class SyftWorkerImageV1(SyftObject):
__canonical_name__ = "SyftWorkerImage"
__version__ = SYFT_OBJECT_VERSION_1
__attr_unique__ = ["config"]
__attr_searchable__ = ["config", "image_hash", "created_by"]
+
__repr_attrs__ = [
"image_identifier",
"image_hash",
@@ -35,6 +37,40 @@ class SyftWorkerImage(SyftObject):
image_hash: str | None = None
built_at: DateTime | None = None
+
+@serializable()
+class SyftWorkerImage(SyftObject):
+ __canonical_name__ = "SyftWorkerImage"
+ __version__ = SYFT_OBJECT_VERSION_2
+
+ __attr_unique__ = ["config_hash"]
+ __attr_searchable__ = [
+ "config",
+ "image_hash",
+ "created_by",
+ "config_hash",
+ ]
+
+ __repr_attrs__ = [
+ "image_identifier",
+ "image_hash",
+ "created_at",
+ "built_at",
+ "config",
+ ]
+
+ id: UID
+ config: WorkerConfig
+ created_by: SyftVerifyKey
+ created_at: DateTime = DateTime.now()
+ image_identifier: SyftWorkerImageIdentifier | None = None
+ image_hash: str | None = None
+ built_at: DateTime | None = None
+
+ @property
+ def config_hash(self) -> str:
+ return self.config.hash()
+
@property
def is_built(self) -> bool:
"""Returns True if the image has been built or is prebuilt."""
diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py
index 37622a36d71..a5f05f94dac 100644
--- a/packages/syft/src/syft/service/worker/worker_image_service.py
+++ b/packages/syft/src/syft/service/worker/worker_image_service.py
@@ -10,7 +10,7 @@
from ...custom_worker.config import WorkerConfig
from ...custom_worker.k8s import IN_KUBERNETES
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
from ...types.errors import SyftException
@@ -31,11 +31,9 @@
@serializable(canonical_name="SyftWorkerImageService", version=1)
class SyftWorkerImageService(AbstractService):
- store: DocumentStore
stash: SyftWorkerImageStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = SyftWorkerImageStash(store=store)
@service_method(
diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py
index 983dfe9a8d6..dc220905839 100644
--- a/packages/syft/src/syft/service/worker/worker_image_stash.py
+++ b/packages/syft/src/syft/service/worker/worker_image_stash.py
@@ -2,16 +2,16 @@
# third party
+# third party
+from sqlalchemy.orm import Session
+
# relative
from ...custom_worker.config import DockerWorkerConfig
from ...custom_worker.config import WorkerConfig
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
+from ...store.db.stash import with_session
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
@@ -20,32 +20,22 @@
from ..action.action_permissions import ActionPermission
from .worker_image import SyftWorkerImage
-WorkerConfigPK = PartitionKey(key="config", type_=WorkerConfig)
-
-
-@serializable(canonical_name="SyftWorkerImageStash", version=1)
-class SyftWorkerImageStash(NewBaseUIDStoreStash):
- object_type = SyftWorkerImage
- settings: PartitionSettings = PartitionSettings(
- name=SyftWorkerImage.__canonical_name__,
- object_type=SyftWorkerImage,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="SyftWorkerImageSQLStash", version=1)
+class SyftWorkerImageStash(ObjectStash[SyftWorkerImage]):
@as_result(SyftException, StashException, NotFoundException)
- def set( # type: ignore
+ @with_session
+ def set(
self,
credentials: SyftVerifyKey,
obj: SyftWorkerImage,
add_permissions: list[ActionObjectPermission] | None = None,
add_storage_permission: bool = True,
ignore_duplicates: bool = False,
+ session: Session = None,
) -> SyftWorkerImage:
- add_permissions = [] if add_permissions is None else add_permissions
-
# By default syft images have all read permission
+ add_permissions = [] if add_permissions is None else add_permissions
add_permissions.append(
ActionObjectPermission(uid=obj.id, permission=ActionPermission.ALL_READ)
)
@@ -67,6 +57,7 @@ def set( # type: ignore
add_permissions=add_permissions,
add_storage_permission=add_storage_permission,
ignore_duplicates=ignore_duplicates,
+ session=session,
)
.unwrap()
)
@@ -85,7 +76,11 @@ def worker_config_exists(
def get_by_worker_config(
self, credentials: SyftVerifyKey, config: WorkerConfig
) -> SyftWorkerImage:
- qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)])
- return self.query_one(credentials=credentials, qks=qks).unwrap(
+ # TODO cannot search on fields containing objects
+ all_images = self.get_all(credentials=credentials).unwrap()
+ for image in all_images:
+ if image.config == config:
+ return image
+ raise NotFoundException(
public_message=f"Worker Image with config {config} not found"
)
diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py
index 4ceced2bf26..55b103ba369 100644
--- a/packages/syft/src/syft/service/worker/worker_pool_service.py
+++ b/packages/syft/src/syft/service/worker/worker_pool_service.py
@@ -12,7 +12,7 @@
from ...custom_worker.k8s import IN_KUBERNETES
from ...custom_worker.runner_k8s import KubernetesRunner
from ...serde.serializable import serializable
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
@@ -52,11 +52,9 @@
@serializable(canonical_name="SyftWorkerPoolService", version=1)
class SyftWorkerPoolService(AbstractService):
- store: DocumentStore
stash: SyftWorkerPoolStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = SyftWorkerPoolStash(store=store)
self.image_stash = SyftWorkerImageStash(store=store)
diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py
index 94a4a0b8fab..81a4f4741d2 100644
--- a/packages/syft/src/syft/service/worker/worker_pool_stash.py
+++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py
@@ -2,14 +2,14 @@
# third party
+# third party
+from sqlalchemy.orm import Session
+
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
-from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
+from ...store.db.stash import ObjectStash
+from ...store.db.stash import with_session
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
@@ -18,29 +18,22 @@
from ..action.action_permissions import ActionPermission
from .worker_pool import WorkerPool
-PoolNamePartitionKey = PartitionKey(key="name", type_=str)
-PoolImageIDPartitionKey = PartitionKey(key="image_id", type_=UID)
-
-
-@serializable(canonical_name="SyftWorkerPoolStash", version=1)
-class SyftWorkerPoolStash(NewBaseUIDStoreStash):
- object_type = WorkerPool
- settings: PartitionSettings = PartitionSettings(
- name=WorkerPool.__canonical_name__,
- object_type=WorkerPool,
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
+@serializable(canonical_name="SyftWorkerPoolSQLStash", version=1)
+class SyftWorkerPoolStash(ObjectStash[WorkerPool]):
@as_result(StashException, NotFoundException)
def get_by_name(self, credentials: SyftVerifyKey, pool_name: str) -> WorkerPool:
- qks = QueryKeys(qks=[PoolNamePartitionKey.with_obj(pool_name)])
- return self.query_one(credentials=credentials, qks=qks).unwrap(
+ result = self.get_one(
+ credentials=credentials,
+ filters={"name": pool_name},
+ )
+
+ return result.unwrap(
public_message=f"WorkerPool with name {pool_name} not found"
)
@as_result(StashException)
+ @with_session
def set(
self,
credentials: SyftVerifyKey,
@@ -48,6 +41,7 @@ def set(
add_permissions: list[ActionObjectPermission] | None = None,
add_storage_permission: bool = True,
ignore_duplicates: bool = False,
+ session: Session = None,
) -> WorkerPool:
# By default all worker pools have all read permission
add_permissions = [] if add_permissions is None else add_permissions
@@ -62,6 +56,7 @@ def set(
add_permissions=add_permissions,
add_storage_permission=add_storage_permission,
ignore_duplicates=ignore_duplicates,
+ session=session,
)
.unwrap()
)
@@ -70,5 +65,7 @@ def set(
def get_by_image_uid(
self, credentials: SyftVerifyKey, image_uid: UID
) -> list[WorkerPool]:
- qks = QueryKeys(qks=[PoolImageIDPartitionKey.with_obj(image_uid)])
- return self.query_all(credentials=credentials, qks=qks).unwrap()
+ return self.get_all(
+ credentials=credentials,
+ filters={"image_id": image_uid},
+ ).unwrap()
diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py
index 625a88a46b4..300c0b6ed3d 100644
--- a/packages/syft/src/syft/service/worker/worker_service.py
+++ b/packages/syft/src/syft/service/worker/worker_service.py
@@ -13,7 +13,7 @@
from ...custom_worker.runner_k8s import KubernetesRunner
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
+from ...store.db.db import DBManager
from ...store.document_store import SyftSuccess
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
@@ -39,11 +39,9 @@
@serializable(canonical_name="WorkerService", version=1)
class WorkerService(AbstractService):
- store: DocumentStore
stash: WorkerStash
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = WorkerStash(store=store)
@service_method(
diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py
index b2b059ffec5..48a192ecd19 100644
--- a/packages/syft/src/syft/service/worker/worker_stash.py
+++ b/packages/syft/src/syft/service/worker/worker_stash.py
@@ -2,14 +2,15 @@
# third party
+# third party
+from sqlalchemy.orm import Session
+
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
-from ...store.document_store import DocumentStore
-from ...store.document_store import NewBaseUIDStoreStash
+from ...store.db.stash import ObjectStash
+from ...store.db.stash import with_session
from ...store.document_store import PartitionKey
-from ...store.document_store import PartitionSettings
-from ...store.document_store import QueryKeys
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
@@ -22,17 +23,10 @@
WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str)
-@serializable(canonical_name="WorkerStash", version=1)
-class WorkerStash(NewBaseUIDStoreStash):
- object_type = SyftWorker
- settings: PartitionSettings = PartitionSettings(
- name=SyftWorker.__canonical_name__, object_type=SyftWorker
- )
-
- def __init__(self, store: DocumentStore) -> None:
- super().__init__(store=store)
-
+@serializable(canonical_name="WorkerSQLStash", version=1)
+class WorkerStash(ObjectStash[SyftWorker]):
@as_result(StashException)
+ @with_session
def set(
self,
credentials: SyftVerifyKey,
@@ -40,6 +34,7 @@ def set(
add_permissions: list[ActionObjectPermission] | None = None,
add_storage_permission: bool = True,
ignore_duplicates: bool = False,
+ session: Session = None,
) -> SyftWorker:
# By default all worker pools have all read permission
add_permissions = [] if add_permissions is None else add_permissions
@@ -54,17 +49,11 @@ def set(
add_permissions=add_permissions,
ignore_duplicates=ignore_duplicates,
add_storage_permission=add_storage_permission,
+ session=session,
)
.unwrap()
)
- @as_result(StashException, NotFoundException)
- def get_worker_by_name(
- self, credentials: SyftVerifyKey, worker_name: str
- ) -> SyftWorker:
- qks = QueryKeys(qks=[WorkerContainerNamePartitionKey.with_obj(worker_name)])
- return self.query_one(credentials=credentials, qks=qks).unwrap()
-
@as_result(StashException, NotFoundException)
def update_consumer_state(
self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState
diff --git a/packages/syft/src/syft/store/__init__.py b/packages/syft/src/syft/store/__init__.py
index 9260d13f956..e69de29bb2d 100644
--- a/packages/syft/src/syft/store/__init__.py
+++ b/packages/syft/src/syft/store/__init__.py
@@ -1,3 +0,0 @@
-# relative
-from .mongo_document_store import MongoDict
-from .mongo_document_store import MongoStoreConfig
diff --git a/packages/syft/tests/mongomock/py.typed b/packages/syft/src/syft/store/db/__init__.py
similarity index 100%
rename from packages/syft/tests/mongomock/py.typed
rename to packages/syft/src/syft/store/db/__init__.py
diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py
new file mode 100644
index 00000000000..cc82e5a3f4e
--- /dev/null
+++ b/packages/syft/src/syft/store/db/db.py
@@ -0,0 +1,81 @@
+# stdlib
+import logging
+from typing import Generic
+from typing import TypeVar
+from urllib.parse import urlparse
+
+# third party
+from pydantic import BaseModel
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+
+# relative
+from ...serde.serializable import serializable
+from ...server.credentials import SyftVerifyKey
+from ...types.uid import UID
+from .schema import PostgresBase
+from .schema import SQLiteBase
+
+logger = logging.getLogger(__name__)
+
+
+@serializable(canonical_name="DBConfig", version=1)
+class DBConfig(BaseModel):
+ @property
+ def connection_string(self) -> str:
+ raise NotImplementedError("Subclasses must implement this method.")
+
+ @classmethod
+ def from_connection_string(cls, conn_str: str) -> "DBConfig":
+ # relative
+ from .postgres import PostgresDBConfig
+ from .sqlite import SQLiteDBConfig
+
+ parsed = urlparse(conn_str)
+ if parsed.scheme == "postgresql":
+ return PostgresDBConfig(
+ host=parsed.hostname,
+ port=parsed.port,
+ user=parsed.username,
+ password=parsed.password,
+ database=parsed.path.lstrip("/"),
+ )
+ elif parsed.scheme == "sqlite":
+ return SQLiteDBConfig(path=parsed.path)
+ else:
+ raise ValueError(f"Unsupported database scheme {parsed.scheme}")
+
+
+ConfigT = TypeVar("ConfigT", bound=DBConfig)
+
+
+class DBManager(Generic[ConfigT]):
+ def __init__(
+ self,
+ config: ConfigT,
+ server_uid: UID,
+ root_verify_key: SyftVerifyKey,
+ ) -> None:
+ self.config = config
+ self.root_verify_key = root_verify_key
+ self.server_uid = server_uid
+ self.engine = create_engine(
+ config.connection_string,
+ # json_serializer=dumps,
+ # json_deserializer=loads,
+ )
+ logger.info(f"Connecting to {config.connection_string}")
+ self.sessionmaker = sessionmaker(bind=self.engine)
+ self.update_settings()
+ logger.info(f"Successfully connected to {config.connection_string}")
+
+ def update_settings(self) -> None:
+ pass
+
+ def init_tables(self, reset: bool = False) -> None:
+ Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase
+
+ with self.sessionmaker().begin() as _:
+ if reset:
+ Base.metadata.drop_all(bind=self.engine)
+ Base.metadata.create_all(self.engine)
diff --git a/packages/syft/src/syft/store/db/errors.py b/packages/syft/src/syft/store/db/errors.py
new file mode 100644
index 00000000000..8f9a4ca048a
--- /dev/null
+++ b/packages/syft/src/syft/store/db/errors.py
@@ -0,0 +1,31 @@
+# stdlib
+import logging
+
+# third party
+from sqlalchemy.exc import DatabaseError
+from typing_extensions import Self
+
+# relative
+from ..document_store_errors import StashException
+
+logger = logging.getLogger(__name__)
+
+
+class StashDBException(StashException):
+ """
+ See https://docs.sqlalchemy.org/en/20/errors.html#databaseerror
+
+ StashDBException converts a SQLAlchemy DatabaseError into a StashException,
+ DatabaseErrors are errors thrown by the database itself, for example when a
+ query fails because a table is missing.
+ """
+
+ public_message = "There was an error retrieving data. Contact your admin."
+
+ @classmethod
+ def from_sqlalchemy_error(cls, e: DatabaseError) -> Self:
+ logger.exception(e)
+
+ error_type = e.__class__.__name__
+ private_message = f"{error_type}: {str(e)}"
+ return cls(private_message=private_message)
diff --git a/packages/syft/src/syft/store/db/postgres.py b/packages/syft/src/syft/store/db/postgres.py
new file mode 100644
index 00000000000..630155e29de
--- /dev/null
+++ b/packages/syft/src/syft/store/db/postgres.py
@@ -0,0 +1,48 @@
+# third party
+from sqlalchemy import URL
+
+# relative
+from ...serde.serializable import serializable
+from ...server.credentials import SyftVerifyKey
+from ...types.uid import UID
+from .db import DBManager
+from .sqlite import DBConfig
+
+
+@serializable(canonical_name="PostgresDBConfig", version=1)
+class PostgresDBConfig(DBConfig):
+ host: str
+ port: int
+ user: str
+ password: str
+ database: str
+
+ @property
+ def connection_string(self) -> str:
+ return URL.create(
+ "postgresql",
+ username=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ database=self.database,
+ ).render_as_string(hide_password=False)
+
+
+class PostgresDBManager(DBManager[PostgresDBConfig]):
+ def update_settings(self) -> None:
+ return super().update_settings()
+
+ @classmethod
+ def random(
+ cls: type,
+ *,
+ config: PostgresDBConfig,
+ server_uid: UID | None = None,
+ root_verify_key: SyftVerifyKey | None = None,
+ ) -> "PostgresDBManager":
+ root_verify_key = root_verify_key or SyftVerifyKey.generate()
+ server_uid = server_uid or UID()
+ return PostgresDBManager(
+ config=config, server_uid=server_uid, root_verify_key=root_verify_key
+ )
diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py
new file mode 100644
index 00000000000..04864a4a74a
--- /dev/null
+++ b/packages/syft/src/syft/store/db/query.py
@@ -0,0 +1,383 @@
+# stdlib
+from abc import ABC
+from abc import abstractmethod
+import enum
+from typing import Any
+from typing import Literal
+
+# third party
+import sqlalchemy as sa
+from sqlalchemy import Column
+from sqlalchemy import Dialect
+from sqlalchemy import Result
+from sqlalchemy import Select
+from sqlalchemy import Table
+from sqlalchemy import func
+from sqlalchemy.exc import DatabaseError
+from sqlalchemy.orm import Session
+from typing_extensions import Self
+
+# relative
+from ...serde.json_serde import serialize_json
+from ...server.credentials import SyftVerifyKey
+from ...service.action.action_permissions import ActionObjectPermission
+from ...service.action.action_permissions import ActionPermission
+from ...service.user.user_roles import ServiceRole
+from ...types.syft_object import SyftObject
+from ...types.uid import UID
+from .errors import StashDBException
+from .schema import PostgresBase
+from .schema import SQLiteBase
+
+
+class FilterOperator(enum.Enum):
+ EQ = "eq"
+ CONTAINS = "contains"
+
+
+class Query(ABC):
+ def __init__(self, object_type: type[SyftObject]) -> None:
+ self.object_type: type = object_type
+ self.table: Table = self._get_table(object_type)
+ self.stmt: Select = self.table.select()
+
+ @abstractmethod
+ def _get_table(self, object_type: type[SyftObject]) -> Table:
+ raise NotImplementedError
+
+ @staticmethod
+ def get_query_class(dialect: str | Dialect) -> "type[Query]":
+ if isinstance(dialect, Dialect):
+ dialect = dialect.name
+
+ if dialect == "sqlite":
+ return SQLiteQuery
+ elif dialect == "postgresql":
+ return PostgresQuery
+ else:
+ raise ValueError(f"Unsupported dialect {dialect}")
+
+ @classmethod
+ def create(cls, object_type: type[SyftObject], dialect: str | Dialect) -> "Query":
+ """Create a query object for the given object type and dialect."""
+ query_class = cls.get_query_class(dialect)
+ return query_class(object_type)
+
+ def execute(self, session: Session) -> Result:
+ """Execute the query using the given session."""
+ try:
+ return session.execute(self.stmt)
+ except DatabaseError as e:
+ raise StashDBException.from_sqlalchemy_error(e) from e
+
+ def with_permissions(
+ self,
+ credentials: SyftVerifyKey,
+ role: ServiceRole,
+ permission: ActionPermission = ActionPermission.READ,
+ ) -> Self:
+ """Add a permission check to the query.
+
+ If the user has a role below DATA_OWNER, the query will be filtered to only include objects
+ that the user has the specified permission on.
+
+ Args:
+ credentials (SyftVerifyKey): user verify key
+ role (ServiceRole): role of the user
+ permission (ActionPermission, optional): Type of permission to check for.
+ Defaults to ActionPermission.READ.
+
+ Returns:
+ Self: The query object with the permission check applied
+ """
+ if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER):
+ return self
+
+ ao_permission = ActionObjectPermission(
+ uid=UID(), # dummy uid, we just need the permission string
+ credentials=credentials,
+ permission=permission,
+ )
+
+ permission_clause = self._make_permissions_clause(ao_permission)
+ self.stmt = self.stmt.where(permission_clause)
+
+ return self
+
+ def filter(self, field: str, operator: str | FilterOperator, value: Any) -> Self:
+ """Add a filter to the query.
+
+ example usage:
+ Query(User).filter("name", "eq", "Alice")
+ Query(User).filter("friends", "contains", "Bob")
+
+ Args:
+ field (str): Field to filter on
+ operator (str): Operator to use for the filter
+ value (Any): Value to filter on
+
+ Raises:
+ ValueError: If the operator is not supported
+
+ Returns:
+ Self: The query object with the filter applied
+ """
+ filter = self._create_filter_clause(self.table, field, operator, value)
+ self.stmt = self.stmt.where(filter)
+ return self
+
+ def filter_and(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self:
+ """Add filters to the query using an AND clause.
+
+ example usage:
+ Query(User).filter_and(
+ ("name", "eq", "Alice"),
+ ("age", "eq", 30),
+ )
+
+ Args:
+ field (str): Field to filter on
+ operator (str): Operator to use for the filter
+ value (Any): Value to filter on
+
+ Raises:
+ ValueError: If the operator is not supported
+
+ Returns:
+ Self: The query object with the filter applied
+ """
+ filter_clauses = [
+ self._create_filter_clause(self.table, field, operator, value)
+ for field, operator, value in filters
+ ]
+
+ self.stmt = self.stmt.where(sa.and_(*filter_clauses))
+ return self
+
+ def filter_or(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self:
+ """Add filters to the query using an OR clause.
+
+ example usage:
+ Query(User).filter_or(
+ ("name", "eq", "Alice"),
+ ("age", "eq", 30),
+ )
+
+ Args:
+ field (str): Field to filter on
+ operator (str): Operator to use for the filter
+ value (Any): Value to filter on
+
+ Raises:
+ ValueError: If the operator is not supported
+
+ Returns:
+ Self: The query object with the filter applied
+ """
+ filter_clauses = [
+ self._create_filter_clause(self.table, field, operator, value)
+ for field, operator, value in filters
+ ]
+
+ self.stmt = self.stmt.where(sa.or_(*filter_clauses))
+ return self
+
+ def _create_filter_clause(
+ self,
+ table: Table,
+ field: str,
+ operator: str | FilterOperator,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ if isinstance(operator, str):
+ try:
+ operator = FilterOperator(operator.lower())
+ except ValueError:
+ raise ValueError(f"Filter operator {operator} not supported")
+
+ if operator == FilterOperator.EQ:
+ return self._eq_filter(table, field, value)
+ elif operator == FilterOperator.CONTAINS:
+ return self._contains_filter(table, field, value)
+
+ def order_by(
+ self,
+ field: str | None = None,
+ order: Literal["asc", "desc"] | None = None,
+ ) -> Self:
+ """Add an order by clause to the query, with sensible defaults if field or order is not provided.
+
+ Args:
+ field (Optional[str]): field to order by. If None, uses the default field.
+ order (Optional[Literal["asc", "desc"]]): Order to use ("asc" or "desc").
+ Defaults to 'asc' if field is provided and order is not, or the default order otherwise.
+
+ Raises:
+ ValueError: If the order is not "asc" or "desc"
+
+ Returns:
+ Self: The query object with the order by clause applied.
+ """
+ # Determine the field and order defaults if not provided
+ if field is None:
+ if hasattr(self.object_type, "__order_by__"):
+ default_field, default_order = self.object_type.__order_by__
+ else:
+ default_field, default_order = "_created_at", "desc"
+ field = default_field
+ else:
+ # If field is provided but order is not, default to 'asc'
+ default_order = "asc"
+ order = order or default_order
+
+ column = self._get_column(field)
+
+ if isinstance(column.type, sa.JSON):
+ column = sa.cast(column, sa.String)
+
+ if order.lower() == "asc":
+ self.stmt = self.stmt.order_by(column.asc())
+
+ elif order.lower() == "desc":
+ self.stmt = self.stmt.order_by(column.desc())
+ else:
+ raise ValueError(f"Invalid sort order {order}")
+
+ return self
+
+ def limit(self, limit: int | None) -> Self:
+ """Add a limit clause to the query."""
+ if limit is None:
+ return self
+
+ if limit < 0:
+ raise ValueError("Limit must be a positive integer")
+ self.stmt = self.stmt.limit(limit)
+
+ return self
+
+ def offset(self, offset: int) -> Self:
+ """Add an offset clause to the query."""
+ if offset < 0:
+ raise ValueError("Offset must be a positive integer")
+
+ self.stmt = self.stmt.offset(offset)
+ return self
+
+ @abstractmethod
+ def _make_permissions_clause(
+ self,
+ permission: ActionObjectPermission,
+ ) -> sa.sql.elements.BinaryExpression:
+ pass
+
+ @abstractmethod
+ def _contains_filter(
+ self,
+ table: Table,
+ field: str,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ pass
+
+ def _get_column(self, column: str) -> Column:
+ if column == "id":
+ return self.table.c.id
+ if column == "created_date" or column == "_created_at":
+ return self.table.c._created_at
+ elif column == "updated_date" or column == "_updated_at":
+ return self.table.c._updated_at
+ elif column == "deleted_date" or column == "_deleted_at":
+ return self.table.c._deleted_at
+
+ return self.table.c.fields[column]
+
+
+class SQLiteQuery(Query):
+ def _make_permissions_clause(
+ self,
+ permission: ActionObjectPermission,
+ ) -> sa.sql.elements.BinaryExpression:
+ permission_string = permission.permission_string
+ compound_permission_string = permission.compound_permission_string
+ return sa.or_(
+ self.table.c.permissions.contains(permission_string),
+ self.table.c.permissions.contains(compound_permission_string),
+ )
+
+ def _get_table(self, object_type: type[SyftObject]) -> Table:
+ cname = object_type.__canonical_name__
+ if cname not in SQLiteBase.metadata.tables:
+ raise ValueError(f"Table for {cname} not found")
+ return SQLiteBase.metadata.tables[cname]
+
+ def _contains_filter(
+ self,
+ table: Table,
+ field: str,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ field_value = serialize_json(value)
+ return table.c.fields[field].contains(func.json_quote(field_value))
+
+ def _eq_filter(
+ self,
+ table: Table,
+ field: str,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ if field == "id":
+ return table.c.id == UID(value)
+
+ if "." in field:
+ # magic!
+ field = field.split(".") # type: ignore
+
+ json_value = serialize_json(value)
+ return table.c.fields[field] == func.json_quote(json_value)
+
+
+class PostgresQuery(Query):
+ def _make_permissions_clause(
+ self, permission: ActionObjectPermission
+ ) -> sa.sql.elements.BinaryExpression:
+ permission_string = [permission.permission_string]
+ compound_permission_string = [permission.compound_permission_string]
+ return sa.or_(
+ self.table.c.permissions.contains(permission_string),
+ self.table.c.permissions.contains(compound_permission_string),
+ )
+
+ def _contains_filter(
+ self,
+ table: Table,
+ field: str,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ field_value = serialize_json(value)
+ col = sa.cast(table.c.fields[field], sa.Text)
+ val = sa.cast(field_value, sa.Text)
+ return col.contains(val)
+
+ def _get_table(self, object_type: type[SyftObject]) -> Table:
+ cname = object_type.__canonical_name__
+ if cname not in PostgresBase.metadata.tables:
+ raise ValueError(f"Table for {cname} not found")
+ return PostgresBase.metadata.tables[cname]
+
+ def _eq_filter(
+ self,
+ table: Table,
+ field: str,
+ value: Any,
+ ) -> sa.sql.elements.BinaryExpression:
+ if field == "id":
+ return table.c.id == UID(value)
+
+ if "." in field:
+ # magic!
+ field = field.split(".") # type: ignore
+
+ json_value = serialize_json(value)
+ # NOTE: there might be a bug with casting everything to text
+ return table.c.fields[field].astext == sa.cast(json_value, sa.Text)
diff --git a/packages/syft/src/syft/store/db/schema.py b/packages/syft/src/syft/store/db/schema.py
new file mode 100644
index 00000000000..7f81e39802e
--- /dev/null
+++ b/packages/syft/src/syft/store/db/schema.py
@@ -0,0 +1,87 @@
+# stdlib
+
+# stdlib
+import uuid
+
+# third party
+import sqlalchemy as sa
+from sqlalchemy import Column
+from sqlalchemy import Dialect
+from sqlalchemy import Table
+from sqlalchemy import TypeDecorator
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.types import JSON
+
+# relative
+from ...types.syft_object import SyftObject
+from ...types.uid import UID
+
+
+class SQLiteBase(DeclarativeBase):
+ pass
+
+
+class PostgresBase(DeclarativeBase):
+ pass
+
+
+class UIDTypeDecorator(TypeDecorator):
+ """Converts between Syft UID and UUID."""
+
+ impl = sa.UUID
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect): # type: ignore
+ if value is not None:
+ return value.value
+
+ def process_result_value(self, value, dialect): # type: ignore
+ if value is not None:
+ return UID(value)
+
+
+def create_table(
+ object_type: type[SyftObject],
+ dialect: Dialect,
+) -> Table:
+ """Create a table for a given SYftObject type, and add it to the metadata.
+
+ To create the table on the database, you must call `Base.metadata.create_all(engine)`.
+
+ Args:
+ object_type (type[SyftObject]): The type of the object to create a table for.
+ dialect (Dialect): The dialect of the database.
+
+ Returns:
+ Table: The created table.
+ """
+ table_name = object_type.__canonical_name__
+ dialect_name = dialect.name
+
+ fields_type = JSON if dialect_name == "sqlite" else postgresql.JSON
+ permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB
+ storage_permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB
+
+ Base = SQLiteBase if dialect_name == "sqlite" else PostgresBase
+
+ if table_name not in Base.metadata.tables:
+ Table(
+ object_type.__canonical_name__,
+ Base.metadata,
+ Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4),
+ Column("fields", fields_type, default={}),
+ Column("permissions", permissions_type, default=[]),
+ Column(
+ "storage_permissions",
+ storage_permissions_type,
+ default=[],
+ ),
+ Column(
+ "_created_at", sa.DateTime, server_default=sa.func.now(), index=True
+ ),
+ Column("_updated_at", sa.DateTime, server_onupdate=sa.func.now()),
+ Column("_deleted_at", sa.DateTime, index=True),
+ )
+
+ return Base.metadata.tables[table_name]
diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py
new file mode 100644
index 00000000000..fbcf87ce47b
--- /dev/null
+++ b/packages/syft/src/syft/store/db/sqlite.py
@@ -0,0 +1,61 @@
+# stdlib
+from pathlib import Path
+import tempfile
+import uuid
+
+# third party
+from pydantic import Field
+import sqlalchemy as sa
+
+# relative
+from ...serde.serializable import serializable
+from ...server.credentials import SyftSigningKey
+from ...server.credentials import SyftVerifyKey
+from ...types.uid import UID
+from .db import DBConfig
+from .db import DBManager
+
+
+@serializable(canonical_name="SQLiteDBConfig", version=1)
+class SQLiteDBConfig(DBConfig):
+ filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db")
+ path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir()))
+
+ @property
+ def connection_string(self) -> str:
+ """
+ NOTE in-memory sqlite is not shared between connections, so:
+ - using 2 workers (high/low) will not share a db
+ - re-using a connection (e.g. for a Job worker) will not share a db
+ """
+ if self.path == Path("."):
+ # Use in-memory database, only for unittests
+ return "sqlite://"
+ filepath = self.path / self.filename
+ return f"sqlite:///{filepath.resolve()}"
+
+
+class SQLiteDBManager(DBManager[SQLiteDBConfig]):
+ def update_settings(self) -> None:
+ connection = self.engine.connect()
+ connection.execute(sa.text("PRAGMA journal_mode = WAL"))
+ connection.execute(sa.text("PRAGMA busy_timeout = 5000"))
+ connection.execute(sa.text("PRAGMA temp_store = 2"))
+ connection.execute(sa.text("PRAGMA synchronous = 1"))
+
+ @classmethod
+ def random(
+ cls,
+ *,
+ config: SQLiteDBConfig | None = None,
+ server_uid: UID | None = None,
+ root_verify_key: SyftVerifyKey | None = None,
+ ) -> "SQLiteDBManager":
+ root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key
+ server_uid = server_uid or UID()
+ config = config or SQLiteDBConfig()
+ return SQLiteDBManager(
+ config=config,
+ server_uid=server_uid,
+ root_verify_key=root_verify_key,
+ )
diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py
new file mode 100644
index 00000000000..aec2a2ed9c5
--- /dev/null
+++ b/packages/syft/src/syft/store/db/stash.py
@@ -0,0 +1,848 @@
+# stdlib
+from collections.abc import Callable
+from functools import wraps
+import inspect
+from typing import Any
+from typing import Generic
+from typing import ParamSpec
+from typing import Set # noqa: UP035
+from typing import cast
+from typing import get_args
+
+# third party
+from pydantic import ValidationError
+import sqlalchemy as sa
+from sqlalchemy import Row
+from sqlalchemy import Table
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+from typing_extensions import Self
+from typing_extensions import TypeVar
+
+# relative
+from ...serde.json_serde import deserialize_json
+from ...serde.json_serde import is_json_primitive
+from ...serde.json_serde import serialize_json
+from ...server.credentials import SyftVerifyKey
+from ...service.action.action_permissions import ActionObjectEXECUTE
+from ...service.action.action_permissions import ActionObjectOWNER
+from ...service.action.action_permissions import ActionObjectPermission
+from ...service.action.action_permissions import ActionObjectREAD
+from ...service.action.action_permissions import ActionObjectWRITE
+from ...service.action.action_permissions import ActionPermission
+from ...service.action.action_permissions import StoragePermission
+from ...service.user.user_roles import ServiceRole
+from ...types.errors import SyftException
+from ...types.result import as_result
+from ...types.syft_metaclass import Empty
+from ...types.syft_object import PartialSyftObject
+from ...types.syft_object import SyftObject
+from ...types.uid import UID
+from ...util.telemetry import instrument
+from ..document_store_errors import NotFoundException
+from ..document_store_errors import StashException
+from .db import DBManager
+from .query import Query
+from .schema import PostgresBase
+from .schema import SQLiteBase
+from .schema import create_table
+from .sqlite import SQLiteDBManager
+
+StashT = TypeVar("StashT", bound=SyftObject)
+T = TypeVar("T")
+P = ParamSpec("P")
+
+
+def parse_filters(filter_dict: dict[str, Any] | None) -> list[tuple[str, str, Any]]:
+ # NOTE using django style filters, e.g. {"age__gt": 18}
+ if filter_dict is None:
+ return []
+ filters = []
+ for key, value in filter_dict.items():
+ key_split = key.split("__")
+ # Operator is eq if not specified
+ if len(key_split) == 1:
+ field, operator = key, "eq"
+ elif len(key_split) == 2:
+ field, operator = key_split
+ filters.append((field, operator, value))
+ return filters
+
+
+def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
+ """
+ Decorator to inject a session into the function kwargs if it is not provided.
+
+ TODO: This decorator is a temporary fix, we want to move to a DI approach instead:
+ move db connection and session to context, and pass context to all stash methods.
+ """
+
+ # inspect if the function has a session kwarg
+ sig = inspect.signature(func)
+ inject_session: bool = "session" in sig.parameters
+
+ @wraps(func)
+ def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any:
+ if inject_session and kwargs.get("session") is None:
+ with self.sessionmaker() as session:
+ kwargs["session"] = session
+ return func(self, *args, **kwargs)
+ return func(self, *args, **kwargs)
+
+ return wrapper # type: ignore
+
+
+@instrument
+class ObjectStash(Generic[StashT]):
+ allow_any_type: bool = False
+
+ def __init__(self, store: DBManager) -> None:
+ self.db = store
+ self.object_type = self.get_object_type()
+ self.table = create_table(self.object_type, self.dialect)
+ self.sessionmaker = self.db.sessionmaker
+
+ @property
+ def dialect(self) -> sa.engine.interfaces.Dialect:
+ return self.db.engine.dialect
+
+ @classmethod
+ def get_object_type(cls) -> type[StashT]:
+ """
+ Get the object type this stash is storing. This is the generic argument of the
+ ObjectStash class.
+ """
+ generic_args = get_args(cls.__orig_bases__[0])
+ if len(generic_args) != 1:
+ raise TypeError("ObjectStash must have a single generic argument")
+ elif not issubclass(generic_args[0], SyftObject):
+ raise TypeError(
+ "ObjectStash generic argument must be a subclass of SyftObject"
+ )
+ return generic_args[0]
+
+ @with_session
+ def __len__(self, session: Session = None) -> int:
+ return session.query(self.table).count()
+
+ @classmethod
+ def random(cls, **kwargs: dict) -> Self:
+ """Create a random stash with a random server_uid and root_verify_key. Useful for development."""
+ db_manager = SQLiteDBManager.random(**kwargs)
+ stash = cls(store=db_manager)
+ stash.db.init_tables()
+ return stash
+
+ def _is_sqlite(self) -> bool:
+ return self.db.engine.dialect.name == "sqlite"
+
+ @property
+ def server_uid(self) -> UID:
+ return self.db.server_uid
+
+ @property
+ def root_verify_key(self) -> SyftVerifyKey:
+ return self.db.root_verify_key
+
+ @property
+ def _data(self) -> list[StashT]:
+ return self.get_all(self.root_verify_key, has_permission=True).unwrap()
+
+ def query(self, object_type: type[SyftObject] | None = None) -> Query:
+ """Creates a query for this stash's object type and SQL dialect."""
+ object_type = object_type or self.object_type
+ return Query.create(object_type, self.dialect)
+
+ @as_result(StashException)
+ def check_type(self, obj: T, type_: type) -> T:
+ if not isinstance(obj, type_):
+ raise StashException(f"{type(obj)} does not match required type: {type_}")
+ return cast(T, obj)
+
+ @property
+ def session(self) -> Session:
+ return self.db.session
+
+ def _print_query(self, stmt: sa.sql.select) -> None:
+ print(
+ stmt.compile(
+ compile_kwargs={"literal_binds": True},
+ dialect=self.db.engine.dialect,
+ )
+ )
+
+ @property
+ def unique_fields(self) -> list[str]:
+ return getattr(self.object_type, "__attr_unique__", [])
+
+ @with_session
+ def is_unique(self, obj: StashT, session: Session = None) -> bool:
+ unique_fields = self.unique_fields
+ if not unique_fields:
+ return True
+
+ filters = []
+ for field_name in unique_fields:
+ field_value = getattr(obj, field_name, None)
+ if not is_json_primitive(field_value):
+ raise StashException(
+ f"Cannot check uniqueness of non-primitive field {field_name}"
+ )
+ if field_value is None:
+ continue
+ filters.append((field_name, "eq", field_value))
+
+ query = self.query()
+ query = query.filter_or(
+ *filters,
+ )
+
+ results = query.execute(session).all()
+
+ if len(results) > 1:
+ return False
+ elif len(results) == 1:
+ result = results[0]
+ return result.id == obj.id
+ return True
+
+ @with_session
+ def exists(
+ self, credentials: SyftVerifyKey, uid: UID, session: Session = None
+ ) -> bool:
+ # TODO should be @as_result
+ # TODO needs credentials check?
+ # TODO use COUNT(*) instead of SELECT
+ query = self.query().filter("id", "eq", uid)
+ result = query.execute(session).first()
+ return result is not None
+
+ @as_result(SyftException, StashException, NotFoundException)
+ @with_session
+ def get_by_uid(
+ self,
+ credentials: SyftVerifyKey,
+ uid: UID,
+ has_permission: bool = False,
+ session: Session = None,
+ ) -> StashT:
+ return self.get_one(
+ credentials=credentials,
+ filters={"id": uid},
+ has_permission=has_permission,
+ session=session,
+ ).unwrap()
+
+ def _get_field_filter(
+ self,
+ field_name: str,
+ field_value: Any,
+ table: Table | None = None,
+ ) -> sa.sql.elements.BinaryExpression:
+ table = table if table is not None else self.table
+ if field_name == "id":
+ uid_field_value = UID(field_value)
+ return table.c.id == uid_field_value
+
+ json_value = serialize_json(field_value)
+ if self.db.engine.dialect.name == "sqlite":
+ return table.c.fields[field_name] == func.json_quote(json_value)
+ elif self.db.engine.dialect.name == "postgresql":
+ return table.c.fields[field_name].astext == cast(json_value, sa.String)
+
+ @as_result(SyftException, StashException, NotFoundException)
+ def get_index(
+ self, credentials: SyftVerifyKey, index: int, has_permission: bool = False
+ ) -> StashT:
+ order_by, sort_order = self.object_type.__order_by__
+ if index < 0:
+ index = -1 - index
+ sort_order = "desc" if sort_order == "asc" else "asc"
+
+ items = self.get_all(
+ credentials,
+ has_permission=has_permission,
+ limit=1,
+ offset=index,
+ order_by=order_by,
+ sort_order=sort_order,
+ ).unwrap()
+
+ if len(items) == 0:
+ raise NotFoundException(f"No item found at index {index}")
+ return items[0]
+
+ def row_as_obj(self, row: Row) -> StashT:
+ # TODO make unwrappable serde
+ return deserialize_json(row.fields)
+
+ @with_session
+ def get_role(
+ self, credentials: SyftVerifyKey, session: Session = None
+ ) -> ServiceRole:
+ # relative
+ from ...service.user.user import User
+
+ Base = SQLiteBase if self._is_sqlite() else PostgresBase
+
+ # TODO error handling
+ if Base.metadata.tables.get("User") is None:
+ # if User table does not exist, we assume the user is a guest
+ # this happens when we create stashes in tests
+ return ServiceRole.GUEST
+
+ try:
+ query = self.query(User).filter("verify_key", "eq", credentials)
+ except Exception as e:
+ print("Error getting role", e)
+ raise e
+
+ user = query.execute(session).first()
+ if user is None:
+ return ServiceRole.GUEST
+
+ return self.row_as_obj(user).role
+
+ def _get_permission_filter_from_permisson(
+ self,
+ permission: ActionObjectPermission,
+ ) -> sa.sql.elements.BinaryExpression:
+ permission_string = permission.permission_string
+ compound_permission_string = permission.compound_permission_string
+
+ if self.db.engine.dialect.name == "postgresql":
+ permission_string = [permission_string] # type: ignore
+ compound_permission_string = [compound_permission_string] # type: ignore
+ return sa.or_(
+ self.table.c.permissions.contains(permission_string),
+ self.table.c.permissions.contains(compound_permission_string),
+ )
+
+ @with_session
+ def _apply_permission_filter(
+ self,
+ stmt: T,
+ *,
+ credentials: SyftVerifyKey,
+ permission: ActionPermission = ActionPermission.READ,
+ has_permission: bool = False,
+ session: Session = None,
+ ) -> T:
+ if has_permission:
+ # ignoring permissions
+ return stmt
+ role = self.get_role(credentials, session=session)
+ if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER):
+ # admins and data owners have all permissions
+ return stmt
+
+ action_object_permission = ActionObjectPermission(
+ uid=UID(), # dummy uid, we just need the permission string
+ credentials=credentials,
+ permission=permission,
+ )
+
+ stmt = stmt.where(
+ self._get_permission_filter_from_permisson(
+ permission=action_object_permission
+ )
+ )
+ return stmt
+
+ @as_result(SyftException, StashException)
+ @with_session
+ def set(
+ self,
+ credentials: SyftVerifyKey,
+ obj: StashT,
+ add_permissions: list[ActionObjectPermission] | None = None,
+ add_storage_permission: bool = True, # TODO: check the default value
+ ignore_duplicates: bool = False,
+ session: Session = None,
+ ) -> StashT:
+ if not self.allow_any_type:
+ self.check_type(obj, self.object_type).unwrap()
+ uid = obj.id
+
+ # check if the object already exists
+ if self.exists(credentials, uid) or not self.is_unique(obj):
+ if ignore_duplicates:
+ return obj
+ unique_fields_str = ", ".join(self.unique_fields)
+ raise StashException(
+ public_message=f"Duplication Key Error for {obj}.\n"
+ f"The fields that should be unique are {unique_fields_str}."
+ )
+
+ permissions = self.get_ownership_permissions(uid, credentials)
+ if add_permissions is not None:
+ add_permission_strings = [p.permission_string for p in add_permissions]
+ permissions.extend(add_permission_strings)
+
+ storage_permissions = []
+ if add_storage_permission:
+ storage_permissions.append(
+ self.server_uid.no_dash,
+ )
+
+ fields = serialize_json(obj)
+ try:
+ # check if the fields are deserializable
+ # TODO: Ideally, we want to make sure we don't serialize what we cannot deserialize
+ # and remove this check.
+ deserialize_json(fields)
+ except Exception as e:
+ raise StashException(
+ f"Error serializing object: {e}. Some fields are invalid."
+ )
+
+ # create the object with the permissions
+ stmt = self.table.insert().values(
+ id=uid,
+ fields=fields,
+ permissions=permissions,
+ storage_permissions=storage_permissions,
+ )
+ session.execute(stmt)
+ session.commit()
+ return self.get_by_uid(credentials, uid, session=session).unwrap()
+
+ @as_result(ValidationError, AttributeError)
+ def apply_partial_update(
+ self, original_obj: StashT, update_obj: SyftObject
+ ) -> StashT:
+ for key, value in update_obj.__dict__.items():
+ if value is Empty:
+ continue
+
+ if key in original_obj.__dict__:
+ setattr(original_obj, key, value)
+ else:
+ raise AttributeError(
+ f"{type(update_obj).__name__}.{key} not found in {type(original_obj).__name__}"
+ )
+
+ # validate the new fields
+ self.object_type.model_validate(original_obj)
+ return original_obj
+
+ @as_result(StashException, NotFoundException, AttributeError, ValidationError)
+ @with_session
+ def update(
+ self,
+ credentials: SyftVerifyKey,
+ obj: StashT,
+ has_permission: bool = False,
+ session: Session = None,
+ ) -> StashT:
+ """
+ NOTE: We cannot do partial updates on the database,
+ because we are using computed fields that are not known to the DB:
+ - serialize_json will add computed fields to the JSON stored in the database
+ - If we update a single field in the JSON, the computed fields can get out of sync.
+ - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed.
+ """
+
+ if issubclass(type(obj), PartialSyftObject):
+ original_obj = self.get_by_uid(
+ credentials, obj.id, session=session
+ ).unwrap()
+ obj = self.apply_partial_update(
+ original_obj=original_obj, update_obj=obj
+ ).unwrap()
+
+ # TODO has_permission is not used
+ if not self.is_unique(obj):
+ raise StashException(f"Some fields are not unique for {type(obj).__name__}")
+
+ stmt = self.table.update().where(self._get_field_filter("id", obj.id))
+ stmt = self._apply_permission_filter(
+ stmt,
+ credentials=credentials,
+ permission=ActionPermission.WRITE,
+ has_permission=has_permission,
+ session=session,
+ )
+ fields = serialize_json(obj)
+ try:
+ deserialize_json(fields)
+ except Exception as e:
+ raise StashException(
+ f"Error serializing object: {e}. Some fields are invalid."
+ )
+ stmt = stmt.values(fields=fields)
+
+ result = session.execute(stmt)
+ session.commit()
+ if result.rowcount == 0:
+ raise NotFoundException(
+ f"{self.object_type.__name__}: {obj.id} not found or no permission to update."
+ )
+ return self.get_by_uid(credentials, obj.id).unwrap()
+
+ @as_result(StashException, NotFoundException)
+ @with_session
+ def delete_by_uid(
+ self,
+ credentials: SyftVerifyKey,
+ uid: UID,
+ has_permission: bool = False,
+ session: Session = None,
+ ) -> UID:
+ stmt = self.table.delete().where(self._get_field_filter("id", uid))
+ stmt = self._apply_permission_filter(
+ stmt,
+ credentials=credentials,
+ permission=ActionPermission.WRITE,
+ has_permission=has_permission,
+ session=session,
+ )
+ result = session.execute(stmt)
+ session.commit()
+ if result.rowcount == 0:
+ raise NotFoundException(
+ f"{self.object_type.__name__}: {uid} not found or no permission to delete."
+ )
+ return uid
+
+ @as_result(StashException)
+ @with_session
+ def get_one(
+ self,
+ credentials: SyftVerifyKey,
+ filters: dict[str, Any] | None = None,
+ has_permission: bool = False,
+ order_by: str | None = None,
+ sort_order: str | None = None,
+ offset: int = 0,
+ session: Session = None,
+ ) -> StashT:
+ """
+ Get first objects from the stash, optionally filtered.
+
+ Args:
+ credentials (SyftVerifyKey): credentials of the user
+ filters (dict[str, Any] | None, optional): dictionary of filters,
+ where the key is the field name and the value is the filter value.
+ Operators other than equals can be used in the key,
+ e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None.
+ has_permission (bool, optional): If True, overrides the permission check.
+ Defaults to False.
+ order_by (str | None, optional): If provided, the results will be ordered by this field.
+ If not provided, the default order and field defined on the SyftObject.__order_by__ are used.
+ Defaults to None.
+ sort_order (str | None, optional): "asc" or "desc" If not defined,
+ the default order defined on the SyftObject.__order_by__ is used.
+ Defaults to None.
+ offset (int, optional): offset the results. Defaults to 0.
+
+ Returns:
+ list[StashT]: list of objects.
+ """
+ query = self.query()
+
+ if not has_permission:
+ role = self.get_role(credentials, session=session)
+ query = query.with_permissions(credentials, role)
+
+ for field_name, operator, field_value in parse_filters(filters):
+ query = query.filter(field_name, operator, field_value)
+
+ query = query.order_by(order_by, sort_order).offset(offset).limit(1)
+ result = query.execute(session).first()
+ if result is None:
+ raise NotFoundException(f"{self.object_type.__name__}: not found")
+
+ return self.row_as_obj(result)
+
+ @as_result(StashException)
+ @with_session
+ def get_all(
+ self,
+ credentials: SyftVerifyKey,
+ filters: dict[str, Any] | None = None,
+ has_permission: bool = False,
+ order_by: str | None = None,
+ sort_order: str | None = None,
+ limit: int | None = None,
+ offset: int = 0,
+ session: Session = None,
+ ) -> list[StashT]:
+ """
+ Get all objects from the stash, optionally filtered.
+
+ Args:
+ credentials (SyftVerifyKey): credentials of the user
+ filters (dict[str, Any] | None, optional): dictionary of filters,
+ where the key is the field name and the value is the filter value.
+ Operators other than equals can be used in the key,
+ e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None.
+ has_permission (bool, optional): If True, overrides the permission check.
+ Defaults to False.
+ order_by (str | None, optional): If provided, the results will be ordered by this field.
+ If not provided, the default order and field defined on the SyftObject.__order_by__ are used.
+ Defaults to None.
+ sort_order (str | None, optional): "asc" or "desc" If not defined,
+ the default order defined on the SyftObject.__order_by__ is used.
+ Defaults to None.
+ limit (int | None, optional): limit the number of results. Defaults to None.
+ offset (int, optional): offset the results. Defaults to 0.
+
+ Returns:
+ list[StashT]: list of objects.
+ """
+ query = self.query()
+
+ if not has_permission:
+ role = self.get_role(credentials, session=session)
+ query = query.with_permissions(credentials, role)
+
+ for field_name, operator, field_value in parse_filters(filters):
+ query = query.filter(field_name, operator, field_value)
+
+ query = query.order_by(order_by, sort_order).limit(limit).offset(offset)
+ result = query.execute(session).all()
+ return [self.row_as_obj(row) for row in result]
+
+ # PERMISSIONS
+ def get_ownership_permissions(
+ self, uid: UID, credentials: SyftVerifyKey
+ ) -> list[str]:
+ return [
+ ActionObjectOWNER(uid=uid, credentials=credentials).permission_string,
+ ActionObjectWRITE(uid=uid, credentials=credentials).permission_string,
+ ActionObjectREAD(uid=uid, credentials=credentials).permission_string,
+ ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string,
+ ]
+
+ @as_result(NotFoundException)
+ @with_session
+ def add_permission(
+ self,
+ permission: ActionObjectPermission,
+ session: Session = None,
+ ignore_missing: bool = False,
+ ) -> None:
+ try:
+ existing_permissions = self._get_permissions_for_uid(
+ permission.uid, session=session
+ ).unwrap()
+ except NotFoundException:
+ if ignore_missing:
+ return None
+ raise
+
+ existing_permissions.add(permission.permission_string)
+
+ stmt = self.table.update().where(self.table.c.id == permission.uid)
+ stmt = stmt.values(permissions=list(existing_permissions))
+ session.execute(stmt)
+ session.commit()
+
+ return None
+
+ @as_result(NotFoundException)
+ @with_session
+ def add_permissions(
+ self,
+ permissions: list[ActionObjectPermission],
+ ignore_missing: bool = False,
+ session: Session = None,
+ ) -> None:
+ for permission in permissions:
+ self.add_permission(
+ permission, session=session, ignore_missing=ignore_missing
+ ).unwrap()
+ return None
+
+ @with_session
+ def remove_permission(
+ self, permission: ActionObjectPermission, session: Session = None
+ ) -> None:
+ # TODO not threadsafe
+ try:
+ permissions = self._get_permissions_for_uid(permission.uid).unwrap()
+ permissions.remove(permission.permission_string)
+ except (NotFoundException, KeyError):
+ # TODO add error handling to permissions
+ return None
+
+ stmt = (
+ self.table.update()
+ .where(self.table.c.id == permission.uid)
+ .values(permissions=list(permissions))
+ )
+ session.execute(stmt)
+ session.commit()
+ return None
+
+ @with_session
+ def has_permission(
+ self, permission: ActionObjectPermission, session: Session = None
+ ) -> bool:
+ if self.get_role(permission.credentials, session=session) in (
+ ServiceRole.ADMIN,
+ ServiceRole.DATA_OWNER,
+ ):
+ return True
+ return self.has_permissions([permission], session=session)
+
+ @with_session
+ def has_permissions(
+ self, permissions: list[ActionObjectPermission], session: Session = None
+ ) -> bool:
+ # TODO: we should use a permissions table to check all permissions at once
+
+ permission_filters = [
+ sa.and_(
+ self._get_field_filter("id", p.uid),
+ self._get_permission_filter_from_permisson(permission=p),
+ )
+ for p in permissions
+ ]
+
+ stmt = self.table.select().where(
+ sa.and_(
+ *permission_filters,
+ ),
+ )
+ result = session.execute(stmt).first()
+ return result is not None
+
+ @as_result(StashException)
+ @with_session
+ def _get_permissions_for_uid(self, uid: UID, session: Session = None) -> Set[str]: # noqa: UP006
+ stmt = select(self.table.c.permissions).where(self.table.c.id == uid)
+ result = session.execute(stmt).scalar_one_or_none()
+ if result is None:
+ raise NotFoundException(f"No permissions found for uid: {uid}")
+ return set(result)
+
+ @as_result(StashException)
+ @with_session
+ def get_all_permissions(self, session: Session = None) -> dict[UID, Set[str]]: # noqa: UP006
+ stmt = select(self.table.c.id, self.table.c.permissions)
+ results = session.execute(stmt).all()
+ return {UID(row.id): set(row.permissions) for row in results}
+
+ # STORAGE PERMISSIONS
+ @with_session
+ def has_storage_permission(
+ self, permission: StoragePermission, session: Session = None
+ ) -> bool:
+ return self.has_storage_permissions([permission], session=session)
+
+ @with_session
+ def has_storage_permissions(
+ self, permissions: list[StoragePermission], session: Session = None
+ ) -> bool:
+ permission_filters = [
+ sa.and_(
+ self._get_field_filter("id", p.uid),
+ self.table.c.storage_permissions.contains(
+ p.server_uid.no_dash
+ if self._is_sqlite()
+ else [p.server_uid.no_dash]
+ ),
+ )
+ for p in permissions
+ ]
+
+ stmt = self.table.select().where(
+ sa.and_(
+ *permission_filters,
+ )
+ )
+ result = session.execute(stmt).first()
+ return result is not None
+
+ @as_result(StashException)
+ @with_session
+ def get_all_storage_permissions(
+ self, session: Session = None
+ ) -> dict[UID, Set[UID]]: # noqa: UP006
+ stmt = select(self.table.c.id, self.table.c.storage_permissions)
+ results = session.execute(stmt).all()
+
+ return {
+ UID(row.id): {UID(uid) for uid in row.storage_permissions}
+ for row in results
+ }
+
+ @as_result(NotFoundException)
+ @with_session
+ def add_storage_permissions(
+ self,
+ permissions: list[StoragePermission],
+ session: Session = None,
+ ignore_missing: bool = False,
+ ) -> None:
+ for permission in permissions:
+ self.add_storage_permission(
+ permission, session=session, ignore_missing=ignore_missing
+ ).unwrap()
+
+ return None
+
+ @as_result(NotFoundException)
+ @with_session
+ def add_storage_permission(
+ self,
+ permission: StoragePermission,
+ session: Session = None,
+ ignore_missing: bool = False,
+ ) -> None:
+ try:
+ existing_permissions = self._get_storage_permissions_for_uid(
+ permission.uid, session=session
+ ).unwrap()
+ except NotFoundException:
+ if ignore_missing:
+ return None
+ raise
+
+ existing_permissions.add(permission.server_uid)
+
+ stmt = (
+ self.table.update()
+ .where(self.table.c.id == permission.uid)
+ .values(storage_permissions=[str(uid) for uid in existing_permissions])
+ )
+
+ session.execute(stmt)
+
+ @with_session
+ def remove_storage_permission(
+ self, permission: StoragePermission, session: Session = None
+ ) -> None:
+ try:
+ permissions = self._get_storage_permissions_for_uid(
+ permission.uid, session=session
+ ).unwrap()
+ permissions.discard(permission.server_uid)
+ except NotFoundException:
+ # TODO add error handling to permissions
+ return None
+
+ stmt = (
+ self.table.update()
+ .where(self.table.c.id == permission.uid)
+ .values(storage_permissions=[str(uid) for uid in permissions])
+ )
+ session.execute(stmt)
+ session.commit()
+ return None
+
+ @as_result(StashException)
+ @with_session
+ def _get_storage_permissions_for_uid(
+ self, uid: UID, session: Session = None
+ ) -> Set[UID]: # noqa: UP006
+ stmt = select(self.table.c.id, self.table.c.storage_permissions).where(
+ self.table.c.id == uid
+ )
+ result = session.execute(stmt).first()
+ if result is None:
+ raise NotFoundException(f"No storage permissions found for uid: {uid}")
+ return {UID(uid) for uid in result.storage_permissions}
diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py
deleted file mode 100644
index ca0f3e1f33a..00000000000
--- a/packages/syft/src/syft/store/dict_document_store.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# future
-from __future__ import annotations
-
-# stdlib
-from typing import Any
-
-# third party
-from pydantic import Field
-
-# relative
-from ..serde.serializable import serializable
-from ..server.credentials import SyftVerifyKey
-from ..types import uid
-from .document_store import DocumentStore
-from .document_store import StoreConfig
-from .kv_document_store import KeyValueBackingStore
-from .kv_document_store import KeyValueStorePartition
-from .locks import LockingConfig
-from .locks import ThreadingLockingConfig
-
-
-@serializable(canonical_name="DictBackingStore", version=1)
-class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc]
- # TODO: fix the mypy issue
- """Dictionary-based Store core logic"""
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__()
- self._ddtype = kwargs.get("ddtype", None)
-
- def __getitem__(self, key: Any) -> Any:
- try:
- value = super().__getitem__(key)
- return value
- except KeyError as e:
- if self._ddtype:
- return self._ddtype()
- raise e
-
-
-@serializable(canonical_name="DictStorePartition", version=1)
-class DictStorePartition(KeyValueStorePartition):
- """Dictionary-based StorePartition
-
- Parameters:
- `settings`: PartitionSettings
- PySyft specific settings, used for indexing and partitioning
- `store_config`: DictStoreConfig
- DictStore specific configuration
- """
-
- def prune(self) -> None:
- self.init_store().unwrap()
-
-
-# the base document store is already a dict but we can change it later
-@serializable(canonical_name="DictDocumentStore", version=1)
-class DictDocumentStore(DocumentStore):
- """Dictionary-based Document Store
-
- Parameters:
- `store_config`: DictStoreConfig
- Dictionary Store specific configuration, containing the store type and the backing store type
- """
-
- partition_type = DictStorePartition
-
- def __init__(
- self,
- server_uid: uid,
- root_verify_key: SyftVerifyKey | None,
- store_config: DictStoreConfig | None = None,
- ) -> None:
- if store_config is None:
- store_config = DictStoreConfig()
- super().__init__(
- server_uid=server_uid,
- root_verify_key=root_verify_key,
- store_config=store_config,
- )
-
- def reset(self) -> None:
- for partition in self.partitions.values():
- partition.prune()
-
-
-@serializable()
-class DictStoreConfig(StoreConfig):
- __canonical_name__ = "DictStoreConfig"
- """Dictionary-based configuration
-
- Parameters:
- `store_type`: Type[DocumentStore]
- The Document type used. Default: DictDocumentStore
- `backing_store`: Type[KeyValueBackingStore]
- The backend type used. Default: DictBackingStore
- locking_config: LockingConfig
- The config used for store locking. Available options:
- * NoLockingConfig: no locking, ideal for single-thread stores.
- * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores.
- Defaults to ThreadingLockingConfig.
- """
-
- store_type: type[DocumentStore] = DictDocumentStore
- backing_store: type[KeyValueBackingStore] = DictBackingStore
- locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig)
diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py
index cc97802a08b..98cccb82568 100644
--- a/packages/syft/src/syft/store/document_store.py
+++ b/packages/syft/src/syft/store/document_store.py
@@ -188,16 +188,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey:
def as_dict(self) -> dict[str, Any]:
return {self.key: self.value}
- @property
- def as_dict_mongo(self) -> dict[str, Any]:
- key = self.key
- if key == "id":
- key = "_id"
- if self.type_list:
- # We want to search inside the list of values
- return {key: {"$in": self.value}}
- return {key: self.value}
-
@serializable(canonical_name="PartitionKeysWithUID", version=1)
class PartitionKeysWithUID(PartitionKeys):
@@ -273,21 +263,6 @@ def as_dict(self) -> dict:
qk_dict[qk_key] = qk_value
return qk_dict
- @property
- def as_dict_mongo(self) -> dict:
- qk_dict = {}
- for qk in self.all:
- qk_key = qk.key
- qk_value = qk.value
- if qk_key == "id":
- qk_key = "_id"
- if qk.type_list:
- # We want to search inside the list of values
- qk_dict[qk_key] = {"$in": qk_value}
- else:
- qk_dict[qk_key] = qk_value
- return qk_dict
-
UIDPartitionKey = PartitionKey(key="id", type_=UID)
@@ -627,30 +602,8 @@ def __init__(
def __has_admin_permissions(
self, settings: PartitionSettings
) -> Callable[[SyftVerifyKey], bool]:
- # relative
- from ..service.user.user import User
- from ..service.user.user_roles import ServiceRole
- from ..service.user.user_stash import UserStash
-
- # leave out UserStash to avoid recursion
- # TODO: pass the callback from BaseStash instead of DocumentStore
- # so that this works with UserStash after the sqlite thread fix is merged
- if settings.object_type is User:
- return lambda credentials: False
-
- user_stash = UserStash(store=self)
-
def has_admin_permissions(credentials: SyftVerifyKey) -> bool:
- res = user_stash.get_by_verify_key(
- credentials=credentials,
- verify_key=credentials,
- )
-
- return (
- res.is_ok()
- and (user := res.ok()) is not None
- and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN)
- )
+ return credentials == self.root_verify_key
return has_admin_permissions
@@ -702,7 +655,6 @@ class NewBaseStash:
partition: StorePartition
def __init__(self, store: DocumentStore) -> None:
- self.store = store
self.partition = store.partition(type(self).settings)
@as_result(StashException)
diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py
index 5d9c29c9d9d..d3e40372842 100644
--- a/packages/syft/src/syft/store/linked_obj.py
+++ b/packages/syft/src/syft/store/linked_obj.py
@@ -76,7 +76,7 @@ def update_with_context(
raise SyftException(public_message=f"context {context}'s server is None")
service = context.server.get_service(self.service_type)
if hasattr(service, "stash"):
- result = service.stash.update(credentials, obj)
+ result = service.stash.update(credentials, obj).unwrap()
else:
raise SyftException(
public_message=f"service {service} does not have a stash"
diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py
deleted file mode 100644
index 9767059a1cb..00000000000
--- a/packages/syft/src/syft/store/mongo_client.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# stdlib
-import logging
-from threading import Lock
-from typing import Any
-
-# third party
-from pymongo.collection import Collection as MongoCollection
-from pymongo.database import Database as MongoDatabase
-from pymongo.errors import ConnectionFailure
-from pymongo.mongo_client import MongoClient as PyMongoClient
-
-# relative
-from ..serde.serializable import serializable
-from ..types.errors import SyftException
-from ..types.result import as_result
-from ..util.telemetry import TRACING_ENABLED
-from .document_store import PartitionSettings
-from .document_store import StoreClientConfig
-from .document_store import StoreConfig
-from .mongo_codecs import SYFT_CODEC_OPTIONS
-
-if TRACING_ENABLED:
- try:
- # third party
- from opentelemetry.instrumentation.pymongo import PymongoInstrumentor
-
- PymongoInstrumentor().instrument()
- message = "> Added OTEL PymongoInstrumentor"
- print(message)
- logger = logging.getLogger(__name__)
- logger.info(message)
- except Exception: # nosec
- pass
-
-
-@serializable(canonical_name="MongoStoreClientConfig", version=1)
-class MongoStoreClientConfig(StoreClientConfig):
- """
- Paramaters:
- `hostname`: optional string
- hostname or IP address or Unix domain socket path of a single mongod or mongos
- instance to connect to, or a mongodb URI, or a list of hostnames (but no more
- than one mongodb URI). If `host` is an IPv6 literal it must be enclosed in '['
- and ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for localhost).
- Multihomed and round robin DNS addresses are **not** supported.
- `port` : optional int
- port number on which to connect
- `directConnection`: bool
- if ``True``, forces this client to connect directly to the specified MongoDB host
- as a standalone. If ``false``, the client connects to the entire replica set of which
- the given MongoDB host(s) is a part. If this is ``True`` and a mongodb+srv:// URI
- or a URI containing multiple seeds is provided, an exception will be raised.
- `maxPoolSize`: int. Default 100
- The maximum allowable number of concurrent connections to each connected server.
- Requests to a server will block if there are `maxPoolSize` outstanding connections
- to the requested server. Defaults to 100. Can be either 0 or None, in which case
- there is no limit on the number of concurrent connections.
- `minPoolSize` : int. Default 0
- The minimum required number of concurrent connections that the pool will maintain
- to each connected server. Default is 0.
- `maxIdleTimeMS`: int
- The maximum number of milliseconds that a connection can remain idle in the pool
- before being removed and replaced. Defaults to `None` (no limit).
- `appname`: string
- The name of the application that created this MongoClient instance. The server will
- log this value upon establishing each connection. It is also recorded in the slow
- query log and profile collections.
- `maxConnecting`: optional int
- The maximum number of connections that each pool can establish concurrently.
- Defaults to `2`.
- `timeoutMS`: (integer or None)
- Controls how long (in milliseconds) the driver will wait when executing an operation
- (including retry attempts) before raising a timeout error. ``0`` or ``None`` means
- no timeout.
- `socketTimeoutMS`: (integer or None)
- Controls how long (in milliseconds) the driver will wait for a response after sending
- an ordinary (non-monitoring) database operation before concluding that a network error
- has occurred. ``0`` or ``None`` means no timeout. Defaults to ``None`` (no timeout).
- `connectTimeoutMS`: (integer or None)
- Controls how long (in milliseconds) the driver will wait during server monitoring when
- connecting a new socket to a server before concluding the server is unavailable.
- ``0`` or ``None`` means no timeout. Defaults to ``20000`` (20 seconds).
- `serverSelectionTimeoutMS`: (integer)
- Controls how long (in milliseconds) the driver will wait to find an available, appropriate
- server to carry out a database operation; while it is waiting, multiple server monitoring
- operations may be carried out, each controlled by `connectTimeoutMS`.
- Defaults to ``120000`` (120 seconds).
- `waitQueueTimeoutMS`: (integer or None)
- How long (in milliseconds) a thread will wait for a socket from the pool if the pool
- has no free sockets. Defaults to ``None`` (no timeout).
- `heartbeatFrequencyMS`: (optional)
- The number of milliseconds between periodic server checks, or None to accept the default
- frequency of 10 seconds.
- # Auth
- username: str
- Database username
- password: str
- Database pass
- authSource: str
- The database to authenticate on.
- Defaults to the database specified in the URI, if provided, or to “admin”.
- tls: bool
- If True, create the connection to the server using transport layer security.
- Defaults to False.
- # Testing and connection reuse
- client: Optional[PyMongoClient]
- If provided, this client is reused. Default = None
-
- """
-
- # Connection
- hostname: str | None = "127.0.0.1"
- port: int | None = None
- directConnection: bool = False
- maxPoolSize: int = 200
- minPoolSize: int = 0
- maxIdleTimeMS: int | None = None
- maxConnecting: int = 3
- timeoutMS: int = 0
- socketTimeoutMS: int = 0
- connectTimeoutMS: int = 20000
- serverSelectionTimeoutMS: int = 120000
- waitQueueTimeoutMS: int | None = None
- heartbeatFrequencyMS: int = 10000
- appname: str = "pysyft"
- # Auth
- username: str | None = None
- password: str | None = None
- authSource: str = "admin"
- tls: bool | None = False
- # Testing and connection reuse
- client: Any = None
-
- # this allows us to have one connection per `Server` object
- # in the MongoClientCache
- server_obj_python_id: int | None = None
-
-
-class MongoClientCache:
- __client_cache__: dict[int, type["MongoClient"] | None] = {}
- _lock: Lock = Lock()
-
- @classmethod
- def from_cache(cls, config: MongoStoreClientConfig) -> PyMongoClient | None:
- return cls.__client_cache__.get(hash(str(config)), None)
-
- @classmethod
- def set_cache(cls, config: MongoStoreClientConfig, client: PyMongoClient) -> None:
- with cls._lock:
- cls.__client_cache__[hash(str(config))] = client
-
-
-class MongoClient:
- client: PyMongoClient = None
-
- def __init__(self, config: MongoStoreClientConfig, cache: bool = True) -> None:
- self.config = config
- if config.client is not None:
- self.client = config.client
- elif cache:
- self.client = MongoClientCache.from_cache(config=config)
-
- if not cache or self.client is None:
- self.connect(config=config).unwrap()
-
- @as_result(SyftException)
- def connect(self, config: MongoStoreClientConfig) -> bool:
- self.client = PyMongoClient(
- # Connection
- host=config.hostname,
- port=config.port,
- directConnection=config.directConnection,
- maxPoolSize=config.maxPoolSize,
- minPoolSize=config.minPoolSize,
- maxIdleTimeMS=config.maxIdleTimeMS,
- maxConnecting=config.maxConnecting,
- timeoutMS=config.timeoutMS,
- socketTimeoutMS=config.socketTimeoutMS,
- connectTimeoutMS=config.connectTimeoutMS,
- serverSelectionTimeoutMS=config.serverSelectionTimeoutMS,
- waitQueueTimeoutMS=config.waitQueueTimeoutMS,
- heartbeatFrequencyMS=config.heartbeatFrequencyMS,
- appname=config.appname,
- # Auth
- username=config.username,
- password=config.password,
- authSource=config.authSource,
- tls=config.tls,
- uuidRepresentation="standard",
- )
- MongoClientCache.set_cache(config=config, client=self.client)
- try:
- # Check if mongo connection is still up
- self.client.admin.command("ping")
- except ConnectionFailure as e:
- self.client = None
- raise SyftException.from_exception(e)
-
- return True
-
- @as_result(SyftException)
- def with_db(self, db_name: str) -> MongoDatabase:
- try:
- return self.client[db_name]
- except BaseException as e:
- raise SyftException.from_exception(e)
-
- @as_result(SyftException)
- def with_collection(
- self,
- collection_settings: PartitionSettings,
- store_config: StoreConfig,
- collection_name: str | None = None,
- ) -> MongoCollection:
- db = self.with_db(db_name=store_config.db_name).unwrap()
-
- try:
- collection_name = (
- collection_name
- if collection_name is not None
- else collection_settings.name
- )
- collection = db.get_collection(
- name=collection_name, codec_options=SYFT_CODEC_OPTIONS
- )
- except BaseException as e:
- raise SyftException.from_exception(e)
-
- return collection
-
- @as_result(SyftException)
- def with_collection_permissions(
- self, collection_settings: PartitionSettings, store_config: StoreConfig
- ) -> MongoCollection:
- """
- For each collection, create a corresponding collection
- that store the permissions to the data in that collection
- """
- db = self.with_db(db_name=store_config.db_name).unwrap()
-
- try:
- collection_permissions_name: str = collection_settings.name + "_permissions"
- collection_permissions = db.get_collection(
- name=collection_permissions_name, codec_options=SYFT_CODEC_OPTIONS
- )
- except BaseException as e:
- raise SyftException.from_exception(e)
- return collection_permissions
-
- @as_result(SyftException)
- def with_collection_storage_permissions(
- self, collection_settings: PartitionSettings, store_config: StoreConfig
- ) -> MongoCollection:
- """
- For each collection, create a corresponding collection
- that store the permissions to the data in that collection
- """
- db = self.with_db(db_name=store_config.db_name).unwrap()
-
- try:
- collection_storage_permissions_name: str = (
- collection_settings.name + "_storage_permissions"
- )
- storage_permissons_collection = db.get_collection(
- name=collection_storage_permissions_name,
- codec_options=SYFT_CODEC_OPTIONS,
- )
- except BaseException as e:
- raise SyftException.from_exception(e)
-
- return storage_permissons_collection
-
- def close(self) -> None:
- self.client.close()
- MongoClientCache.__client_cache__.pop(hash(str(self.config)), None)
diff --git a/packages/syft/src/syft/store/mongo_codecs.py b/packages/syft/src/syft/store/mongo_codecs.py
deleted file mode 100644
index 08b7fa63562..00000000000
--- a/packages/syft/src/syft/store/mongo_codecs.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# stdlib
-from typing import Any
-
-# third party
-from bson import CodecOptions
-from bson.binary import Binary
-from bson.binary import USER_DEFINED_SUBTYPE
-from bson.codec_options import TypeDecoder
-from bson.codec_options import TypeRegistry
-
-# relative
-from ..serde.deserialize import _deserialize
-from ..serde.serialize import _serialize
-
-
-def fallback_syft_encoder(value: object) -> Binary:
- return Binary(_serialize(value, to_bytes=True), USER_DEFINED_SUBTYPE)
-
-
-class SyftMongoBinaryDecoder(TypeDecoder):
- bson_type = Binary
-
- def transform_bson(self, value: Any) -> Any:
- if value.subtype == USER_DEFINED_SUBTYPE:
- return _deserialize(value, from_bytes=True)
- return value
-
-
-syft_codecs = [SyftMongoBinaryDecoder()]
-syft_type_registry = TypeRegistry(syft_codecs, fallback_encoder=fallback_syft_encoder)
-SYFT_CODEC_OPTIONS = CodecOptions(type_registry=syft_type_registry)
diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py
deleted file mode 100644
index 805cc042cdf..00000000000
--- a/packages/syft/src/syft/store/mongo_document_store.py
+++ /dev/null
@@ -1,963 +0,0 @@
-# stdlib
-from collections.abc import Callable
-from typing import Any
-from typing import Set # noqa: UP035
-
-# third party
-from pydantic import Field
-from pymongo import ASCENDING
-from pymongo.collection import Collection as MongoCollection
-from typing_extensions import Self
-
-# relative
-from ..serde.deserialize import _deserialize
-from ..serde.serializable import serializable
-from ..serde.serialize import _serialize
-from ..server.credentials import SyftVerifyKey
-from ..service.action.action_permissions import ActionObjectEXECUTE
-from ..service.action.action_permissions import ActionObjectOWNER
-from ..service.action.action_permissions import ActionObjectPermission
-from ..service.action.action_permissions import ActionObjectREAD
-from ..service.action.action_permissions import ActionObjectWRITE
-from ..service.action.action_permissions import ActionPermission
-from ..service.action.action_permissions import StoragePermission
-from ..service.context import AuthedServiceContext
-from ..service.response import SyftSuccess
-from ..types.errors import SyftException
-from ..types.result import as_result
-from ..types.syft_object import SYFT_OBJECT_VERSION_1
-from ..types.syft_object import StorableObjectType
-from ..types.syft_object import SyftBaseObject
-from ..types.syft_object import SyftObject
-from ..types.transforms import TransformContext
-from ..types.transforms import transform
-from ..types.transforms import transform_method
-from ..types.uid import UID
-from .document_store import DocumentStore
-from .document_store import PartitionKey
-from .document_store import PartitionSettings
-from .document_store import QueryKey
-from .document_store import QueryKeys
-from .document_store import StoreConfig
-from .document_store import StorePartition
-from .document_store_errors import NotFoundException
-from .kv_document_store import KeyValueBackingStore
-from .locks import LockingConfig
-from .locks import NoLockingConfig
-from .mongo_client import MongoClient
-from .mongo_client import MongoStoreClientConfig
-
-
-@serializable()
-class MongoDict(SyftBaseObject):
- __canonical_name__ = "MongoDict"
- __version__ = SYFT_OBJECT_VERSION_1
-
- keys: list[Any]
- values: list[Any]
-
- @property
- def dict(self) -> dict[Any, Any]:
- return dict(zip(self.keys, self.values))
-
- @classmethod
- def from_dict(cls, input: dict) -> Self:
- return cls(keys=list(input.keys()), values=list(input.values()))
-
- def __repr__(self) -> str:
- return self.dict.__repr__()
-
-
-class MongoBsonObject(StorableObjectType, dict):
- pass
-
-
-def _repr_debug_(value: Any) -> str:
- if hasattr(value, "_repr_debug_"):
- return value._repr_debug_()
- return repr(value)
-
-
-def to_mongo(context: TransformContext) -> TransformContext:
- output = {}
- if context.obj:
- unique_keys_dict = context.obj._syft_unique_keys_dict()
- search_keys_dict = context.obj._syft_searchable_keys_dict()
- all_dict = unique_keys_dict
- all_dict.update(search_keys_dict)
- for k in all_dict:
- value = getattr(context.obj, k, "")
- # if the value is a method, store its value
- if callable(value):
- output[k] = value()
- else:
- output[k] = value
-
- output["__canonical_name__"] = context.obj.__canonical_name__
- output["__version__"] = context.obj.__version__
- output["__blob__"] = _serialize(context.obj, to_bytes=True)
- output["__arepr__"] = _repr_debug_(context.obj) # a comes first in alphabet
-
- if context.output and "id" in context.output:
- output["_id"] = context.output["id"]
-
- context.output = output
-
- return context
-
-
-@transform(SyftObject, MongoBsonObject)
-def syft_obj_to_mongo() -> list[Callable]:
- return [to_mongo]
-
-
-@transform_method(MongoBsonObject, SyftObject)
-def from_mongo(
- storage_obj: dict, context: TransformContext | None = None
-) -> SyftObject:
- return _deserialize(storage_obj["__blob__"], from_bytes=True)
-
-
-@serializable(attrs=["storage_type"], canonical_name="MongoStorePartition", version=1)
-class MongoStorePartition(StorePartition):
- """Mongo StorePartition
-
- Parameters:
- `settings`: PartitionSettings
- PySyft specific settings, used for partitioning and indexing.
- `store_config`: MongoStoreConfig
- Mongo specific configuration
- """
-
- storage_type: type[StorableObjectType] = MongoBsonObject
-
- @as_result(SyftException)
- def init_store(self) -> bool:
- super().init_store().unwrap()
- client = MongoClient(config=self.store_config.client_config)
- self._collection = client.with_collection(
- collection_settings=self.settings, store_config=self.store_config
- ).unwrap()
- self._permissions = client.with_collection_permissions(
- collection_settings=self.settings, store_config=self.store_config
- ).unwrap()
- self._storage_permissions = client.with_collection_storage_permissions(
- collection_settings=self.settings, store_config=self.store_config
- ).unwrap()
- return self._create_update_index().unwrap()
-
- # Potentially thread-unsafe methods.
- #
- # CAUTION:
- # * Don't use self.lock here.
- # * Do not call the public thread-safe methods here(with locking).
- # These methods are called from the public thread-safe API, and will hang the process.
-
- @as_result(SyftException)
- def _create_update_index(self) -> bool:
- """Create or update mongo database indexes"""
- collection: MongoCollection = self.collection.unwrap()
-
- def check_index_keys(
- current_keys: list[tuple[str, int]], new_index_keys: list[tuple[str, int]]
- ) -> bool:
- current_keys.sort()
- new_index_keys.sort()
- return current_keys == new_index_keys
-
- syft_obj = self.settings.object_type
-
- unique_attrs = getattr(syft_obj, "__attr_unique__", [])
- object_name = syft_obj.__canonical_name__
-
- new_index_keys = [(attr, ASCENDING) for attr in unique_attrs]
-
- try:
- current_indexes = collection.index_information()
- except BaseException as e:
- raise SyftException.from_exception(e)
- index_name = f"{object_name}_index_name"
-
- current_index_keys = current_indexes.get(index_name, None)
-
- if current_index_keys is not None:
- keys_same = check_index_keys(current_index_keys["key"], new_index_keys)
- if keys_same:
- return True
-
- # Drop current index, since incompatible with current object
- try:
- collection.drop_index(index_or_name=index_name)
- except Exception:
- raise SyftException(
- public_message=(
- f"Failed to drop index for object: {object_name}"
- f" with index keys: {current_index_keys}"
- )
- )
-
- # If no new indexes, then skip index creation
- if len(new_index_keys) == 0:
- return True
-
- try:
- collection.create_index(new_index_keys, unique=True, name=index_name)
- except Exception:
- raise SyftException(
- public_message=f"Failed to create index for {object_name} with index keys: {new_index_keys}"
- )
-
- return True
-
- @property
- @as_result(SyftException)
- def collection(self) -> MongoCollection:
- if not hasattr(self, "_collection"):
- self.init_store().unwrap()
- return self._collection
-
- @property
- @as_result(SyftException)
- def permissions(self) -> MongoCollection:
- if not hasattr(self, "_permissions"):
- self.init_store().unwrap()
- return self._permissions
-
- @property
- @as_result(SyftException)
- def storage_permissions(self) -> MongoCollection:
- if not hasattr(self, "_storage_permissions"):
- self.init_store().unwrap()
- return self._storage_permissions
-
- @as_result(SyftException)
- def set(self, *args: Any, **kwargs: Any) -> SyftObject:
- return self._set(*args, **kwargs).unwrap()
-
- @as_result(SyftException)
- def _set(
- self,
- credentials: SyftVerifyKey,
- obj: SyftObject,
- add_permissions: list[ActionObjectPermission] | None = None,
- add_storage_permission: bool = True,
- ignore_duplicates: bool = False,
- ) -> SyftObject:
- # TODO: Refactor this function since now it's doing both set and
- # update at the same time
- write_permission = ActionObjectWRITE(uid=obj.id, credentials=credentials)
- can_write: bool = self.has_permission(write_permission)
-
- store_query_key: QueryKey = self.settings.store_key.with_obj(obj)
- collection: MongoCollection = self.collection.unwrap()
-
- store_key_exists = (
- collection.find_one(store_query_key.as_dict_mongo) is not None
- )
- if (not store_key_exists) and (not self.item_keys_exist(obj, collection)):
- # attempt to claim ownership for writing
- can_write = self.take_ownership(
- uid=obj.id, credentials=credentials
- ).unwrap()
- elif not ignore_duplicates:
- unique_query_keys: QueryKeys = self.settings.unique_keys.with_obj(obj)
- keys = ", ".join(f"`{key.key}`" for key in unique_query_keys.all)
- raise SyftException(
- public_message=f"Duplication Key Error for {obj}.\nThe fields that should be unique are {keys}."
- )
- else:
- # we are not throwing an error, because we are ignoring duplicates
- # we are also not writing though
- return obj
-
- if not can_write:
- raise SyftException(
- public_message=f"No permission to write object with id {obj.id}"
- )
-
- storage_obj = obj.to(self.storage_type)
-
- collection.insert_one(storage_obj)
-
- # adding permissions
- read_permission = ActionObjectPermission(
- uid=obj.id,
- credentials=credentials,
- permission=ActionPermission.READ,
- )
- self.add_permission(read_permission)
-
- if add_permissions is not None:
- self.add_permissions(add_permissions)
-
- if add_storage_permission:
- self.add_storage_permission(
- StoragePermission(
- uid=obj.id,
- server_uid=self.server_uid,
- )
- )
-
- return obj
-
- def item_keys_exist(self, obj: SyftObject, collection: MongoCollection) -> bool:
- qks: QueryKeys = self.settings.unique_keys.with_obj(obj)
- query = {"$or": [{k: v} for k, v in qks.as_dict_mongo.items()]}
- res = collection.find_one(query)
- return res is not None
-
- @as_result(SyftException)
- def _update(
- self,
- credentials: SyftVerifyKey,
- qk: QueryKey,
- obj: SyftObject,
- has_permission: bool = False,
- overwrite: bool = False,
- allow_missing_keys: bool = False,
- ) -> SyftObject:
- collection: MongoCollection = self.collection.unwrap()
-
- # TODO: optimize the update. The ID should not be overwritten,
- # but the qk doesn't necessarily have to include the `id` field either.
-
- prev_obj = self._get_all_from_store(credentials, QueryKeys(qks=[qk])).unwrap()
- if len(prev_obj) == 0:
- raise SyftException(
- public_message=f"Failed to update missing values for query key: {qk} for type {type(obj)}"
- )
-
- prev_obj = prev_obj[0]
- if has_permission or self.has_permission(
- ActionObjectWRITE(uid=prev_obj.id, credentials=credentials)
- ):
- for key, value in obj.to_dict(exclude_empty=True).items():
- # we don't want to overwrite Mongo's "id_" or Syft's "id" on update
- if key == "id":
- # protected field
- continue
-
- # Overwrite the value if the key is already present
- setattr(prev_obj, key, value)
-
- # Create the Mongo object
- storage_obj = prev_obj.to(self.storage_type)
-
- try:
- collection.update_one(
- filter=qk.as_dict_mongo, update={"$set": storage_obj}
- )
- except Exception:
- raise SyftException(f"Failed to update obj: {obj} with qk: {qk}")
-
- return prev_obj
- else:
- raise SyftException(f"Failed to update obj {obj}, you have no permission")
-
- @as_result(SyftException)
- def _find_index_or_search_keys(
- self,
- credentials: SyftVerifyKey,
- index_qks: QueryKeys,
- search_qks: QueryKeys,
- order_by: PartitionKey | None = None,
- ) -> list[SyftObject]:
- # TODO: pass index as hint to find method
- qks = QueryKeys(qks=(list(index_qks.all) + list(search_qks.all)))
- return self._get_all_from_store(
- credentials=credentials, qks=qks, order_by=order_by
- ).unwrap()
-
- @property
- def data(self) -> dict:
- values: list = self._all(credentials=None, has_permission=True).unwrap()
- return {v.id: v for v in values}
-
- @as_result(SyftException)
- def _get(
- self,
- uid: UID,
- credentials: SyftVerifyKey,
- has_permission: bool | None = False,
- ) -> SyftObject:
- qks = QueryKeys.from_dict({"id": uid})
- res = self._get_all_from_store(
- credentials, qks, order_by=None, has_permission=has_permission
- ).unwrap()
- if len(res) == 0:
- raise NotFoundException
- else:
- return res[0]
-
- @as_result(SyftException)
- def _get_all_from_store(
- self,
- credentials: SyftVerifyKey,
- qks: QueryKeys,
- order_by: PartitionKey | None = None,
- has_permission: bool | None = False,
- ) -> list[SyftObject]:
- collection = self.collection.unwrap()
-
- if order_by is not None:
- storage_objs = collection.find(filter=qks.as_dict_mongo).sort(order_by.key)
- else:
- _default_key = "_id"
- storage_objs = collection.find(filter=qks.as_dict_mongo).sort(_default_key)
-
- syft_objs = []
- for storage_obj in storage_objs:
- obj = self.storage_type(storage_obj)
- transform_context = TransformContext(output={}, obj=obj)
-
- syft_obj = obj.to(self.settings.object_type, transform_context)
- if has_permission or self.has_permission(
- ActionObjectREAD(uid=syft_obj.id, credentials=credentials)
- ):
- syft_objs.append(syft_obj)
-
- return syft_objs
-
- @as_result(SyftException)
- def _delete(
- self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False
- ) -> SyftSuccess:
- if not (
- has_permission
- or self.has_permission(
- ActionObjectWRITE(uid=qk.value, credentials=credentials)
- )
- ):
- raise SyftException(
- public_message=f"You don't have permission to delete object with qk: {qk}"
- )
-
- collection = self.collection.unwrap()
- collection_permissions: MongoCollection = self.permissions.unwrap()
-
- qks = QueryKeys(qks=qk)
- # delete the object
- result = collection.delete_one(filter=qks.as_dict_mongo)
- # delete the object's permission
- result_permission = collection_permissions.delete_one(filter=qks.as_dict_mongo)
- if result.deleted_count == 1 and result_permission.deleted_count == 1:
- return SyftSuccess(message="Object and its permission are deleted")
- elif result.deleted_count == 0:
- raise SyftException(public_message=f"Failed to delete object with qk: {qk}")
- else:
- raise SyftException(
- public_message=f"Object with qk: {qk} was deleted, but failed to delete its corresponding permission"
- )
-
- def has_permission(self, permission: ActionObjectPermission) -> bool:
- """Check if the permission is inside the permission collection"""
- collection_permissions_status = self.permissions
- if collection_permissions_status.is_err():
- return False
- collection_permissions: MongoCollection = collection_permissions_status.ok()
-
- permissions: dict | None = collection_permissions.find_one(
- {"_id": permission.uid}
- )
-
- if permissions is None:
- return False
-
- if (
- permission.credentials
- and self.root_verify_key.verify == permission.credentials.verify
- ):
- return True
-
- if (
- permission.credentials
- and self.has_admin_permissions is not None
- and self.has_admin_permissions(permission.credentials)
- ):
- return True
-
- if permission.permission_string in permissions["permissions"]:
- return True
-
- # check ALL_READ permission
- if (
- permission.permission == ActionPermission.READ
- and ActionObjectPermission(
- permission.uid, ActionPermission.ALL_READ
- ).permission_string
- in permissions["permissions"]
- ):
- return True
-
- return False
-
- @as_result(SyftException)
- def _get_permissions_for_uid(self, uid: UID) -> Set[str]: # noqa: UP006
- collection_permissions = self.permissions.unwrap()
- permissions: dict | None = collection_permissions.find_one({"_id": uid})
- if permissions is None:
- raise SyftException(
- public_message=f"Permissions for object with UID {uid} not found!"
- )
- return set(permissions["permissions"])
-
- @as_result(SyftException)
- def get_all_permissions(self) -> dict[UID, Set[str]]: # noqa: UP006
- # Returns a dictionary of all permissions {object_uid: {*permissions}}
- collection_permissions: MongoCollection = self.permissions.unwrap()
- permissions = collection_permissions.find({})
- permissions_dict = {}
- for permission in permissions:
- permissions_dict[permission["_id"]] = permission["permissions"]
- return permissions_dict
-
- def add_permission(self, permission: ActionObjectPermission) -> None:
- collection_permissions = self.permissions.unwrap()
-
- # find the permissions for the given permission.uid
- # e.g. permissions = {"_id": "7b88fdef6bff42a8991d294c3d66f757",
- # "permissions": set(["permission_str_1", "permission_str_2"]}}
- permissions: dict | None = collection_permissions.find_one(
- {"_id": permission.uid}
- )
- if permissions is None:
- # Permission doesn't exist, add a new one
- collection_permissions.insert_one(
- {
- "_id": permission.uid,
- "permissions": {permission.permission_string},
- }
- )
- else:
- # update the permissions with the new permission string
- permission_strings: set = permissions["permissions"]
- permission_strings.add(permission.permission_string)
- collection_permissions.update_one(
- {"_id": permission.uid}, {"$set": {"permissions": permission_strings}}
- )
-
- def add_permissions(self, permissions: list[ActionObjectPermission]) -> None:
- for permission in permissions:
- self.add_permission(permission)
-
- def remove_permission(self, permission: ActionObjectPermission) -> None:
- collection_permissions = self.permissions.unwrap()
- permissions: dict | None = collection_permissions.find_one(
- {"_id": permission.uid}
- )
- if permissions is None:
- raise SyftException(
- public_message=f"permission with UID {permission.uid} not found!"
- )
- permissions_strings: set = permissions["permissions"]
- if permission.permission_string in permissions_strings:
- permissions_strings.remove(permission.permission_string)
- if len(permissions_strings) > 0:
- collection_permissions.update_one(
- {"_id": permission.uid},
- {"$set": {"permissions": permissions_strings}},
- )
- else:
- collection_permissions.delete_one({"_id": permission.uid})
- else:
- raise SyftException(
- public_message=f"the permission {permission.permission_string} does not exist!"
- )
-
- def add_storage_permission(self, storage_permission: StoragePermission) -> None:
- storage_permissions_collection: MongoCollection = (
- self.storage_permissions.unwrap()
- )
- storage_permissions: dict | None = storage_permissions_collection.find_one(
- {"_id": storage_permission.uid}
- )
- if storage_permissions is None:
- # Permission doesn't exist, add a new one
- storage_permissions_collection.insert_one(
- {
- "_id": storage_permission.uid,
- "server_uids": {storage_permission.server_uid},
- }
- )
- else:
- # update the permissions with the new permission string
- server_uids: set = storage_permissions["server_uids"]
- server_uids.add(storage_permission.server_uid)
- storage_permissions_collection.update_one(
- {"_id": storage_permission.uid},
- {"$set": {"server_uids": server_uids}},
- )
-
- def add_storage_permissions(self, permissions: list[StoragePermission]) -> None:
- for permission in permissions:
- self.add_storage_permission(permission)
-
- def has_storage_permission(self, permission: StoragePermission) -> bool: # type: ignore
- """Check if the storage_permission is inside the storage_permission collection"""
- storage_permissions_collection: MongoCollection = (
- self.storage_permissions.unwrap()
- )
- storage_permissions: dict | None = storage_permissions_collection.find_one(
- {"_id": permission.uid}
- )
- if storage_permissions is None or "server_uids" not in storage_permissions:
- return False
- return permission.server_uid in storage_permissions["server_uids"]
-
- def remove_storage_permission(self, storage_permission: StoragePermission) -> None:
- storage_permissions_collection = self.storage_permissions.unwrap()
- storage_permissions: dict | None = storage_permissions_collection.find_one(
- {"_id": storage_permission.uid}
- )
- if storage_permissions is None:
- raise SyftException(
- public_message=f"storage permission with UID {storage_permission.uid} not found!"
- )
- server_uids: set = storage_permissions["server_uids"]
- if storage_permission.server_uid in server_uids:
- server_uids.remove(storage_permission.server_uid)
- storage_permissions_collection.update_one(
- {"_id": storage_permission.uid},
- {"$set": {"server_uids": server_uids}},
- )
- else:
- raise SyftException(
- public_message=(
- f"the server_uid {storage_permission.server_uid} does not exist in the storage permission!"
- )
- )
-
- @as_result(SyftException)
- def _get_storage_permissions_for_uid(self, uid: UID) -> Set[UID]: # noqa: UP006
- storage_permissions_collection: MongoCollection = (
- self.storage_permissions.unwrap()
- )
- storage_permissions: dict | None = storage_permissions_collection.find_one(
- {"_id": uid}
- )
- if storage_permissions is None:
- raise SyftException(
- public_message=f"Storage permissions for object with UID {uid} not found!"
- )
- return set(storage_permissions["server_uids"])
-
- @as_result(SyftException)
- def get_all_storage_permissions(
- self,
- ) -> dict[UID, Set[UID]]: # noqa: UP006
- # Returns a dictionary of all storage permissions {object_uid: {*server_uids}}
- storage_permissions_collection: MongoCollection = (
- self.storage_permissions.unwrap()
- )
- storage_permissions = storage_permissions_collection.find({})
- storage_permissions_dict = {}
- for storage_permission in storage_permissions:
- storage_permissions_dict[storage_permission["_id"]] = storage_permission[
- "server_uids"
- ]
- return storage_permissions_dict
-
- @as_result(SyftException)
- def take_ownership(self, uid: UID, credentials: SyftVerifyKey) -> bool:
- collection_permissions: MongoCollection = self.permissions.unwrap()
- collection: MongoCollection = self.collection.unwrap()
- data: list[UID] | None = collection.find_one({"_id": uid})
- permissions: list[UID] | None = collection_permissions.find_one({"_id": uid})
-
- if permissions is not None or data is not None:
- raise SyftException(public_message=f"UID: {uid} already owned.")
-
- # first person using this UID can claim ownership
- self.add_permissions(
- [
- ActionObjectOWNER(uid=uid, credentials=credentials),
- ActionObjectWRITE(uid=uid, credentials=credentials),
- ActionObjectREAD(uid=uid, credentials=credentials),
- ActionObjectEXECUTE(uid=uid, credentials=credentials),
- ]
- )
-
- return True
-
- @as_result(SyftException)
- def _all(
- self,
- credentials: SyftVerifyKey,
- order_by: PartitionKey | None = None,
- has_permission: bool | None = False,
- ) -> list[SyftObject]:
- qks = QueryKeys(qks=())
- return self._get_all_from_store(
- credentials=credentials,
- qks=qks,
- order_by=order_by,
- has_permission=has_permission,
- ).unwrap()
-
- def __len__(self) -> int:
- collection_status = self.collection
- if collection_status.is_err():
- return 0
- collection: MongoCollection = collection_status.ok()
- return collection.count_documents(filter={})
-
- @as_result(SyftException)
- def _migrate_data(
- self, to_klass: SyftObject, context: AuthedServiceContext, has_permission: bool
- ) -> bool:
- credentials = context.credentials
- has_permission = (credentials == self.root_verify_key) or has_permission
- collection: MongoCollection = self.collection.unwrap()
-
- if has_permission:
- storage_objs = collection.find({})
- for storage_obj in storage_objs:
- obj = self.storage_type(storage_obj)
- transform_context = TransformContext(output={}, obj=obj)
- value = obj.to(self.settings.object_type, transform_context)
- key = obj.get("_id")
- try:
- migrated_value = value.migrate_to(to_klass.__version__, context)
- except Exception:
- raise SyftException(
- public_message=f"Failed to migrate data to {to_klass} for qk: {key}"
- )
- qk = self.settings.store_key.with_obj(key)
- self._update(
- credentials,
- qk=qk,
- obj=migrated_value,
- has_permission=has_permission,
- ).unwrap()
- return True
- raise SyftException(
- public_message="You don't have permissions to migrate data."
- )
-
-
-@serializable(canonical_name="MongoDocumentStore", version=1)
-class MongoDocumentStore(DocumentStore):
- """Mongo Document Store
-
- Parameters:
- `store_config`: MongoStoreConfig
- Mongo specific configuration, including connection configuration, database name, or client class type.
- """
-
- partition_type = MongoStorePartition
-
-
-@serializable(
- attrs=["index_name", "settings", "store_config"],
- canonical_name="MongoBackingStore",
- version=1,
-)
-class MongoBackingStore(KeyValueBackingStore):
- """
- Core logic for the MongoDB key-value store
-
- Parameters:
- `index_name`: str
- Index name (can be either 'data' or 'permissions')
- `settings`: PartitionSettings
- Syft specific settings
- `store_config`: StoreConfig
- Connection Configuration
- `ddtype`: Type
- Optional and should be None
- Used to make a consistent interface with SQLiteBackingStore
- """
-
- def __init__(
- self,
- index_name: str,
- settings: PartitionSettings,
- store_config: StoreConfig,
- ddtype: type | None = None,
- ) -> None:
- self.index_name = index_name
- self.settings = settings
- self.store_config = store_config
- self.client: MongoClient
- self.ddtype = ddtype
- self.init_client()
-
- @as_result(SyftException)
- def init_client(self) -> None:
- self.client = MongoClient(config=self.store_config.client_config)
- self._collection: MongoCollection = self.client.with_collection(
- collection_settings=self.settings,
- store_config=self.store_config,
- collection_name=f"{self.settings.name}_{self.index_name}",
- ).unwrap()
-
- @property
- @as_result(SyftException)
- def collection(self) -> MongoCollection:
- if not hasattr(self, "_collection"):
- self.init_client().unwrap()
- return self._collection
-
- def _exist(self, key: UID) -> bool:
- collection: MongoCollection = self.collection.unwrap()
- result: dict | None = collection.find_one({"_id": key})
- if result is not None:
- return True
- return False
-
- def _set(self, key: UID, value: Any) -> None:
- if self._exist(key):
- self._update(key, value)
- else:
- collection: MongoCollection = self.collection.unwrap()
- try:
- bson_data = {
- "_id": key,
- f"{key}": _serialize(value, to_bytes=True),
- "_repr_debug_": _repr_debug_(value),
- }
- collection.insert_one(bson_data)
- except Exception:
- raise SyftException(public_message="Cannot insert data.")
-
- def _update(self, key: UID, value: Any) -> None:
- collection: MongoCollection = self.collection.unwrap()
- try:
- collection.update_one(
- {"_id": key},
- {
- "$set": {
- f"{key}": _serialize(value, to_bytes=True),
- "_repr_debug_": _repr_debug_(value),
- }
- },
- )
- except Exception as e:
- raise SyftException(
- public_message=f"Failed to update obj: {key} with value: {value}. Error: {e}"
- )
-
- def __setitem__(self, key: Any, value: Any) -> None:
- self._set(key, value)
-
- def _get(self, key: UID) -> Any:
- collection: MongoCollection = self.collection.unwrap()
- result: dict | None = collection.find_one({"_id": key})
- if result is not None:
- return _deserialize(result[f"{key}"], from_bytes=True)
- else:
- raise KeyError(f"{key} does not exist")
-
- def __getitem__(self, key: Any) -> Self:
- try:
- return self._get(key)
- except KeyError as e:
- if self.ddtype is not None:
- return self.ddtype()
- raise e
-
- def _len(self) -> int:
- collection: MongoCollection = self.collection.unwrap()
- return collection.count_documents(filter={})
-
- def __len__(self) -> int:
- return self._len()
-
- def _delete(self, key: UID) -> SyftSuccess:
- collection: MongoCollection = self.collection.unwrap()
- result = collection.delete_one({"_id": key})
- if result.deleted_count != 1:
- raise SyftException(public_message=f"{key} does not exist")
- return SyftSuccess(message="Deleted")
-
- def __delitem__(self, key: str) -> None:
- self._delete(key)
-
- def _delete_all(self) -> None:
- collection: MongoCollection = self.collection.unwrap()
- collection.delete_many({})
-
- def clear(self) -> None:
- self._delete_all()
-
- def _get_all(self) -> Any:
- collection_status = self.collection
- if collection_status.is_err():
- return collection_status
- collection: MongoCollection = collection_status.ok()
- result = collection.find()
- keys, values = [], []
- for row in result:
- keys.append(row["_id"])
- values.append(_deserialize(row[f"{row['_id']}"], from_bytes=True))
- return dict(zip(keys, values))
-
- def keys(self) -> Any:
- return self._get_all().keys()
-
- def values(self) -> Any:
- return self._get_all().values()
-
- def items(self) -> Any:
- return self._get_all().items()
-
- def pop(self, key: Any) -> Self:
- value = self._get(key)
- self._delete(key)
- return value
-
- def __contains__(self, key: Any) -> bool:
- return self._exist(key)
-
- def __iter__(self) -> Any:
- return iter(self.keys())
-
- def __repr__(self) -> str:
- return repr(self._get_all())
-
- def copy(self) -> Self:
- # 🟡 TODO
- raise NotImplementedError
-
- def update(self, *args: Any, **kwargs: Any) -> None:
- """
- Inserts the specified items to the dictionary.
- """
- # 🟡 TODO
- raise NotImplementedError
-
- def __del__(self) -> None:
- """
- Close the mongo client connection:
- - Cleanup client resources and disconnect from MongoDB
- - End all server sessions created by this client
- - Close all sockets in the connection pools and stop the monitor threads
- """
- self.client.close()
-
-
-@serializable()
-class MongoStoreConfig(StoreConfig):
- __canonical_name__ = "MongoStoreConfig"
- """Mongo Store configuration
-
- Parameters:
- `client_config`: MongoStoreClientConfig
- Mongo connection details: hostname, port, user, password etc.
- `store_type`: Type[DocumentStore]
- The type of the DocumentStore. Default: MongoDocumentStore
- `db_name`: str
- Database name
- locking_config: LockingConfig
- The config used for store locking. Available options:
- * NoLockingConfig: no locking, ideal for single-thread stores.
- * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores.
- Defaults to NoLockingConfig.
- """
-
- client_config: MongoStoreClientConfig
- store_type: type[DocumentStore] = MongoDocumentStore
- db_name: str = "app"
- backing_store: type[KeyValueBackingStore] = MongoBackingStore
- # TODO: should use a distributed lock, with RedisLockingConfig
- locking_config: LockingConfig = Field(default_factory=NoLockingConfig)
diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py
index 2cbef952862..82d75d68e6b 100644
--- a/packages/syft/src/syft/store/sqlite_document_store.py
+++ b/packages/syft/src/syft/store/sqlite_document_store.py
@@ -7,6 +7,8 @@
import logging
from pathlib import Path
import sqlite3
+from sqlite3 import Connection
+from sqlite3 import Cursor
import tempfile
from typing import Any
@@ -40,8 +42,8 @@
# by its filename and optionally the thread that its running in
# we keep track of each SQLiteBackingStore init in REF_COUNTS
# when it hits 0 we can close the connection and release the file descriptor
-SQLITE_CONNECTION_POOL_DB: dict[str, sqlite3.Connection] = {}
-SQLITE_CONNECTION_POOL_CUR: dict[str, sqlite3.Cursor] = {}
+SQLITE_CONNECTION_POOL_DB: dict[str, Connection] = {}
+SQLITE_CONNECTION_POOL_CUR: dict[str, Cursor] = {}
REF_COUNTS: dict[str, int] = defaultdict(int)
@@ -114,6 +116,7 @@ def __init__(
self.lock = SyftLock(NoLockingConfig())
self.create_table()
REF_COUNTS[cache_key(self.db_filename)] += 1
+ self.subs_char = r"?"
@property
def table_name(self) -> str:
@@ -134,7 +137,7 @@ def _connect(self) -> None:
connection = sqlite3.connect(
self.file_path,
timeout=self.store_config.client_config.timeout,
- check_same_thread=False, # do we need this if we use the lock?
+ check_same_thread=False, # do we need this if we use the lock
# check_same_thread=self.store_config.client_config.check_same_thread,
)
# Set journal mode to WAL.
@@ -159,13 +162,13 @@ def create_table(self) -> None:
raise SyftException.from_exception(e, public_message=public_message)
@property
- def db(self) -> sqlite3.Connection:
+ def db(self) -> Connection:
if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_DB:
self._connect()
return SQLITE_CONNECTION_POOL_DB[cache_key(self.db_filename)]
@property
- def cur(self) -> sqlite3.Cursor:
+ def cur(self) -> Cursor:
if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_CUR:
SQLITE_CONNECTION_POOL_CUR[cache_key(self.db_filename)] = self.db.cursor()
@@ -191,49 +194,74 @@ def _close(self) -> None:
def _commit(self) -> None:
self.db.commit()
+ @staticmethod
@as_result(SyftException)
- def _execute(self, sql: str, *args: list[Any] | None) -> sqlite3.Cursor:
- with self.lock:
- cursor: sqlite3.Cursor | None = None
- # err = None
+ def _execute(
+ lock: SyftLock,
+ cursor: Cursor,
+ db: Connection,
+ table_name: str,
+ sql: str,
+ args: list[Any] | None,
+ ) -> Cursor:
+ with lock:
+ cur: Cursor | None = None
try:
- cursor = self.cur.execute(sql, *args)
+ cur = cursor.execute(sql, args)
except Exception as e:
- public_message = special_exception_public_message(self.table_name, e)
+ public_message = special_exception_public_message(table_name, e)
raise SyftException.from_exception(e, public_message=public_message)
- # TODO: Which exception is safe to rollback on?
+ # TODO: Which exception is safe to rollback on
# we should map out some more clear exceptions that can be returned
# rather than halting the program like disk I/O error etc
# self.db.rollback() # Roll back all changes if an exception occurs.
# err = Err(str(e))
- self.db.commit() # Commit if everything went ok
- return cursor
+ db.commit() # Commit if everything went ok
+ return cur
def _set(self, key: UID, value: Any) -> None:
if self._exists(key):
self._update(key, value)
else:
insert_sql = (
- f"insert into {self.table_name} (uid, repr, value) VALUES (?, ?, ?)" # nosec
+ f"insert into {self.table_name} (uid, repr, value) VALUES " # nosec
+ f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec
)
data = _serialize(value, to_bytes=True)
- self._execute(insert_sql, [str(key), _repr_debug_(value), data]).unwrap()
+ self._execute(
+ self.lock,
+ self.cur,
+ self.db,
+ self.table_name,
+ insert_sql,
+ [str(key), _repr_debug_(value), data],
+ ).unwrap()
def _update(self, key: UID, value: Any) -> None:
insert_sql = (
- f"update {self.table_name} set uid = ?, repr = ?, value = ? where uid = ?" # nosec
+ f"update {self.table_name} set uid = {self.subs_char}, " # nosec
+ f"repr = {self.subs_char}, value = {self.subs_char} " # nosec
+ f"where uid = {self.subs_char}" # nosec
)
data = _serialize(value, to_bytes=True)
self._execute(
- insert_sql, [str(key), _repr_debug_(value), data, str(key)]
+ self.lock,
+ self.cur,
+ self.db,
+ self.table_name,
+ insert_sql,
+ [str(key), _repr_debug_(value), data, str(key)],
).unwrap()
def _get(self, key: UID) -> Any:
- select_sql = f"select * from {self.table_name} where uid = ? order by sqltime" # nosec
- cursor = self._execute(select_sql, [str(key)]).unwrap(
- public_message=f"Query {select_sql} failed"
+ select_sql = (
+ f"select * from {self.table_name} where uid = {self.subs_char} " # nosec
+ "order by sqltime"
)
+ cursor = self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)]
+ ).unwrap(public_message=f"Query {select_sql} failed")
row = cursor.fetchone()
if row is None or len(row) == 0:
raise KeyError(f"{key} not in {type(self)}")
@@ -241,13 +269,10 @@ def _get(self, key: UID) -> Any:
return _deserialize(data, from_bytes=True)
def _exists(self, key: UID) -> bool:
- select_sql = f"select uid from {self.table_name} where uid = ?" # nosec
-
- res = self._execute(select_sql, [str(key)])
- if res.is_err():
- return False
- cursor = res.ok()
-
+ select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec
+ cursor = self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)]
+ ).unwrap()
row = cursor.fetchone() # type: ignore
if row is None:
return False
@@ -259,13 +284,11 @@ def _get_all(self) -> Any:
keys = []
data = []
- res = self._execute(select_sql)
- if res.is_err():
- return {}
- cursor = res.ok()
-
+ cursor = self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, []
+ ).unwrap()
rows = cursor.fetchall() # type: ignore
- if rows is None:
+ if not rows:
return {}
for row in rows:
@@ -276,29 +299,33 @@ def _get_all(self) -> Any:
def _get_all_keys(self) -> Any:
select_sql = f"select uid from {self.table_name} order by sqltime" # nosec
- res = self._execute(select_sql)
- if res.is_err():
- return []
- cursor = res.ok()
-
+ cursor = self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, []
+ ).unwrap()
rows = cursor.fetchall() # type: ignore
- if rows is None:
+ if not rows:
return []
keys = [UID(row[0]) for row in rows]
return keys
def _delete(self, key: UID) -> None:
- select_sql = f"delete from {self.table_name} where uid = ?" # nosec
- self._execute(select_sql, [str(key)]).unwrap()
+ select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec
+ self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)]
+ ).unwrap()
def _delete_all(self) -> None:
select_sql = f"delete from {self.table_name}" # nosec
- self._execute(select_sql).unwrap()
+ self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, []
+ ).unwrap()
def _len(self) -> int:
select_sql = f"select count(uid) from {self.table_name}" # nosec
- cursor = self._execute(select_sql).unwrap()
+ cursor = self._execute(
+ self.lock, self.cur, self.db, self.table_name, select_sql, []
+ ).unwrap()
cnt = cursor.fetchone()[0]
return cnt
@@ -369,7 +396,7 @@ class SQLiteStorePartition(KeyValueStorePartition):
def close(self) -> None:
self.lock.acquire()
try:
- # I think we don't want these now, because of the REF_COUNT?
+ # I think we don't want these now, because of the REF_COUNT
# self.data._close()
# self.unique_keys._close()
# self.searchable_keys._close()
diff --git a/packages/syft/src/syft/types/datetime.py b/packages/syft/src/syft/types/datetime.py
index c78dc04f5a2..93dd4ffc65e 100644
--- a/packages/syft/src/syft/types/datetime.py
+++ b/packages/syft/src/syft/types/datetime.py
@@ -67,6 +67,15 @@ def timedelta(self, other: "DateTime") -> timedelta:
utc_timestamp_delta = self.utc_timestamp - other.utc_timestamp
return timedelta(seconds=utc_timestamp_delta)
+ @classmethod
+ def from_timestamp(cls, ts: float) -> datetime:
+ return cls(utc_timestamp=ts)
+
+ @classmethod
+ def from_datetime(cls, dt: datetime) -> "DateTime":
+ utc_datetime = dt.astimezone(timezone.utc)
+ return cls(utc_timestamp=utc_datetime.timestamp())
+
def format_timedelta(local_timedelta: timedelta) -> str:
total_seconds = int(local_timedelta.total_seconds())
diff --git a/packages/syft/src/syft/types/result.py b/packages/syft/src/syft/types/result.py
index 52d392e48cc..bba6e80777b 100644
--- a/packages/syft/src/syft/types/result.py
+++ b/packages/syft/src/syft/types/result.py
@@ -114,7 +114,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, BE]:
if isinstance(output, Ok) or isinstance(output, Err):
raise _AsResultError(
f"Functions decorated with `as_result` should not return Result.\n"
- f"Did you forget to unwrap() the result?\n"
+ f"Did you forget to unwrap() the result in {func.__name__}?\n"
f"result: {output}"
)
return Ok(output)
diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py
index 20dedae88c6..f6a4d3233cb 100644
--- a/packages/syft/src/syft/types/syft_object.py
+++ b/packages/syft/src/syft/types/syft_object.py
@@ -396,7 +396,6 @@ class SyftObject(SyftObjectVersioned):
# all objects have a UID
id: UID
-
created_date: BaseDateTime | None = None
updated_date: BaseDateTime | None = None
deleted_date: BaseDateTime | None = None
@@ -411,6 +410,7 @@ def make_id(cls, values: Any) -> Any:
values["id"] = id_field.annotation()
return values
+ __order_by__: ClassVar[tuple[str, str]] = ("_created_at", "asc")
__attr_searchable__: ClassVar[
list[str]
] = [] # keys which can be searched in the ORM
diff --git a/packages/syft/src/syft/types/uid.py b/packages/syft/src/syft/types/uid.py
index 96bc5af31b8..de364d7b10a 100644
--- a/packages/syft/src/syft/types/uid.py
+++ b/packages/syft/src/syft/types/uid.py
@@ -154,6 +154,10 @@ def is_valid_uuid(value: Any) -> bool:
def no_dash(self) -> str:
return str(self.value).replace("-", "")
+ @property
+ def hex(self) -> str:
+ return self.value.hex
+
def __repr__(self) -> str:
"""Returns a human-readable version of the ID
diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py
index 4e93e82c45a..538614b4cb8 100644
--- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py
+++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py
@@ -24,7 +24,7 @@
def make_links(text: str) -> str:
file_pattern = re.compile(r"([\w/.-]+\.py)\", line (\d+)")
- return file_pattern.sub(r'\1, line \2', text)
+ return file_pattern.sub(r'\1, line \2', text)
DEFAULT_ID_WIDTH = 110
diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py
index eca68d13b12..56506b43fad 100644
--- a/packages/syft/tests/conftest.py
+++ b/packages/syft/tests/conftest.py
@@ -7,6 +7,7 @@
import sys
from tempfile import gettempdir
from unittest import mock
+from uuid import uuid4
# third party
from faker import Faker
@@ -22,24 +23,10 @@
from syft.protocol.data_protocol import protocol_release_dir
from syft.protocol.data_protocol import stage_protocol_changes
from syft.server.worker import Worker
+from syft.service.queue.queue_stash import QueueStash
from syft.service.user import user
-# relative
# our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support
-from .mongomock.mongo_client import MongoClient
-from .syft.stores.store_fixtures_test import dict_action_store
-from .syft.stores.store_fixtures_test import dict_document_store
-from .syft.stores.store_fixtures_test import dict_queue_stash
-from .syft.stores.store_fixtures_test import dict_store_partition
-from .syft.stores.store_fixtures_test import mongo_action_store
-from .syft.stores.store_fixtures_test import mongo_document_store
-from .syft.stores.store_fixtures_test import mongo_queue_stash
-from .syft.stores.store_fixtures_test import mongo_store_partition
-from .syft.stores.store_fixtures_test import sqlite_action_store
-from .syft.stores.store_fixtures_test import sqlite_document_store
-from .syft.stores.store_fixtures_test import sqlite_queue_stash
-from .syft.stores.store_fixtures_test import sqlite_store_partition
-from .syft.stores.store_fixtures_test import sqlite_workspace
def patch_protocol_file(filepath: Path):
@@ -129,7 +116,12 @@ def faker():
@pytest.fixture(scope="function")
def worker() -> Worker:
- worker = sy.Worker.named(name=token_hex(8))
+ """
+ NOTE in-memory sqlite is not shared between connections, so:
+ - using 2 workers (high/low) will not share a db
+ - re-using a connection (e.g. for a Job worker) will not share a db
+ """
+ worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://")
yield worker
worker.cleanup()
del worker
@@ -138,7 +130,7 @@ def worker() -> Worker:
@pytest.fixture(scope="function")
def second_worker() -> Worker:
# Used in server syncing tests
- worker = sy.Worker.named(name=token_hex(8))
+ worker = sy.Worker.named(name=uuid4().hex, db_url="sqlite://")
yield worker
worker.cleanup()
del worker
@@ -147,7 +139,7 @@ def second_worker() -> Worker:
@pytest.fixture(scope="function")
def high_worker() -> Worker:
worker = sy.Worker.named(
- name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE
+ name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE, db_url="sqlite://"
)
yield worker
worker.cleanup()
@@ -157,7 +149,10 @@ def high_worker() -> Worker:
@pytest.fixture(scope="function")
def low_worker() -> Worker:
worker = sy.Worker.named(
- name=token_hex(8), server_side_type=ServerSideType.LOW_SIDE, dev_mode=True
+ name=token_hex(8),
+ server_side_type=ServerSideType.LOW_SIDE,
+ dev_mode=True,
+ db_url="sqlite://",
)
yield worker
worker.cleanup()
@@ -212,8 +207,7 @@ def ds_verify_key(ds_client: DatasiteClient):
@pytest.fixture
def document_store(worker):
- yield worker.document_store
- worker.document_store.reset()
+ yield worker.db
@pytest.fixture
@@ -221,31 +215,6 @@ def action_store(worker):
yield worker.action_store
-@pytest.fixture(scope="session")
-def mongo_client(testrun_uid):
- """
- A race-free fixture that starts a MongoDB server for an entire pytest session.
- Cleans up the server when the session ends, or when the last client disconnects.
- """
- db_name = f"pytest_mongo_{testrun_uid}"
-
- # rand conn str
- conn_str = f"mongodb://localhost:27017/{db_name}"
-
- # create a client, and test the connection
- client = MongoClient(conn_str)
- assert client.server_info().get("ok") == 1.0
-
- yield client
-
- # stop_mongo_server(db_name)
-
-
-@pytest.fixture(autouse=True)
-def patched_mongo_client(monkeypatch):
- monkeypatch.setattr("pymongo.mongo_client.MongoClient", MongoClient)
-
-
@pytest.fixture(autouse=True)
def patched_session_cache(monkeypatch):
# patching compute heavy hashing to speed up tests
@@ -307,21 +276,18 @@ def big_dataset() -> Dataset:
yield dataset
-__all__ = [
- "mongo_store_partition",
- "mongo_document_store",
- "mongo_queue_stash",
- "mongo_action_store",
- "sqlite_store_partition",
- "sqlite_workspace",
- "sqlite_document_store",
- "sqlite_queue_stash",
- "sqlite_action_store",
- "dict_store_partition",
- "dict_action_store",
- "dict_document_store",
- "dict_queue_stash",
-]
+@pytest.fixture(
+ scope="function",
+ params=[
+ "tODOsqlite_address",
+ # "TODOpostgres_address", # will be used when we have a postgres CI tests
+ ],
+)
+def queue_stash(request):
+ _ = request.param
+ stash = QueueStash.random()
+ yield stash
+
pytest_plugins = [
"tests.syft.users.fixtures",
diff --git a/packages/syft/tests/mongomock/__init__.py b/packages/syft/tests/mongomock/__init__.py
deleted file mode 100644
index 6ce7670902b..00000000000
--- a/packages/syft/tests/mongomock/__init__.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# stdlib
-import os
-
-try:
- # third party
- from pymongo.errors import PyMongoError
-except ImportError:
-
- class PyMongoError(Exception):
- pass
-
-
-try:
- # third party
- from pymongo.errors import OperationFailure
-except ImportError:
-
- class OperationFailure(PyMongoError):
- def __init__(self, message, code=None, details=None):
- super(OperationFailure, self).__init__()
- self._message = message
- self._code = code
- self._details = details
-
- code = property(lambda self: self._code)
- details = property(lambda self: self._details)
-
- def __str__(self):
- return self._message
-
-
-try:
- # third party
- from pymongo.errors import WriteError
-except ImportError:
-
- class WriteError(OperationFailure):
- pass
-
-
-try:
- # third party
- from pymongo.errors import DuplicateKeyError
-except ImportError:
-
- class DuplicateKeyError(WriteError):
- pass
-
-
-try:
- # third party
- from pymongo.errors import BulkWriteError
-except ImportError:
-
- class BulkWriteError(OperationFailure):
- def __init__(self, results):
- super(BulkWriteError, self).__init__(
- "batch op errors occurred", 65, results
- )
-
-
-try:
- # third party
- from pymongo.errors import CollectionInvalid
-except ImportError:
-
- class CollectionInvalid(PyMongoError):
- pass
-
-
-try:
- # third party
- from pymongo.errors import InvalidName
-except ImportError:
-
- class InvalidName(PyMongoError):
- pass
-
-
-try:
- # third party
- from pymongo.errors import InvalidOperation
-except ImportError:
-
- class InvalidOperation(PyMongoError):
- pass
-
-
-try:
- # third party
- from pymongo.errors import ConfigurationError
-except ImportError:
-
- class ConfigurationError(PyMongoError):
- pass
-
-
-try:
- # third party
- from pymongo.errors import InvalidURI
-except ImportError:
-
- class InvalidURI(ConfigurationError):
- pass
-
-
-from .helpers import ObjectId, utcnow # noqa
-
-
-__all__ = [
- "Database",
- "DuplicateKeyError",
- "Collection",
- "CollectionInvalid",
- "InvalidName",
- "MongoClient",
- "ObjectId",
- "OperationFailure",
- "WriteConcern",
- "ignore_feature",
- "patch",
- "warn_on_feature",
- "SERVER_VERSION",
-]
-
-# relative
-from .collection import Collection
-from .database import Database
-from .mongo_client import MongoClient
-from .not_implemented import ignore_feature
-from .not_implemented import warn_on_feature
-from .patch import patch
-from .write_concern import WriteConcern
-
-# The version of the server faked by mongomock. Callers may patch it before creating connections to
-# update the behavior of mongomock.
-# Keep the default version in sync with docker-compose.yml and travis.yml.
-SERVER_VERSION = os.getenv("MONGODB", "5.0.5")
diff --git a/packages/syft/tests/mongomock/__init__.pyi b/packages/syft/tests/mongomock/__init__.pyi
deleted file mode 100644
index b7ba5e4c03c..00000000000
--- a/packages/syft/tests/mongomock/__init__.pyi
+++ /dev/null
@@ -1,30 +0,0 @@
-# stdlib
-from typing import Any
-from typing import Callable
-from typing import Literal
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from unittest import mock
-
-# third party
-from bson.objectid import ObjectId
-from pymongo import MongoClient
-from pymongo.collection import Collection
-from pymongo.database import Database
-from pymongo.errors import CollectionInvalid
-from pymongo.errors import DuplicateKeyError
-from pymongo.errors import InvalidName
-from pymongo.errors import OperationFailure
-
-def patch(
- servers: Union[str, Tuple[str, int], Sequence[Union[str, Tuple[str, int]]]] = ...,
- on_new: Literal["error", "create", "timeout", "pymongo"] = ...,
-) -> mock._patch: ...
-
-_FeatureName = Literal["collation", "session"]
-
-def ignore_feature(feature: _FeatureName) -> None: ...
-def warn_on_feature(feature: _FeatureName) -> None: ...
-
-SERVER_VERSION: str = ...
diff --git a/packages/syft/tests/mongomock/__version__.py b/packages/syft/tests/mongomock/__version__.py
deleted file mode 100644
index 14863a3db29..00000000000
--- a/packages/syft/tests/mongomock/__version__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# stdlib
-from platform import python_version_tuple
-
-python_version = python_version_tuple()
-
-if (int(python_version[0]), int(python_version[1])) >= (3, 8):
- # stdlib
- from importlib.metadata import version
-
- __version__ = version("mongomock")
-else:
- # third party
- import pkg_resources
-
- __version__ = pkg_resources.get_distribution("mongomock").version
diff --git a/packages/syft/tests/mongomock/aggregate.py b/packages/syft/tests/mongomock/aggregate.py
deleted file mode 100644
index 243720b690f..00000000000
--- a/packages/syft/tests/mongomock/aggregate.py
+++ /dev/null
@@ -1,1811 +0,0 @@
-"""Module to handle the operations within the aggregate pipeline."""
-
-# stdlib
-import bisect
-import collections
-import copy
-import datetime
-import decimal
-import functools
-import itertools
-import math
-import numbers
-import random
-import re
-import sys
-import warnings
-
-# third party
-from packaging import version
-import pytz
-
-# relative
-from . import OperationFailure
-from . import command_cursor
-from . import filtering
-from . import helpers
-
-try:
- # third party
- from bson import Regex
- from bson import decimal128
- from bson.errors import InvalidDocument
-
- decimal_support = True
- _RE_TYPES = (helpers.RE_TYPE, Regex)
-except ImportError:
- InvalidDocument = OperationFailure
- decimal_support = False
- _RE_TYPES = helpers.RE_TYPE
-
-_random = random.Random()
-
-
-group_operators = [
- "$addToSet",
- "$avg",
- "$first",
- "$last",
- "$max",
- "$mergeObjects",
- "$min",
- "$push",
- "$stdDevPop",
- "$stdDevSamp",
- "$sum",
-]
-unary_arithmetic_operators = {
- "$abs",
- "$ceil",
- "$exp",
- "$floor",
- "$ln",
- "$log10",
- "$sqrt",
- "$trunc",
-}
-binary_arithmetic_operators = {
- "$divide",
- "$log",
- "$mod",
- "$pow",
- "$subtract",
-}
-arithmetic_operators = (
- unary_arithmetic_operators
- | binary_arithmetic_operators
- | {
- "$add",
- "$multiply",
- }
-)
-project_operators = [
- "$max",
- "$min",
- "$avg",
- "$sum",
- "$stdDevPop",
- "$stdDevSamp",
- "$arrayElemAt",
- "$first",
- "$last",
-]
-control_flow_operators = [
- "$switch",
-]
-projection_operators = [
- "$let",
- "$literal",
-]
-date_operators = [
- "$dateFromString",
- "$dateToString",
- "$dateFromParts",
- "$dayOfMonth",
- "$dayOfWeek",
- "$dayOfYear",
- "$hour",
- "$isoDayOfWeek",
- "$isoWeek",
- "$isoWeekYear",
- "$millisecond",
- "$minute",
- "$month",
- "$second",
- "$week",
- "$year",
-]
-conditional_operators = ["$cond", "$ifNull"]
-array_operators = [
- "$concatArrays",
- "$filter",
- "$indexOfArray",
- "$map",
- "$range",
- "$reduce",
- "$reverseArray",
- "$size",
- "$slice",
- "$zip",
-]
-object_operators = [
- "$mergeObjects",
-]
-text_search_operators = ["$meta"]
-string_operators = [
- "$concat",
- "$indexOfBytes",
- "$indexOfCP",
- "$regexMatch",
- "$split",
- "$strcasecmp",
- "$strLenBytes",
- "$strLenCP",
- "$substr",
- "$substrBytes",
- "$substrCP",
- "$toLower",
- "$toUpper",
- "$trim",
-]
-comparison_operators = [
- "$cmp",
- "$eq",
- "$ne",
-] + list(filtering.SORTING_OPERATOR_MAP.keys())
-boolean_operators = ["$and", "$or", "$not"]
-set_operators = [
- "$in",
- "$setEquals",
- "$setIntersection",
- "$setDifference",
- "$setUnion",
- "$setIsSubset",
- "$anyElementTrue",
- "$allElementsTrue",
-]
-
-type_convertion_operators = [
- "$convert",
- "$toString",
- "$toInt",
- "$toDecimal",
- "$toLong",
- "$arrayToObject",
- "$objectToArray",
-]
-type_operators = [
- "$isNumber",
- "$isArray",
-]
-
-
-def _avg_operation(values):
- values_list = list(v for v in values if isinstance(v, numbers.Number))
- if not values_list:
- return None
- return sum(values_list) / float(len(list(values_list)))
-
-
-def _group_operation(values, operator):
- values_list = list(v for v in values if v is not None)
- if not values_list:
- return None
- return operator(values_list)
-
-
-def _sum_operation(values):
- values_list = list()
- if decimal_support:
- for v in values:
- if isinstance(v, numbers.Number):
- values_list.append(v)
- elif isinstance(v, decimal128.Decimal128):
- values_list.append(v.to_decimal())
- else:
- values_list = list(v for v in values if isinstance(v, numbers.Number))
- sum_value = sum(values_list)
- return (
- decimal128.Decimal128(sum_value)
- if isinstance(sum_value, decimal.Decimal)
- else sum_value
- )
-
-
-def _merge_objects_operation(values):
- merged_doc = dict()
- for v in values:
- if isinstance(v, dict):
- merged_doc.update(v)
- return merged_doc
-
-
-_GROUPING_OPERATOR_MAP = {
- "$sum": _sum_operation,
- "$avg": _avg_operation,
- "$mergeObjects": _merge_objects_operation,
- "$min": lambda values: _group_operation(values, min),
- "$max": lambda values: _group_operation(values, max),
- "$first": lambda values: values[0] if values else None,
- "$last": lambda values: values[-1] if values else None,
-}
-
-
-class _Parser(object):
- """Helper to parse expressions within the aggregate pipeline."""
-
- def __init__(self, doc_dict, user_vars=None, ignore_missing_keys=False):
- self._doc_dict = doc_dict
- self._ignore_missing_keys = ignore_missing_keys
- self._user_vars = user_vars or {}
-
- def parse(self, expression):
- """Parse a MongoDB expression."""
- if not isinstance(expression, dict):
- # May raise a KeyError despite the ignore missing key.
- return self._parse_basic_expression(expression)
-
- if len(expression) > 1 and any(key.startswith("$") for key in expression):
- raise OperationFailure(
- "an expression specification must contain exactly one field, "
- "the name of the expression. Found %d fields in %s"
- % (len(expression), expression)
- )
-
- value_dict = {}
- for k, v in expression.items():
- if k in arithmetic_operators:
- return self._handle_arithmetic_operator(k, v)
- if k in project_operators:
- return self._handle_project_operator(k, v)
- if k in projection_operators:
- return self._handle_projection_operator(k, v)
- if k in comparison_operators:
- return self._handle_comparison_operator(k, v)
- if k in date_operators:
- return self._handle_date_operator(k, v)
- if k in array_operators:
- return self._handle_array_operator(k, v)
- if k in conditional_operators:
- return self._handle_conditional_operator(k, v)
- if k in control_flow_operators:
- return self._handle_control_flow_operator(k, v)
- if k in set_operators:
- return self._handle_set_operator(k, v)
- if k in string_operators:
- return self._handle_string_operator(k, v)
- if k in type_convertion_operators:
- return self._handle_type_convertion_operator(k, v)
- if k in type_operators:
- return self._handle_type_operator(k, v)
- if k in boolean_operators:
- return self._handle_boolean_operator(k, v)
- if k in text_search_operators + projection_operators + object_operators:
- raise NotImplementedError(
- "'%s' is a valid operation but it is not supported by Mongomock yet."
- % k
- )
- if k.startswith("$"):
- raise OperationFailure("Unrecognized expression '%s'" % k)
- try:
- value = self.parse(v)
- except KeyError:
- if self._ignore_missing_keys:
- continue
- raise
- value_dict[k] = value
-
- return value_dict
-
- def parse_many(self, values):
- for value in values:
- try:
- yield self.parse(value)
- except KeyError:
- if self._ignore_missing_keys:
- yield None
- else:
- raise
-
- def _parse_to_bool(self, expression):
- """Parse a MongoDB expression and then convert it to bool"""
- # handles converting `undefined` (in form of KeyError) to False
- try:
- return helpers.mongodb_to_bool(self.parse(expression))
- except KeyError:
- return False
-
- def _parse_or_None(self, expression):
- try:
- return self.parse(expression)
- except KeyError:
- return None
-
- def _parse_basic_expression(self, expression):
- if isinstance(expression, str) and expression.startswith("$"):
- if expression.startswith("$$"):
- return helpers.get_value_by_dot(
- dict(
- {
- "ROOT": self._doc_dict,
- "CURRENT": self._doc_dict,
- },
- **self._user_vars,
- ),
- expression[2:],
- can_generate_array=True,
- )
- return helpers.get_value_by_dot(
- self._doc_dict, expression[1:], can_generate_array=True
- )
- return expression
-
- def _handle_boolean_operator(self, operator, values):
- if operator == "$and":
- return all([self._parse_to_bool(value) for value in values])
- if operator == "$or":
- return any(self._parse_to_bool(value) for value in values)
- if operator == "$not":
- return not self._parse_to_bool(values)
- # This should never happen: it is only a safe fallback if something went wrong.
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid boolean operator for the "
- "aggregation pipeline, it is currently not implemented"
- " in Mongomock." % operator
- )
-
- def _handle_arithmetic_operator(self, operator, values):
- if operator in unary_arithmetic_operators:
- try:
- number = self.parse(values)
- except KeyError:
- return None
- if number is None:
- return None
- if not isinstance(number, numbers.Number):
- raise OperationFailure(
- "Parameter to %s must evaluate to a number, got '%s'"
- % (operator, type(number))
- )
-
- if operator == "$abs":
- return abs(number)
- if operator == "$ceil":
- return math.ceil(number)
- if operator == "$exp":
- return math.exp(number)
- if operator == "$floor":
- return math.floor(number)
- if operator == "$ln":
- return math.log(number)
- if operator == "$log10":
- return math.log10(number)
- if operator == "$sqrt":
- return math.sqrt(number)
- if operator == "$trunc":
- return math.trunc(number)
-
- if operator in binary_arithmetic_operators:
- if not isinstance(values, (tuple, list)):
- raise OperationFailure(
- "Parameter to %s must evaluate to a list, got '%s'"
- % (operator, type(values))
- )
-
- if len(values) != 2:
- raise OperationFailure("%s must have only 2 parameters" % operator)
- number_0, number_1 = self.parse_many(values)
- if number_0 is None or number_1 is None:
- return None
-
- if operator == "$divide":
- return number_0 / number_1
- if operator == "$log":
- return math.log(number_0, number_1)
- if operator == "$mod":
- return math.fmod(number_0, number_1)
- if operator == "$pow":
- return math.pow(number_0, number_1)
- if operator == "$subtract":
- if isinstance(number_0, datetime.datetime) and isinstance(
- number_1, (int, float)
- ):
- number_1 = datetime.timedelta(milliseconds=number_1)
- res = number_0 - number_1
- if isinstance(res, datetime.timedelta):
- return round(res.total_seconds() * 1000)
- return res
-
- assert isinstance(values, (tuple, list)), (
- "Parameter to %s must evaluate to a list, got '%s'"
- % (
- operator,
- type(values),
- )
- )
-
- parsed_values = list(self.parse_many(values))
- assert parsed_values, "%s must have at least one parameter" % operator
- for value in parsed_values:
- if value is None:
- return None
- assert isinstance(value, numbers.Number), "%s only uses numbers" % operator
- if operator == "$add":
- return sum(parsed_values)
- if operator == "$multiply":
- return functools.reduce(lambda x, y: x * y, parsed_values)
-
- # This should never happen: it is only a safe fallback if something went wrong.
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid aritmetic operator for the aggregation "
- "pipeline, it is currently not implemented in Mongomock." % operator
- )
-
- def _handle_project_operator(self, operator, values):
- if operator in _GROUPING_OPERATOR_MAP:
- values = (
- self.parse(values)
- if isinstance(values, str)
- else self.parse_many(values)
- )
- return _GROUPING_OPERATOR_MAP[operator](values)
- if operator == "$arrayElemAt":
- key, value = values
- array = self.parse(key)
- index = self.parse(value)
- try:
- return array[index]
- except IndexError as error:
- raise KeyError("Array have length less than index value") from error
-
- raise NotImplementedError(
- "Although '%s' is a valid project operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
-
- def _handle_projection_operator(self, operator, value):
- if operator == "$literal":
- return value
- if operator == "$let":
- if not isinstance(value, dict):
- raise InvalidDocument("$let only supports an object as its argument")
- for field in ("vars", "in"):
- if field not in value:
- raise OperationFailure(
- "Missing '{}' parameter to $let".format(field)
- )
- if not isinstance(value["vars"], dict):
- raise OperationFailure("invalid parameter: expected an object (vars)")
- user_vars = {
- var_key: self.parse(var_value)
- for var_key, var_value in value["vars"].items()
- }
- return _Parser(
- self._doc_dict,
- dict(self._user_vars, **user_vars),
- ignore_missing_keys=self._ignore_missing_keys,
- ).parse(value["in"])
- raise NotImplementedError(
- "Although '%s' is a valid project operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
-
- def _handle_comparison_operator(self, operator, values):
- assert len(values) == 2, "Comparison requires two expressions"
- a = self.parse(values[0])
- b = self.parse(values[1])
- if operator == "$eq":
- return a == b
- if operator == "$ne":
- return a != b
- if operator in filtering.SORTING_OPERATOR_MAP:
- return filtering.bson_compare(
- filtering.SORTING_OPERATOR_MAP[operator], a, b
- )
- raise NotImplementedError(
- "Although '%s' is a valid comparison operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator
- )
-
- def _handle_string_operator(self, operator, values):
- if operator == "$toLower":
- parsed = self.parse(values)
- return str(parsed).lower() if parsed is not None else ""
- if operator == "$toUpper":
- parsed = self.parse(values)
- return str(parsed).upper() if parsed is not None else ""
- if operator == "$concat":
- parsed_list = list(self.parse_many(values))
- return (
- None if None in parsed_list else "".join([str(x) for x in parsed_list])
- )
- if operator == "$split":
- if len(values) != 2:
- raise OperationFailure("split must have 2 items")
- try:
- string = self.parse(values[0])
- delimiter = self.parse(values[1])
- except KeyError:
- return None
-
- if string is None or delimiter is None:
- return None
- if not isinstance(string, str):
- raise TypeError("split first argument must evaluate to string")
- if not isinstance(delimiter, str):
- raise TypeError("split second argument must evaluate to string")
- return string.split(delimiter)
- if operator == "$substr":
- if len(values) != 3:
- raise OperationFailure("substr must have 3 items")
- string = str(self.parse(values[0]))
- first = self.parse(values[1])
- length = self.parse(values[2])
- if string is None:
- return ""
- if first < 0:
- warnings.warn(
- "Negative starting point given to $substr is accepted only until "
- "MongoDB 3.7. This behavior will change in the future."
- )
- return ""
- if length < 0:
- warnings.warn(
- "Negative length given to $substr is accepted only until "
- "MongoDB 3.7. This behavior will change in the future."
- )
- second = len(string) if length < 0 else first + length
- return string[first:second]
- if operator == "$strcasecmp":
- if len(values) != 2:
- raise OperationFailure("strcasecmp must have 2 items")
- a, b = str(self.parse(values[0])), str(self.parse(values[1]))
- return 0 if a == b else -1 if a < b else 1
- if operator == "$regexMatch":
- if not isinstance(values, dict):
- raise OperationFailure(
- "$regexMatch expects an object of named arguments but found: %s"
- % type(values)
- )
- for field in ("input", "regex"):
- if field not in values:
- raise OperationFailure(
- "$regexMatch requires '%s' parameter" % field
- )
- unknown_args = set(values) - {"input", "regex", "options"}
- if unknown_args:
- raise OperationFailure(
- "$regexMatch found an unknown argument: %s" % list(unknown_args)[0]
- )
-
- try:
- input_value = self.parse(values["input"])
- except KeyError:
- return False
- if not isinstance(input_value, str):
- raise OperationFailure("$regexMatch needs 'input' to be of type string")
-
- try:
- regex_val = self.parse(values["regex"])
- except KeyError:
- return False
- options = None
- for option in values.get("options", ""):
- if option not in "imxs":
- raise OperationFailure(
- "$regexMatch invalid flag in regex options: %s" % option
- )
- re_option = getattr(re, option.upper())
- if options is None:
- options = re_option
- else:
- options |= re_option
- if isinstance(regex_val, str):
- if options is None:
- regex = re.compile(regex_val)
- else:
- regex = re.compile(regex_val, options)
- elif "options" in values and regex_val.flags:
- raise OperationFailure(
- "$regexMatch: regex option(s) specified in both 'regex' and 'option' fields"
- )
- elif isinstance(regex_val, helpers.RE_TYPE):
- if options and not regex_val.flags:
- regex = re.compile(regex_val.pattern, options)
- elif regex_val.flags & ~(re.I | re.M | re.X | re.S):
- raise OperationFailure(
- "$regexMatch invalid flag in regex options: %s"
- % regex_val.flags
- )
- else:
- regex = regex_val
- elif isinstance(regex_val, _RE_TYPES):
- # bson.Regex
- if regex_val.flags & ~(re.I | re.M | re.X | re.S):
- raise OperationFailure(
- "$regexMatch invalid flag in regex options: %s"
- % regex_val.flags
- )
- regex = re.compile(regex_val.pattern, regex_val.flags or options)
- else:
- raise OperationFailure(
- "$regexMatch needs 'regex' to be of type string or regex"
- )
-
- return bool(regex.search(input_value))
-
- # This should never happen: it is only a safe fallback if something went wrong.
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid string operator for the aggregation "
- "pipeline, it is currently not implemented in Mongomock." % operator
- )
-
- def _handle_date_operator(self, operator, values):
- if isinstance(values, dict) and values.keys() == {"date", "timezone"}:
- value = self.parse(values["date"])
- target_tz = pytz.timezone(values["timezone"])
- out_value = value.replace(tzinfo=pytz.utc).astimezone(target_tz)
- else:
- out_value = self.parse(values)
-
- if operator == "$dayOfYear":
- return out_value.timetuple().tm_yday
- if operator == "$dayOfMonth":
- return out_value.day
- if operator == "$dayOfWeek":
- return (out_value.isoweekday() % 7) + 1
- if operator == "$year":
- return out_value.year
- if operator == "$month":
- return out_value.month
- if operator == "$week":
- return int(out_value.strftime("%U"))
- if operator == "$hour":
- return out_value.hour
- if operator == "$minute":
- return out_value.minute
- if operator == "$second":
- return out_value.second
- if operator == "$millisecond":
- return int(out_value.microsecond / 1000)
- if operator == "$dateToString":
- if not isinstance(values, dict):
- raise OperationFailure(
- "$dateToString operator must correspond a dict"
- 'that has "format" and "date" field.'
- )
- if not isinstance(values, dict) or not {"format", "date"} <= set(values):
- raise OperationFailure(
- "$dateToString operator must correspond a dict"
- 'that has "format" and "date" field.'
- )
- if "%L" in out_value["format"]:
- raise NotImplementedError(
- "Although %L is a valid date format for the "
- "$dateToString operator, it is currently not implemented "
- " in Mongomock."
- )
- if "onNull" in values:
- raise NotImplementedError(
- "Although onNull is a valid field for the "
- "$dateToString operator, it is currently not implemented "
- " in Mongomock."
- )
- if "timezone" in values.keys():
- raise NotImplementedError(
- "Although timezone is a valid field for the "
- "$dateToString operator, it is currently not implemented "
- " in Mongomock."
- )
- return out_value["date"].strftime(out_value["format"])
- if operator == "$dateFromParts":
- if not isinstance(out_value, dict):
- raise OperationFailure(
- f"{operator} operator must correspond a dict "
- 'that has "year" or "isoWeekYear" field.'
- )
- if len(set(out_value) & {"year", "isoWeekYear"}) != 1:
- raise OperationFailure(
- f"{operator} operator must correspond a dict "
- 'that has "year" or "isoWeekYear" field.'
- )
- for field in ("isoWeekYear", "isoWeek", "isoDayOfWeek", "timezone"):
- if field in out_value:
- raise NotImplementedError(
- f"Although {field} is a valid field for the "
- f"{operator} operator, it is currently not implemented "
- "in Mongomock."
- )
-
- year = out_value["year"]
- month = out_value.get("month", 1) or 1
- day = out_value.get("day", 1) or 1
- hour = out_value.get("hour", 0) or 0
- minute = out_value.get("minute", 0) or 0
- second = out_value.get("second", 0) or 0
- millisecond = out_value.get("millisecond", 0) or 0
-
- return datetime.datetime(
- year=year,
- month=month,
- day=day,
- hour=hour,
- minute=minute,
- second=second,
- microsecond=millisecond,
- )
-
- raise NotImplementedError(
- "Although '%s' is a valid date operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator
- )
-
- def _handle_array_operator(self, operator, value):
- if operator == "$concatArrays":
- if not isinstance(value, (list, tuple)):
- value = [value]
-
- parsed_list = list(self.parse_many(value))
- for parsed_item in parsed_list:
- if parsed_item is not None and not isinstance(
- parsed_item, (list, tuple)
- ):
- raise OperationFailure(
- "$concatArrays only supports arrays, not {}".format(
- type(parsed_item)
- )
- )
-
- return (
- None
- if None in parsed_list
- else list(itertools.chain.from_iterable(parsed_list))
- )
-
- if operator == "$map":
- if not isinstance(value, dict):
- raise OperationFailure("$map only supports an object as its argument")
-
- # NOTE: while the two validations below could be achieved with
- # one-liner set operations (e.g. set(value) - {'input', 'as',
- # 'in'}), we prefer the iteration-based approaches in order to
- # mimic MongoDB's behavior regarding the order of evaluation. For
- # example, MongoDB complains about 'input' parameter missing before
- # 'in'.
- for k in ("input", "in"):
- if k not in value:
- raise OperationFailure("Missing '%s' parameter to $map" % k)
-
- for k in value:
- if k not in {"input", "as", "in"}:
- raise OperationFailure("Unrecognized parameter to $map: %s" % k)
-
- input_array = self._parse_or_None(value["input"])
-
- if input_array is None or input_array is None:
- return None
-
- if not isinstance(input_array, (list, tuple)):
- raise OperationFailure(
- "input to $map must be an array not %s" % type(input_array)
- )
-
- fieldname = value.get("as", "this")
- in_expr = value["in"]
- return [
- _Parser(
- self._doc_dict,
- dict(self._user_vars, **{fieldname: item}),
- ignore_missing_keys=self._ignore_missing_keys,
- ).parse(in_expr)
- for item in input_array
- ]
-
- if operator == "$size":
- if isinstance(value, list):
- if len(value) != 1:
- raise OperationFailure(
- "Expression $size takes exactly 1 arguments. "
- "%d were passed in." % len(value)
- )
- value = value[0]
- array_value = self._parse_or_None(value)
- if not isinstance(array_value, (list, tuple)):
- raise OperationFailure(
- "The argument to $size must be an array, but was of type: %s"
- % ("missing" if array_value is None else type(array_value))
- )
- return len(array_value)
-
- if operator == "$filter":
- if not isinstance(value, dict):
- raise OperationFailure(
- "$filter only supports an object as its argument"
- )
- extra_params = set(value) - {"input", "cond", "as"}
- if extra_params:
- raise OperationFailure(
- "Unrecognized parameter to $filter: %s" % extra_params.pop()
- )
- missing_params = {"input", "cond"} - set(value)
- if missing_params:
- raise OperationFailure(
- "Missing '%s' parameter to $filter" % missing_params.pop()
- )
-
- input_array = self.parse(value["input"])
- fieldname = value.get("as", "this")
- cond = value["cond"]
- return [
- item
- for item in input_array
- if _Parser(
- self._doc_dict,
- dict(self._user_vars, **{fieldname: item}),
- ignore_missing_keys=self._ignore_missing_keys,
- ).parse(cond)
- ]
- if operator == "$slice":
- if not isinstance(value, list):
- raise OperationFailure("$slice only supports a list as its argument")
- if len(value) < 2 or len(value) > 3:
- raise OperationFailure(
- "Expression $slice takes at least 2 arguments, and at most "
- "3, but {} were passed in".format(len(value))
- )
- array_value = self.parse(value[0])
- if not isinstance(array_value, list):
- raise OperationFailure(
- "First argument to $slice must be an array, but is of type: {}".format(
- type(array_value)
- )
- )
- for num, v in zip(("Second", "Third"), value[1:]):
- if not isinstance(v, int):
- raise OperationFailure(
- "{} argument to $slice must be numeric, but is of type: {}".format(
- num, type(v)
- )
- )
- if len(value) > 2 and value[2] <= 0:
- raise OperationFailure(
- "Third argument to $slice must be " "positive: {}".format(value[2])
- )
-
- start = value[1]
- if start < 0:
- if len(value) > 2:
- stop = len(array_value) + start + value[2]
- else:
- stop = None
- elif len(value) > 2:
- stop = start + value[2]
- else:
- stop = start
- start = 0
- return array_value[start:stop]
-
- raise NotImplementedError(
- "Although '%s' is a valid array operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
-
- def _handle_type_convertion_operator(self, operator, values):
- if operator == "$toString":
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
- if isinstance(parsed, bool):
- return str(parsed).lower()
- if isinstance(parsed, datetime.datetime):
- return parsed.isoformat()[:-3] + "Z"
- return str(parsed)
-
- if operator == "$toInt":
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
- if decimal_support:
- if isinstance(parsed, decimal128.Decimal128):
- return int(parsed.to_decimal())
- return int(parsed)
- raise NotImplementedError(
- "You need to import the pymongo library to support decimal128 type."
- )
-
- if operator == "$toLong":
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
- if decimal_support:
- if isinstance(parsed, decimal128.Decimal128):
- return int(parsed.to_decimal())
- return int(parsed)
- raise NotImplementedError(
- "You need to import the pymongo library to support decimal128 type."
- )
-
- # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/toDecimal/
- if operator == "$toDecimal":
- if not decimal_support:
- raise NotImplementedError(
- "You need to import the pymongo library to support decimal128 type."
- )
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
- if isinstance(parsed, bool):
- parsed = "1" if parsed is True else "0"
- decimal_value = decimal128.Decimal128(parsed)
- elif isinstance(parsed, int):
- decimal_value = decimal128.Decimal128(str(parsed))
- elif isinstance(parsed, float):
- exp = decimal.Decimal(".00000000000000")
- decimal_value = decimal.Decimal(str(parsed)).quantize(exp)
- decimal_value = decimal128.Decimal128(decimal_value)
- elif isinstance(parsed, decimal128.Decimal128):
- decimal_value = parsed
- elif isinstance(parsed, str):
- try:
- decimal_value = decimal128.Decimal128(parsed)
- except decimal.InvalidOperation as err:
- raise OperationFailure(
- "Failed to parse number '%s' in $convert with no onError value:"
- "Failed to parse string to decimal" % parsed
- ) from err
- elif isinstance(parsed, datetime.datetime):
- epoch = datetime.datetime.utcfromtimestamp(0)
- string_micro_seconds = str(
- (parsed - epoch).total_seconds() * 1000
- ).split(".", 1)[0]
- decimal_value = decimal128.Decimal128(string_micro_seconds)
- else:
- raise TypeError("'%s' type is not supported" % type(parsed))
- return decimal_value
-
- # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/arrayToObject/
- if operator == "$arrayToObject":
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
-
- if parsed is None:
- return None
-
- if not isinstance(parsed, (list, tuple)):
- raise OperationFailure(
- "$arrayToObject requires an array input, found: {}".format(
- type(parsed)
- )
- )
-
- if all(isinstance(x, dict) and set(x.keys()) == {"k", "v"} for x in parsed):
- return {d["k"]: d["v"] for d in parsed}
-
- if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in parsed):
- return dict(parsed)
-
- raise OperationFailure(
- "arrays used with $arrayToObject must contain documents "
- "with k and v fields or two-element arrays"
- )
-
- # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/objectToArray/
- if operator == "$objectToArray":
- try:
- parsed = self.parse(values)
- except KeyError:
- return None
-
- if parsed is None:
- return None
-
- if not isinstance(parsed, (dict, collections.OrderedDict)):
- raise OperationFailure(
- "$objectToArray requires an object input, found: {}".format(
- type(parsed)
- )
- )
-
- if len(parsed) > 1 and sys.version_info < (3, 6):
- raise NotImplementedError(
- "Although '%s' is a valid type conversion, it is not implemented for Python 2 "
- "and Python 3.5 in Mongomock yet." % operator
- )
-
- return [{"k": k, "v": v} for k, v in parsed.items()]
-
- raise NotImplementedError(
- "Although '%s' is a valid type conversion operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
-
- def _handle_type_operator(self, operator, values):
- # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isNumber/
- if operator == "$isNumber":
- try:
- parsed = self.parse(values)
- except KeyError:
- return False
- return (
- False
- if isinstance(parsed, bool)
- else isinstance(parsed, numbers.Number)
- )
-
- # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isArray/
- if operator == "$isArray":
- try:
- parsed = self.parse(values)
- except KeyError:
- return False
- return isinstance(parsed, (tuple, list))
-
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid type operator for the aggregation pipeline, it is currently "
- "not implemented in Mongomock." % operator
- )
-
- def _handle_conditional_operator(self, operator, values):
- # relative
- from . import SERVER_VERSION
-
- if operator == "$ifNull":
- fields = values[:-1]
- if len(fields) > 1 and version.parse(SERVER_VERSION) <= version.parse(
- "4.4"
- ):
- raise OperationFailure(
- "$ifNull supports only one input expression "
- " in MongoDB v4.4 and lower"
- )
- fallback = values[-1]
- for field in fields:
- try:
- out_value = self.parse(field)
- if out_value is not None:
- return out_value
- except KeyError:
- pass
- return self.parse(fallback)
- if operator == "$cond":
- if isinstance(values, list):
- condition, true_case, false_case = values
- elif isinstance(values, dict):
- condition = values["if"]
- true_case = values["then"]
- false_case = values["else"]
- condition_value = self._parse_to_bool(condition)
- expression = true_case if condition_value else false_case
- return self.parse(expression)
- # This should never happen: it is only a safe fallback if something went wrong.
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid conditional operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator
- )
-
- def _handle_control_flow_operator(self, operator, values):
- if operator == "$switch":
- if not isinstance(values, dict):
- raise OperationFailure(
- "$switch requires an object as an argument, "
- "found: %s" % type(values)
- )
-
- branches = values.get("branches", [])
- if not isinstance(branches, (list, tuple)):
- raise OperationFailure(
- "$switch expected an array for 'branches', "
- "found: %s" % type(branches)
- )
- if not branches:
- raise OperationFailure("$switch requires at least one branch.")
-
- for branch in branches:
- if not isinstance(branch, dict):
- raise OperationFailure(
- "$switch expected each branch to be an object, "
- "found: %s" % type(branch)
- )
- if "case" not in branch:
- raise OperationFailure(
- "$switch requires each branch have a 'case' expression"
- )
- if "then" not in branch:
- raise OperationFailure(
- "$switch requires each branch have a 'then' expression."
- )
-
- for branch in branches:
- if self._parse_to_bool(branch["case"]):
- return self.parse(branch["then"])
-
- if "default" not in values:
- raise OperationFailure(
- "$switch could not find a matching branch for an input, "
- "and no default was specified."
- )
- return self.parse(values["default"])
-
- # This should never happen: it is only a safe fallback if something went wrong.
- raise NotImplementedError( # pragma: no cover
- "Although '%s' is a valid control flow operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
-
- def _handle_set_operator(self, operator, values):
- if operator == "$in":
- expression, array = values
- return self.parse(expression) in self.parse(array)
- if operator == "$setUnion":
- result = []
- for set_value in values:
- for value in self.parse(set_value):
- if value not in result:
- result.append(value)
- return result
- if operator == "$setEquals":
- set_values = [set(self.parse(value)) for value in values]
- for set1, set2 in itertools.combinations(set_values, 2):
- if set1 != set2:
- return False
- return True
- raise NotImplementedError(
- "Although '%s' is a valid set operator for the aggregation "
- "pipeline, it is currently not implemented in Mongomock." % operator
- )
-
-
-def _parse_expression(expression, doc_dict, ignore_missing_keys=False):
- """Parse an expression.
-
- Args:
- expression: an Aggregate Expression, see
- https://docs.mongodb.com/manual/meta/aggregation-quick-reference/#aggregation-expressions.
- doc_dict: the document on which to evaluate the expression.
- ignore_missing_keys: if True, missing keys evaluated by the expression are ignored silently
- if it is possible.
- """
- return _Parser(doc_dict, ignore_missing_keys=ignore_missing_keys).parse(expression)
-
-
-filtering.register_parse_expression(_parse_expression)
-
-
-def _accumulate_group(output_fields, group_list):
- doc_dict = {}
- for field, value in output_fields.items():
- if field == "_id":
- continue
- for operator, key in value.items():
- values = []
- for doc in group_list:
- try:
- values.append(_parse_expression(key, doc))
- except KeyError:
- continue
- if operator in _GROUPING_OPERATOR_MAP:
- doc_dict[field] = _GROUPING_OPERATOR_MAP[operator](values)
- elif operator == "$addToSet":
- value = []
- val_it = (val or None for val in values)
- # Don't use set in case elt in not hashable (like dicts).
- for elt in val_it:
- if elt not in value:
- value.append(elt)
- doc_dict[field] = value
- elif operator == "$push":
- if field not in doc_dict:
- doc_dict[field] = values
- else:
- doc_dict[field].extend(values)
- elif operator in group_operators:
- raise NotImplementedError(
- "Although %s is a valid group operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator
- )
- else:
- raise NotImplementedError(
- "%s is not a valid group operator for the aggregation "
- "pipeline. See http://docs.mongodb.org/manual/meta/"
- "aggregation-quick-reference/ for a complete list of "
- "valid operators." % operator
- )
- return doc_dict
-
-
-def _fix_sort_key(key_getter):
- def fixed_getter(doc):
- key = key_getter(doc)
- # Convert dictionaries to make sorted() work in Python 3.
- if isinstance(key, dict):
- return [(k, v) for (k, v) in sorted(key.items())]
- return key
-
- return fixed_getter
-
-
-def _handle_lookup_stage(in_collection, database, options):
- for operator in ("let", "pipeline"):
- if operator in options:
- raise NotImplementedError(
- "Although '%s' is a valid lookup operator for the "
- "aggregation pipeline, it is currently not "
- "implemented in Mongomock." % operator
- )
- for operator in ("from", "localField", "foreignField", "as"):
- if operator not in options:
- raise OperationFailure("Must specify '%s' field for a $lookup" % operator)
- if not isinstance(options[operator], str):
- raise OperationFailure("Arguments to $lookup must be strings")
- if operator in ("as", "localField", "foreignField") and options[
- operator
- ].startswith("$"):
- raise OperationFailure("FieldPath field names may not start with '$'")
- if operator == "as" and "." in options[operator]:
- raise NotImplementedError(
- "Although '.' is valid in the 'as' "
- "parameters for the lookup stage of the aggregation "
- "pipeline, it is currently not implemented in Mongomock."
- )
-
- foreign_name = options["from"]
- local_field = options["localField"]
- foreign_field = options["foreignField"]
- local_name = options["as"]
- foreign_collection = database.get_collection(foreign_name)
- for doc in in_collection:
- try:
- query = helpers.get_value_by_dot(doc, local_field)
- except KeyError:
- query = None
- if isinstance(query, list):
- query = {"$in": query}
- matches = foreign_collection.find({foreign_field: query})
- doc[local_name] = [foreign_doc for foreign_doc in matches]
-
- return in_collection
-
-
-def _recursive_get(match, nested_fields):
- head = match.get(nested_fields[0])
- remaining_fields = nested_fields[1:]
- if not remaining_fields:
- # Final/last field reached.
- yield head
- return
- # More fields to go, must be list, tuple, or dict.
- if isinstance(head, (list, tuple)):
- for m in head:
- # Yield from _recursive_get(m, remaining_fields).
- for answer in _recursive_get(m, remaining_fields):
- yield answer
- elif isinstance(head, dict):
- # Yield from _recursive_get(head, remaining_fields).
- for answer in _recursive_get(head, remaining_fields):
- yield answer
-
-
-def _handle_graph_lookup_stage(in_collection, database, options):
- if not isinstance(options.get("maxDepth", 0), int):
- raise OperationFailure("Argument 'maxDepth' to $graphLookup must be a number")
- if not isinstance(options.get("restrictSearchWithMatch", {}), dict):
- raise OperationFailure(
- "Argument 'restrictSearchWithMatch' to $graphLookup must be a Dictionary"
- )
- if not isinstance(options.get("depthField", ""), str):
- raise OperationFailure("Argument 'depthField' to $graphlookup must be a string")
- if "startWith" not in options:
- raise OperationFailure("Must specify 'startWith' field for a $graphLookup")
- for operator in ("as", "connectFromField", "connectToField", "from"):
- if operator not in options:
- raise OperationFailure(
- "Must specify '%s' field for a $graphLookup" % operator
- )
- if not isinstance(options[operator], str):
- raise OperationFailure(
- "Argument '%s' to $graphLookup must be string" % operator
- )
- if options[operator].startswith("$"):
- raise OperationFailure("FieldPath field names may not start with '$'")
- if operator == "as" and "." in options[operator]:
- raise NotImplementedError(
- "Although '.' is valid in the '%s' "
- "parameter for the $graphLookup stage of the aggregation "
- "pipeline, it is currently not implemented in Mongomock." % operator
- )
-
- foreign_name = options["from"]
- start_with = options["startWith"]
- connect_from_field = options["connectFromField"]
- connect_to_field = options["connectToField"]
- local_name = options["as"]
- max_depth = options.get("maxDepth", None)
- depth_field = options.get("depthField", None)
- restrict_search_with_match = options.get("restrictSearchWithMatch", {})
- foreign_collection = database.get_collection(foreign_name)
- out_doc = copy.deepcopy(in_collection) # TODO(pascal): speed the deep copy
-
- def _find_matches_for_depth(query):
- if isinstance(query, list):
- query = {"$in": query}
- matches = foreign_collection.find({connect_to_field: query})
- new_matches = []
- for new_match in matches:
- if (
- filtering.filter_applies(restrict_search_with_match, new_match)
- and new_match["_id"] not in found_items
- ):
- if depth_field is not None:
- new_match = collections.OrderedDict(
- new_match, **{depth_field: depth}
- )
- new_matches.append(new_match)
- found_items.add(new_match["_id"])
- return new_matches
-
- for doc in out_doc:
- found_items = set()
- depth = 0
- try:
- result = _parse_expression(start_with, doc)
- except KeyError:
- continue
- origin_matches = doc[local_name] = _find_matches_for_depth(result)
- while origin_matches and (max_depth is None or depth < max_depth):
- depth += 1
- newly_discovered_matches = []
- for match in origin_matches:
- nested_fields = connect_from_field.split(".")
- for match_target in _recursive_get(match, nested_fields):
- newly_discovered_matches += _find_matches_for_depth(match_target)
- doc[local_name] += newly_discovered_matches
- origin_matches = newly_discovered_matches
- return out_doc
-
-
-def _handle_group_stage(in_collection, unused_database, options):
- grouped_collection = []
- _id = options["_id"]
- if _id:
-
- def _key_getter(doc):
- try:
- return _parse_expression(_id, doc, ignore_missing_keys=True)
- except KeyError:
- return None
-
- def _sort_key_getter(doc):
- return filtering.BsonComparable(_key_getter(doc))
-
- # Sort the collection only for the itertools.groupby.
- # $group does not order its output document.
- sorted_collection = sorted(in_collection, key=_sort_key_getter)
- grouped = itertools.groupby(sorted_collection, _key_getter)
- else:
- grouped = [(None, in_collection)]
-
- for doc_id, group in grouped:
- group_list = [x for x in group]
- doc_dict = _accumulate_group(options, group_list)
- doc_dict["_id"] = doc_id
- grouped_collection.append(doc_dict)
-
- return grouped_collection
-
-
-def _handle_bucket_stage(in_collection, unused_database, options):
- unknown_options = set(options) - {"groupBy", "boundaries", "output", "default"}
- if unknown_options:
- raise OperationFailure(
- "Unrecognized option to $bucket: %s." % unknown_options.pop()
- )
- if "groupBy" not in options or "boundaries" not in options:
- raise OperationFailure(
- "$bucket requires 'groupBy' and 'boundaries' to be specified."
- )
- group_by = options["groupBy"]
- boundaries = options["boundaries"]
- if not isinstance(boundaries, list):
- raise OperationFailure(
- "The $bucket 'boundaries' field must be an array, but found type: %s"
- % type(boundaries)
- )
- if len(boundaries) < 2:
- raise OperationFailure(
- "The $bucket 'boundaries' field must have at least 2 values, but "
- "found %d value(s)." % len(boundaries)
- )
- if sorted(boundaries) != boundaries:
- raise OperationFailure(
- "The 'boundaries' option to $bucket must be sorted in ascending order"
- )
- output_fields = options.get("output", {"count": {"$sum": 1}})
- default_value = options.get("default", None)
- try:
- is_default_last = default_value >= boundaries[-1]
- except TypeError:
- is_default_last = True
-
- def _get_default_bucket():
- try:
- return options["default"]
- except KeyError as err:
- raise OperationFailure(
- "$bucket could not find a matching branch for "
- "an input, and no default was specified."
- ) from err
-
- def _get_bucket_id(doc):
- """Get the bucket ID for a document.
-
- Note that it actually returns a tuple with the first
- param being a sort key to sort the default bucket even
- if it's not the same type as the boundaries.
- """
- try:
- value = _parse_expression(group_by, doc)
- except KeyError:
- return (is_default_last, _get_default_bucket())
- index = bisect.bisect_right(boundaries, value)
- if index and index < len(boundaries):
- return (False, boundaries[index - 1])
- return (is_default_last, _get_default_bucket())
-
- in_collection = ((_get_bucket_id(doc), doc) for doc in in_collection)
- out_collection = sorted(in_collection, key=lambda kv: kv[0])
- grouped = itertools.groupby(out_collection, lambda kv: kv[0])
-
- out_collection = []
- for (unused_key, doc_id), group in grouped:
- group_list = [kv[1] for kv in group]
- doc_dict = _accumulate_group(output_fields, group_list)
- doc_dict["_id"] = doc_id
- out_collection.append(doc_dict)
- return out_collection
-
-
-def _handle_sample_stage(in_collection, unused_database, options):
- if not isinstance(options, dict):
- raise OperationFailure("the $sample stage specification must be an object")
- size = options.pop("size", None)
- if size is None:
- raise OperationFailure("$sample stage must specify a size")
- if options:
- raise OperationFailure(
- "unrecognized option to $sample: %s" % set(options).pop()
- )
- shuffled = list(in_collection)
- _random.shuffle(shuffled)
- return shuffled[:size]
-
-
-def _handle_sort_stage(in_collection, unused_database, options):
- sort_array = reversed([{x: y} for x, y in options.items()])
- sorted_collection = in_collection
- for sort_pair in sort_array:
- for sortKey, sortDirection in sort_pair.items():
- sorted_collection = sorted(
- sorted_collection,
- key=lambda x: filtering.resolve_sort_key(sortKey, x),
- reverse=sortDirection < 0,
- )
- return sorted_collection
-
-
-def _handle_unwind_stage(in_collection, unused_database, options):
- if not isinstance(options, dict):
- options = {"path": options}
- path = options["path"]
- if not isinstance(path, str) or path[0] != "$":
- raise ValueError(
- "$unwind failed: exception: field path references must be prefixed "
- "with a '$' '%s'" % path
- )
- path = path[1:]
- should_preserve_null_and_empty = options.get("preserveNullAndEmptyArrays")
- include_array_index = options.get("includeArrayIndex")
- unwound_collection = []
- for doc in in_collection:
- try:
- array_value = helpers.get_value_by_dot(doc, path)
- except KeyError:
- if should_preserve_null_and_empty:
- unwound_collection.append(doc)
- continue
- if array_value is None:
- if should_preserve_null_and_empty:
- unwound_collection.append(doc)
- continue
- if array_value == []:
- if should_preserve_null_and_empty:
- new_doc = copy.deepcopy(doc)
- # We just ran a get_value_by_dot so we know the value exists.
- helpers.delete_value_by_dot(new_doc, path)
- unwound_collection.append(new_doc)
- continue
- if isinstance(array_value, list):
- iter_array = enumerate(array_value)
- else:
- iter_array = [(None, array_value)]
- for index, field_item in iter_array:
- new_doc = copy.deepcopy(doc)
- new_doc = helpers.set_value_by_dot(new_doc, path, field_item)
- if include_array_index:
- new_doc = helpers.set_value_by_dot(new_doc, include_array_index, index)
- unwound_collection.append(new_doc)
-
- return unwound_collection
-
-
-# TODO(pascal): Combine with the equivalent function in collection but check
-# what are the allowed overriding.
-def _combine_projection_spec(filter_list, original_filter, prefix=""):
- """Re-format a projection fields spec into a nested dictionary.
-
- e.g: ['a', 'b.c', 'b.d'] => {'a': 1, 'b': {'c': 1, 'd': 1}}
- """
- if not isinstance(filter_list, list):
- return filter_list
-
- filter_dict = collections.OrderedDict()
-
- for key in filter_list:
- field, separator, subkey = key.partition(".")
- if not separator:
- if isinstance(filter_dict.get(field), list):
- other_key = field + "." + filter_dict[field][0]
- raise OperationFailure(
- "Invalid $project :: caused by :: specification contains two conflicting paths."
- " Cannot specify both %s and %s: %s"
- % (repr(prefix + field), repr(prefix + other_key), original_filter)
- )
- filter_dict[field] = 1
- continue
- if not isinstance(filter_dict.get(field, []), list):
- raise OperationFailure(
- "Invalid $project :: caused by :: specification contains two conflicting paths."
- " Cannot specify both %s and %s: %s"
- % (repr(prefix + field), repr(prefix + key), original_filter)
- )
- filter_dict[field] = filter_dict.get(field, []) + [subkey]
-
- return collections.OrderedDict(
- (k, _combine_projection_spec(v, original_filter, prefix="%s%s." % (prefix, k)))
- for k, v in filter_dict.items()
- )
-
-
-def _project_by_spec(doc, proj_spec, is_include):
- output = {}
- for key, value in doc.items():
- if key not in proj_spec:
- if not is_include:
- output[key] = value
- continue
-
- if not isinstance(proj_spec[key], dict):
- if is_include:
- output[key] = value
- continue
-
- if isinstance(value, dict):
- output[key] = _project_by_spec(value, proj_spec[key], is_include)
- elif isinstance(value, list):
- output[key] = [
- _project_by_spec(array_value, proj_spec[key], is_include)
- for array_value in value
- if isinstance(array_value, dict)
- ]
- elif not is_include:
- output[key] = value
-
- return output
-
-
-def _handle_replace_root_stage(in_collection, unused_database, options):
- if "newRoot" not in options:
- raise OperationFailure(
- "Parameter 'newRoot' is missing for $replaceRoot operation."
- )
- new_root = options["newRoot"]
- out_collection = []
- for doc in in_collection:
- try:
- new_doc = _parse_expression(new_root, doc, ignore_missing_keys=True)
- except KeyError:
- new_doc = None
- if not isinstance(new_doc, dict):
- raise OperationFailure(
- "'newRoot' expression must evaluate to an object, but resulting value was: {}".format(
- new_doc
- )
- )
- out_collection.append(new_doc)
- return out_collection
-
-
-def _handle_project_stage(in_collection, unused_database, options):
- filter_list = []
- method = None
- include_id = options.get("_id")
- # Compute new values for each field, except inclusion/exclusions that are
- # handled in one final step.
- new_fields_collection = None
- for field, value in options.items():
- if method is None and (field != "_id" or value):
- method = "include" if value else "exclude"
- elif method == "include" and not value and field != "_id":
- raise OperationFailure(
- "Bad projection specification, cannot exclude fields "
- "other than '_id' in an inclusion projection: %s" % options
- )
- elif method == "exclude" and value:
- raise OperationFailure(
- "Bad projection specification, cannot include fields "
- "or add computed fields during an exclusion projection: %s" % options
- )
- if value in (0, 1, True, False):
- if field != "_id":
- filter_list.append(field)
- continue
- if not new_fields_collection:
- new_fields_collection = [{} for unused_doc in in_collection]
-
- for in_doc, out_doc in zip(in_collection, new_fields_collection):
- try:
- out_doc[field] = _parse_expression(
- value, in_doc, ignore_missing_keys=True
- )
- except KeyError:
- # Ignore missing key.
- pass
- if (method == "include") == (include_id is not False and include_id != 0):
- filter_list.append("_id")
-
- if not filter_list:
- return new_fields_collection
-
- # Final steps: include or exclude fields and merge with newly created fields.
- projection_spec = _combine_projection_spec(filter_list, original_filter=options)
- out_collection = [
- _project_by_spec(doc, projection_spec, is_include=(method == "include"))
- for doc in in_collection
- ]
- if new_fields_collection:
- return [dict(a, **b) for a, b in zip(out_collection, new_fields_collection)]
- return out_collection
-
-
-def _handle_add_fields_stage(in_collection, unused_database, options):
- if not options:
- raise OperationFailure(
- "Invalid $addFields :: caused by :: specification must have at least one field"
- )
- out_collection = [dict(doc) for doc in in_collection]
- for field, value in options.items():
- for in_doc, out_doc in zip(in_collection, out_collection):
- try:
- out_value = _parse_expression(value, in_doc, ignore_missing_keys=True)
- except KeyError:
- continue
- parts = field.split(".")
- for subfield in parts[:-1]:
- out_doc[subfield] = out_doc.get(subfield, {})
- if not isinstance(out_doc[subfield], dict):
- out_doc[subfield] = {}
- out_doc = out_doc[subfield]
- out_doc[parts[-1]] = out_value
- return out_collection
-
-
-def _handle_out_stage(in_collection, database, options):
- # TODO(MetrodataTeam): should leave the origin collection unchanged
- out_collection = database.get_collection(options)
- if out_collection.find_one():
- out_collection.drop()
- if in_collection:
- out_collection.insert_many(in_collection)
- return in_collection
-
-
-def _handle_count_stage(in_collection, database, options):
- if not isinstance(options, str) or options == "":
- raise OperationFailure("the count field must be a non-empty string")
- elif options.startswith("$"):
- raise OperationFailure("the count field cannot be a $-prefixed path")
- elif "." in options:
- raise OperationFailure("the count field cannot contain '.'")
- return [{options: len(in_collection)}]
-
-
-def _handle_facet_stage(in_collection, database, options):
- out_collection_by_pipeline = {}
- for pipeline_title, pipeline in options.items():
- out_collection_by_pipeline[pipeline_title] = list(
- process_pipeline(in_collection, database, pipeline, None)
- )
- return [out_collection_by_pipeline]
-
-
-def _handle_match_stage(in_collection, database, options):
- spec = helpers.patch_datetime_awareness_in_document(options)
- return [
- doc
- for doc in in_collection
- if filtering.filter_applies(
- spec, helpers.patch_datetime_awareness_in_document(doc)
- )
- ]
-
-
-_PIPELINE_HANDLERS = {
- "$addFields": _handle_add_fields_stage,
- "$bucket": _handle_bucket_stage,
- "$bucketAuto": None,
- "$collStats": None,
- "$count": _handle_count_stage,
- "$currentOp": None,
- "$facet": _handle_facet_stage,
- "$geoNear": None,
- "$graphLookup": _handle_graph_lookup_stage,
- "$group": _handle_group_stage,
- "$indexStats": None,
- "$limit": lambda c, d, o: c[:o],
- "$listLocalSessions": None,
- "$listSessions": None,
- "$lookup": _handle_lookup_stage,
- "$match": _handle_match_stage,
- "$merge": None,
- "$out": _handle_out_stage,
- "$planCacheStats": None,
- "$project": _handle_project_stage,
- "$redact": None,
- "$replaceRoot": _handle_replace_root_stage,
- "$replaceWith": None,
- "$sample": _handle_sample_stage,
- "$set": _handle_add_fields_stage,
- "$skip": lambda c, d, o: c[o:],
- "$sort": _handle_sort_stage,
- "$sortByCount": None,
- "$unset": None,
- "$unwind": _handle_unwind_stage,
-}
-
-
-def process_pipeline(collection, database, pipeline, session):
- if session:
- raise NotImplementedError("Mongomock does not handle sessions yet")
-
- for stage in pipeline:
- for operator, options in stage.items():
- try:
- handler = _PIPELINE_HANDLERS[operator]
- except KeyError as err:
- raise NotImplementedError(
- "%s is not a valid operator for the aggregation pipeline. "
- "See http://docs.mongodb.org/manual/meta/aggregation-quick-reference/ "
- "for a complete list of valid operators." % operator
- ) from err
- if not handler:
- raise NotImplementedError(
- "Although '%s' is a valid operator for the aggregation pipeline, it is "
- "currently not implemented in Mongomock." % operator
- )
- collection = handler(collection, database, options)
-
- return command_cursor.CommandCursor(collection)
diff --git a/packages/syft/tests/mongomock/codec_options.py b/packages/syft/tests/mongomock/codec_options.py
deleted file mode 100644
index e71eb41d672..00000000000
--- a/packages/syft/tests/mongomock/codec_options.py
+++ /dev/null
@@ -1,135 +0,0 @@
-"""Tools for specifying BSON codec options."""
-
-# stdlib
-import collections
-
-# third party
-from packaging import version
-
-# relative
-from . import helpers
-
-try:
- # third party
- from bson import codec_options
- from pymongo.common import _UUID_REPRESENTATIONS
-except ImportError:
- codec_options = None
- _UUID_REPRESENTATIONS = None
-
-
-class TypeRegistry(object):
- pass
-
-
-_FIELDS = (
- "document_class",
- "tz_aware",
- "uuid_representation",
- "unicode_decode_error_handler",
- "tzinfo",
-)
-
-if codec_options and helpers.PYMONGO_VERSION >= version.parse("3.8"):
- _DEFAULT_TYPE_REGISTRY = codec_options.TypeRegistry()
- _FIELDS = _FIELDS + ("type_registry",)
-else:
- _DEFAULT_TYPE_REGISTRY = TypeRegistry()
-
-if codec_options and helpers.PYMONGO_VERSION >= version.parse("4.3.0"):
- _DATETIME_CONVERSION_VALUES = codec_options.DatetimeConversion._value2member_map_
- _DATETIME_CONVERSION_DEFAULT_VALUE = codec_options.DatetimeConversion.DATETIME
- _FIELDS = _FIELDS + ("datetime_conversion",)
-else:
- _DATETIME_CONVERSION_VALUES = ()
- _DATETIME_CONVERSION_DEFAULT_VALUE = None
-
-# New default in Pymongo v4:
-# https://pymongo.readthedocs.io/en/stable/examples/uuid.html#unspecified
-if helpers.PYMONGO_VERSION >= version.parse("4.0"):
- _DEFAULT_UUID_REPRESENTATION = 0
-else:
- _DEFAULT_UUID_REPRESENTATION = 3
-
-
-class CodecOptions(collections.namedtuple("CodecOptions", _FIELDS)):
- def __new__(
- cls,
- document_class=dict,
- tz_aware=False,
- uuid_representation=None,
- unicode_decode_error_handler="strict",
- tzinfo=None,
- type_registry=None,
- datetime_conversion=_DATETIME_CONVERSION_DEFAULT_VALUE,
- ):
- if document_class != dict:
- raise NotImplementedError(
- "Mongomock does not implement custom document_class yet: %r"
- % document_class
- )
-
- if not isinstance(tz_aware, bool):
- raise TypeError("tz_aware must be True or False")
-
- if uuid_representation is None:
- uuid_representation = _DEFAULT_UUID_REPRESENTATION
-
- if unicode_decode_error_handler not in ("strict", None):
- raise NotImplementedError(
- "Mongomock does not handle custom unicode_decode_error_handler yet"
- )
-
- if tzinfo:
- raise NotImplementedError("Mongomock does not handle custom tzinfo yet")
-
- values = (
- document_class,
- tz_aware,
- uuid_representation,
- unicode_decode_error_handler,
- tzinfo,
- )
-
- if "type_registry" in _FIELDS:
- if not type_registry:
- type_registry = _DEFAULT_TYPE_REGISTRY
- values = values + (type_registry,)
-
- if "datetime_conversion" in _FIELDS:
- if (
- datetime_conversion
- and datetime_conversion not in _DATETIME_CONVERSION_VALUES
- ):
- raise TypeError(
- "datetime_conversion must be member of DatetimeConversion"
- )
- values = values + (datetime_conversion,)
-
- return tuple.__new__(cls, values)
-
- def with_options(self, **kwargs):
- opts = self._asdict()
- opts.update(kwargs)
- return CodecOptions(**opts)
-
- def to_pymongo(self):
- if not codec_options:
- return None
-
- uuid_representation = self.uuid_representation
- if _UUID_REPRESENTATIONS and isinstance(self.uuid_representation, str):
- uuid_representation = _UUID_REPRESENTATIONS[uuid_representation]
-
- return codec_options.CodecOptions(
- uuid_representation=uuid_representation,
- unicode_decode_error_handler=self.unicode_decode_error_handler,
- type_registry=self.type_registry,
- )
-
-
-def is_supported(custom_codec_options):
- if not custom_codec_options:
- return None
-
- return CodecOptions(**custom_codec_options._asdict())
diff --git a/packages/syft/tests/mongomock/collection.py b/packages/syft/tests/mongomock/collection.py
deleted file mode 100644
index 8a677300355..00000000000
--- a/packages/syft/tests/mongomock/collection.py
+++ /dev/null
@@ -1,2596 +0,0 @@
-# future
-from __future__ import division
-
-# stdlib
-import collections
-from collections import OrderedDict
-from collections.abc import Iterable
-from collections.abc import Mapping
-from collections.abc import MutableMapping
-import copy
-import functools
-import itertools
-import json
-import math
-import time
-import warnings
-
-# third party
-from packaging import version
-
-try:
- # third party
- from bson import BSON
- from bson import SON
- from bson import json_util
- from bson.codec_options import CodecOptions
- from bson.errors import InvalidDocument
-except ImportError:
- json_utils = SON = BSON = None
- CodecOptions = None
-try:
- # third party
- import execjs
-except ImportError:
- execjs = None
-
-try:
- # third party
- from pymongo import ReadPreference
- from pymongo import ReturnDocument
- from pymongo.operations import IndexModel
-
- _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY
-except ImportError:
-
- class IndexModel(object):
- pass
-
- class ReturnDocument(object):
- BEFORE = False
- AFTER = True
-
- from .read_preferences import PRIMARY as _READ_PREFERENCE_PRIMARY
-
-# relative
-from . import BulkWriteError
-from . import ConfigurationError
-from . import DuplicateKeyError
-from . import InvalidOperation
-from . import ObjectId
-from . import OperationFailure
-from . import WriteError
-from . import aggregate
-from . import codec_options as mongomock_codec_options
-from . import filtering
-from . import helpers
-from . import utcnow
-from .filtering import filter_applies
-from .not_implemented import raise_for_feature as raise_not_implemented
-from .results import BulkWriteResult
-from .results import DeleteResult
-from .results import InsertManyResult
-from .results import InsertOneResult
-from .results import UpdateResult
-from .write_concern import WriteConcern
-
-try:
- # third party
- from pymongo.read_concern import ReadConcern
-except ImportError:
- # relative
- from .read_concern import ReadConcern
-
-_KwargOption = collections.namedtuple("KwargOption", ["typename", "default", "attrs"])
-
-_WITH_OPTIONS_KWARGS = {
- "read_preference": _KwargOption(
- "pymongo.read_preference.ReadPreference",
- _READ_PREFERENCE_PRIMARY,
- ("document", "mode", "mongos_mode", "max_staleness"),
- ),
- "write_concern": _KwargOption(
- "pymongo.write_concern.WriteConcern",
- WriteConcern(),
- ("acknowledged", "document"),
- ),
-}
-
-
-def _bson_encode(document, codec_options):
- if CodecOptions:
- if isinstance(codec_options, mongomock_codec_options.CodecOptions):
- codec_options = codec_options.to_pymongo()
- if isinstance(codec_options, CodecOptions):
- BSON.encode(document, check_keys=True, codec_options=codec_options)
- else:
- BSON.encode(document, check_keys=True)
-
-
-def validate_is_mapping(option, value):
- if not isinstance(value, Mapping):
- raise TypeError(
- "%s must be an instance of dict, bson.son.SON, or "
- "other type that inherits from "
- "collections.Mapping" % (option,)
- )
-
-
-def validate_is_mutable_mapping(option, value):
- if not isinstance(value, MutableMapping):
- raise TypeError(
- "%s must be an instance of dict, bson.son.SON, or "
- "other type that inherits from "
- "collections.MutableMapping" % (option,)
- )
-
-
-def validate_ok_for_replace(replacement):
- validate_is_mapping("replacement", replacement)
- if replacement:
- first = next(iter(replacement))
- if first.startswith("$"):
- raise ValueError("replacement can not include $ operators")
-
-
-def validate_ok_for_update(update):
- validate_is_mapping("update", update)
- if not update:
- raise ValueError("update only works with $ operators")
- first = next(iter(update))
- if not first.startswith("$"):
- raise ValueError("update only works with $ operators")
-
-
-def validate_write_concern_params(**params):
- if params:
- WriteConcern(**params)
-
-
-class BulkWriteOperation(object):
- def __init__(self, builder, selector, is_upsert=False):
- self.builder = builder
- self.selector = selector
- self.is_upsert = is_upsert
-
- def upsert(self):
- assert not self.is_upsert
- return BulkWriteOperation(self.builder, self.selector, is_upsert=True)
-
- def register_remove_op(self, multi, hint=None):
- collection = self.builder.collection
- selector = self.selector
-
- def exec_remove():
- if multi:
- op_result = collection.delete_many(selector, hint=hint).raw_result
- else:
- op_result = collection.delete_one(selector, hint=hint).raw_result
- if op_result.get("ok"):
- return {"nRemoved": op_result.get("n")}
- err = op_result.get("err")
- if err:
- return {"writeErrors": [err]}
- return {}
-
- self.builder.executors.append(exec_remove)
-
- def remove(self):
- assert not self.is_upsert
- self.register_remove_op(multi=True)
-
- def remove_one(
- self,
- ):
- assert not self.is_upsert
- self.register_remove_op(multi=False)
-
- def register_update_op(self, document, multi, **extra_args):
- if not extra_args.get("remove"):
- validate_ok_for_update(document)
-
- collection = self.builder.collection
- selector = self.selector
-
- def exec_update():
- result = collection._update(
- spec=selector,
- document=document,
- multi=multi,
- upsert=self.is_upsert,
- **extra_args,
- )
- ret_val = {}
- if result.get("upserted"):
- ret_val["upserted"] = result.get("upserted")
- ret_val["nUpserted"] = result.get("n")
- else:
- matched = result.get("n")
- if matched is not None:
- ret_val["nMatched"] = matched
- modified = result.get("nModified")
- if modified is not None:
- ret_val["nModified"] = modified
- if result.get("err"):
- ret_val["err"] = result.get("err")
- return ret_val
-
- self.builder.executors.append(exec_update)
-
- def update(self, document, hint=None):
- self.register_update_op(document, multi=True, hint=hint)
-
- def update_one(self, document, hint=None):
- self.register_update_op(document, multi=False, hint=hint)
-
- def replace_one(self, document, hint=None):
- self.register_update_op(document, multi=False, remove=True, hint=hint)
-
-
-def _combine_projection_spec(projection_fields_spec):
- """Re-format a projection fields spec into a nested dictionary.
-
- e.g: {'a': 1, 'b.c': 1, 'b.d': 1} => {'a': 1, 'b': {'c': 1, 'd': 1}}
- """
-
- tmp_spec = OrderedDict()
- for f, v in projection_fields_spec.items():
- if "." not in f:
- if isinstance(tmp_spec.get(f), dict):
- if not v:
- raise NotImplementedError(
- "Mongomock does not support overriding excluding projection: %s"
- % projection_fields_spec
- )
- raise OperationFailure("Path collision at %s" % f)
- tmp_spec[f] = v
- else:
- split_field = f.split(".", 1)
- base_field, new_field = tuple(split_field)
- if not isinstance(tmp_spec.get(base_field), dict):
- if base_field in tmp_spec:
- raise OperationFailure(
- "Path collision at %s remaining portion %s" % (f, new_field)
- )
- tmp_spec[base_field] = OrderedDict()
- tmp_spec[base_field][new_field] = v
-
- combined_spec = OrderedDict()
- for f, v in tmp_spec.items():
- if isinstance(v, dict):
- combined_spec[f] = _combine_projection_spec(v)
- else:
- combined_spec[f] = v
-
- return combined_spec
-
-
-def _project_by_spec(doc, combined_projection_spec, is_include, container):
- if "$" in combined_projection_spec:
- if is_include:
- raise NotImplementedError(
- "Positional projection is not implemented in mongomock"
- )
- raise OperationFailure(
- "Cannot exclude array elements with the positional operator"
- )
-
- doc_copy = container()
-
- for key, val in doc.items():
- spec = combined_projection_spec.get(key, None)
- if isinstance(spec, dict):
- if isinstance(val, (list, tuple)):
- doc_copy[key] = [
- _project_by_spec(sub_doc, spec, is_include, container)
- for sub_doc in val
- ]
- elif isinstance(val, dict):
- doc_copy[key] = _project_by_spec(val, spec, is_include, container)
- elif (is_include and spec is not None) or (not is_include and spec is None):
- doc_copy[key] = _copy_field(val, container)
-
- return doc_copy
-
-
-def _copy_field(obj, container):
- if isinstance(obj, list):
- new = []
- for item in obj:
- new.append(_copy_field(item, container))
- return new
- if isinstance(obj, dict):
- new = container()
- for key, value in obj.items():
- new[key] = _copy_field(value, container)
- return new
- return copy.copy(obj)
-
-
-def _recursive_key_check_null_character(data):
- for key, value in data.items():
- if "\0" in key:
- raise InvalidDocument(
- f"Field names cannot contain the null character (found: {key})"
- )
- if isinstance(value, Mapping):
- _recursive_key_check_null_character(value)
-
-
-def _validate_data_fields(data):
- _recursive_key_check_null_character(data)
- for key in data.keys():
- if key.startswith("$"):
- raise InvalidDocument(
- f'Top-level field names cannot start with the "$" sign '
- f"(found: {key})"
- )
-
-
-class BulkOperationBuilder(object):
- def __init__(self, collection, ordered=False, bypass_document_validation=False):
- self.collection = collection
- self.ordered = ordered
- self.results = {}
- self.executors = []
- self.done = False
- self._insert_returns_nModified = True
- self._update_returns_nModified = True
- self._bypass_document_validation = bypass_document_validation
-
- def find(self, selector):
- return BulkWriteOperation(self, selector)
-
- def insert(self, doc):
- def exec_insert():
- self.collection.insert_one(
- doc, bypass_document_validation=self._bypass_document_validation
- )
- return {"nInserted": 1}
-
- self.executors.append(exec_insert)
-
- def __aggregate_operation_result(self, total_result, key, value):
- agg_val = total_result.get(key)
- assert agg_val is not None, (
- "Unknow operation result %s=%s" " (unrecognized key)" % (key, value)
- )
- if isinstance(agg_val, int):
- total_result[key] += value
- elif isinstance(agg_val, list):
- if key == "upserted":
- new_element = {"index": len(agg_val), "_id": value}
- agg_val.append(new_element)
- else:
- agg_val.append(value)
- else:
- assert False, (
- "Fixme: missed aggreation rule for type: %s for"
- " key {%s=%s}"
- % (
- type(agg_val),
- key,
- agg_val,
- )
- )
-
- def _set_nModified_policy(self, insert, update):
- self._insert_returns_nModified = insert
- self._update_returns_nModified = update
-
- def execute(self, write_concern=None):
- if not self.executors:
- raise InvalidOperation("Bulk operation empty!")
- if self.done:
- raise InvalidOperation("Bulk operation already executed!")
- self.done = True
- result = {
- "nModified": 0,
- "nUpserted": 0,
- "nMatched": 0,
- "writeErrors": [],
- "upserted": [],
- "writeConcernErrors": [],
- "nRemoved": 0,
- "nInserted": 0,
- }
-
- has_update = False
- has_insert = False
- broken_nModified_info = False
- for index, execute_func in enumerate(self.executors):
- exec_name = execute_func.__name__
- try:
- op_result = execute_func()
- except WriteError as error:
- result["writeErrors"].append(
- {
- "index": index,
- "code": error.code,
- "errmsg": str(error),
- }
- )
- if self.ordered:
- break
- continue
- for key, value in op_result.items():
- self.__aggregate_operation_result(result, key, value)
- if exec_name == "exec_update":
- has_update = True
- if "nModified" not in op_result:
- broken_nModified_info = True
- has_insert |= exec_name == "exec_insert"
-
- if broken_nModified_info:
- result.pop("nModified")
- elif has_insert and self._insert_returns_nModified:
- pass
- elif has_update and self._update_returns_nModified:
- pass
- elif self._update_returns_nModified and self._insert_returns_nModified:
- pass
- else:
- result.pop("nModified")
-
- if result.get("writeErrors"):
- raise BulkWriteError(result)
-
- return result
-
- def add_insert(self, doc):
- self.insert(doc)
-
- def add_update(
- self,
- selector,
- doc,
- multi=False,
- upsert=False,
- collation=None,
- array_filters=None,
- hint=None,
- ):
- if array_filters:
- raise_not_implemented(
- "array_filters", "Array filters are not implemented in mongomock yet."
- )
- write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
- write_operation.register_update_op(doc, multi, hint=hint)
-
- def add_replace(self, selector, doc, upsert, collation=None, hint=None):
- write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
- write_operation.replace_one(doc, hint=hint)
-
- def add_delete(self, selector, just_one, collation=None, hint=None):
- write_operation = BulkWriteOperation(self, selector, is_upsert=False)
- write_operation.register_remove_op(not just_one, hint=hint)
-
-
-class Collection(object):
- def __init__(
- self,
- database,
- name,
- _db_store,
- write_concern=None,
- read_concern=None,
- read_preference=None,
- codec_options=None,
- ):
- self.database = database
- self._name = name
- self._db_store = _db_store
- self._write_concern = write_concern or WriteConcern()
- if read_concern and not isinstance(read_concern, ReadConcern):
- raise TypeError(
- "read_concern must be an instance of pymongo.read_concern.ReadConcern"
- )
- self._read_concern = read_concern or ReadConcern()
- self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY
- self._codec_options = codec_options or mongomock_codec_options.CodecOptions()
-
- def __repr__(self):
- return "Collection({0}, '{1}')".format(self.database, self.name)
-
- def __getitem__(self, name):
- return self.database[self.name + "." + name]
-
- def __getattr__(self, attr):
- if attr.startswith("_"):
- raise AttributeError(
- "%s has no attribute '%s'. To access the %s.%s collection, use database['%s.%s']."
- % (self.__class__.__name__, attr, self.name, attr, self.name, attr)
- )
- return self.__getitem__(attr)
-
- def __call__(self, *args, **kwargs):
- name = self._name if "." not in self._name else self._name.split(".")[-1]
- raise TypeError(
- "'Collection' object is not callable. If you meant to call the '%s' method on a "
- "'Collection' object it is failing because no such method exists." % name
- )
-
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- return self.database == other.database and self.name == other.name
- return NotImplemented
-
- if helpers.PYMONGO_VERSION >= version.parse("3.12"):
-
- def __hash__(self):
- return hash((self.database, self.name))
-
- @property
- def full_name(self):
- return "{0}.{1}".format(self.database.name, self._name)
-
- @property
- def name(self):
- return self._name
-
- @property
- def write_concern(self):
- return self._write_concern
-
- @property
- def read_concern(self):
- return self._read_concern
-
- @property
- def read_preference(self):
- return self._read_preference
-
- @property
- def codec_options(self):
- return self._codec_options
-
- def initialize_unordered_bulk_op(self, bypass_document_validation=False):
- return BulkOperationBuilder(
- self, ordered=False, bypass_document_validation=bypass_document_validation
- )
-
- def initialize_ordered_bulk_op(self, bypass_document_validation=False):
- return BulkOperationBuilder(
- self, ordered=True, bypass_document_validation=bypass_document_validation
- )
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def insert(
- self,
- data,
- manipulate=True,
- check_keys=True,
- continue_on_error=False,
- **kwargs,
- ):
- warnings.warn(
- "insert is deprecated. Use insert_one or insert_many " "instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- validate_write_concern_params(**kwargs)
- return self._insert(data)
-
- def insert_one(self, document, bypass_document_validation=False, session=None):
- if not bypass_document_validation:
- validate_is_mutable_mapping("document", document)
- return InsertOneResult(self._insert(document, session), acknowledged=True)
-
- def insert_many(
- self, documents, ordered=True, bypass_document_validation=False, session=None
- ):
- if not isinstance(documents, Iterable) or not documents:
- raise TypeError("documents must be a non-empty list")
- documents = list(documents)
- if not bypass_document_validation:
- for document in documents:
- validate_is_mutable_mapping("document", document)
- return InsertManyResult(
- self._insert(documents, session, ordered=ordered), acknowledged=True
- )
-
- @property
- def _store(self):
- return self._db_store[self._name]
-
- def _insert(self, data, session=None, ordered=True):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- if not isinstance(data, Mapping):
- results = []
- write_errors = []
- num_inserted = 0
- for index, item in enumerate(data):
- try:
- results.append(self._insert(item))
- except WriteError as error:
- write_errors.append(
- {
- "index": index,
- "code": error.code,
- "errmsg": str(error),
- "op": item,
- }
- )
- if ordered:
- break
- else:
- continue
- num_inserted += 1
- if write_errors:
- raise BulkWriteError(
- {
- "writeErrors": write_errors,
- "nInserted": num_inserted,
- }
- )
- return results
-
- if not all(isinstance(k, str) for k in data):
- raise ValueError("Document keys must be strings")
-
- if BSON:
- # bson validation
- check_keys = helpers.PYMONGO_VERSION < version.parse("3.6")
- if not check_keys:
- _validate_data_fields(data)
-
- _bson_encode(data, self._codec_options)
-
- # Like pymongo, we should fill the _id in the inserted dict (odd behavior,
- # but we need to stick to it), so we must patch in-place the data dict
- if "_id" not in data:
- data["_id"] = ObjectId()
-
- object_id = data["_id"]
- if isinstance(object_id, dict):
- object_id = helpers.hashdict(object_id)
- if object_id in self._store:
- raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
-
- data = helpers.patch_datetime_awareness_in_document(data)
-
- self._store[object_id] = data
- try:
- self._ensure_uniques(data)
- except DuplicateKeyError:
- # Rollback
- del self._store[object_id]
- raise
- return data["_id"]
-
- def _ensure_uniques(self, new_data):
- # Note we consider new_data is already inserted in db
- for index in self._store.indexes.values():
- if not index.get("unique"):
- continue
- unique = index.get("key")
- is_sparse = index.get("sparse")
- partial_filter_expression = index.get("partialFilterExpression")
- find_kwargs = {}
- for key, _ in unique:
- try:
- find_kwargs[key] = helpers.get_value_by_dot(new_data, key)
- except KeyError:
- find_kwargs[key] = None
- if is_sparse and set(find_kwargs.values()) == {None}:
- continue
- if partial_filter_expression is not None:
- find_kwargs = {"$and": [partial_filter_expression, find_kwargs]}
- answer_count = len(list(self._iter_documents(find_kwargs)))
- if answer_count > 1:
- raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
-
- def _internalize_dict(self, d):
- return {k: copy.deepcopy(v) for k, v in d.items()}
-
- def _has_key(self, doc, key):
- key_parts = key.split(".")
- sub_doc = doc
- for part in key_parts:
- if part not in sub_doc:
- return False
- sub_doc = sub_doc[part]
- return True
-
- def update_one(
- self,
- filter,
- update,
- upsert=False,
- bypass_document_validation=False,
- collation=None,
- array_filters=None,
- hint=None,
- session=None,
- let=None,
- ):
- if not bypass_document_validation:
- validate_ok_for_update(update)
- return UpdateResult(
- self._update(
- filter,
- update,
- upsert=upsert,
- hint=hint,
- session=session,
- collation=collation,
- array_filters=array_filters,
- let=let,
- ),
- acknowledged=True,
- )
-
- def update_many(
- self,
- filter,
- update,
- upsert=False,
- array_filters=None,
- bypass_document_validation=False,
- collation=None,
- hint=None,
- session=None,
- let=None,
- ):
- if not bypass_document_validation:
- validate_ok_for_update(update)
- return UpdateResult(
- self._update(
- filter,
- update,
- upsert=upsert,
- multi=True,
- hint=hint,
- session=session,
- collation=collation,
- array_filters=array_filters,
- let=let,
- ),
- acknowledged=True,
- )
-
- def replace_one(
- self,
- filter,
- replacement,
- upsert=False,
- bypass_document_validation=False,
- session=None,
- hint=None,
- ):
- if not bypass_document_validation:
- validate_ok_for_replace(replacement)
- return UpdateResult(
- self._update(
- filter, replacement, upsert=upsert, hint=hint, session=session
- ),
- acknowledged=True,
- )
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def update(
- self,
- spec,
- document,
- upsert=False,
- manipulate=False,
- multi=False,
- check_keys=False,
- **kwargs,
- ):
- warnings.warn(
- "update is deprecated. Use replace_one, update_one or "
- "update_many instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return self._update(
- spec, document, upsert, manipulate, multi, check_keys, **kwargs
- )
-
- def _update(
- self,
- spec,
- document,
- upsert=False,
- manipulate=False,
- multi=False,
- check_keys=False,
- hint=None,
- session=None,
- collation=None,
- let=None,
- array_filters=None,
- **kwargs,
- ):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- if hint:
- raise NotImplementedError(
- "The hint argument of update is valid but has not been implemented in "
- "mongomock yet"
- )
- if collation:
- raise_not_implemented(
- "collation",
- "The collation argument of update is valid but has not been implemented in "
- "mongomock yet",
- )
- if array_filters:
- raise_not_implemented(
- "array_filters", "Array filters are not implemented in mongomock yet."
- )
- if let:
- raise_not_implemented(
- "let",
- "The let argument of update is valid but has not been implemented in mongomock "
- "yet",
- )
- spec = helpers.patch_datetime_awareness_in_document(spec)
- document = helpers.patch_datetime_awareness_in_document(document)
- validate_is_mapping("spec", spec)
- validate_is_mapping("document", document)
-
- if self.database.client.server_info()["versionArray"] < [5]:
- for operator in _updaters:
- if not document.get(operator, True):
- raise WriteError(
- "'%s' is empty. You must specify a field like so: {%s: {: ...}}"
- % (operator, operator),
- )
-
- updated_existing = False
- upserted_id = None
- num_updated = 0
- num_matched = 0
- for existing_document in itertools.chain(self._iter_documents(spec), [None]):
- # we need was_insert for the setOnInsert update operation
- was_insert = False
- # the sentinel document means we should do an upsert
- if existing_document is None:
- if not upsert or num_matched:
- continue
- # For upsert operation we have first to create a fake existing_document,
- # update it like a regular one, then finally insert it
- if spec.get("_id") is not None:
- _id = spec["_id"]
- elif document.get("_id") is not None:
- _id = document["_id"]
- else:
- _id = ObjectId()
- to_insert = dict(spec, _id=_id)
- to_insert = self._expand_dots(to_insert)
- to_insert, _ = self._discard_operators(to_insert)
- existing_document = to_insert
- was_insert = True
- else:
- original_document_snapshot = copy.deepcopy(existing_document)
- updated_existing = True
- num_matched += 1
- first = True
- subdocument = None
- for k, v in document.items():
- if k in _updaters:
- updater = _updaters[k]
- subdocument = (
- self._update_document_fields_with_positional_awareness(
- existing_document, v, spec, updater, subdocument
- )
- )
-
- elif k == "$rename":
- for src, dst in v.items():
- if "." in src or "." in dst:
- raise NotImplementedError(
- "Using the $rename operator with dots is a valid MongoDB "
- "operation, but it is not yet supported by mongomock"
- )
- if self._has_key(existing_document, src):
- existing_document[dst] = existing_document.pop(src)
-
- elif k == "$setOnInsert":
- if not was_insert:
- continue
- subdocument = (
- self._update_document_fields_with_positional_awareness(
- existing_document, v, spec, _set_updater, subdocument
- )
- )
-
- elif k == "$currentDate":
- subdocument = (
- self._update_document_fields_with_positional_awareness(
- existing_document,
- v,
- spec,
- _current_date_updater,
- subdocument,
- )
- )
-
- elif k == "$addToSet":
- for field, value in v.items():
- nested_field_list = field.rsplit(".")
- if len(nested_field_list) == 1:
- if field not in existing_document:
- existing_document[field] = []
- # document should be a list append to it
- if isinstance(value, dict):
- if "$each" in value:
- # append the list to the field
- existing_document[field] += [
- obj
- for obj in list(value["$each"])
- if obj not in existing_document[field]
- ]
- continue
- if value not in existing_document[field]:
- existing_document[field].append(value)
- continue
- # push to array in a nested attribute
- else:
- # create nested attributes if they do not exist
- subdocument = existing_document
- for field_part in nested_field_list[:-1]:
- if field_part == "$":
- break
- if field_part not in subdocument:
- subdocument[field_part] = {}
-
- subdocument = subdocument[field_part]
-
- # get subdocument with $ oprator support
- subdocument, _ = self._get_subdocument(
- existing_document, spec, nested_field_list
- )
-
- # we're pushing a list
- push_results = []
- if nested_field_list[-1] in subdocument:
- # if the list exists, then use that list
- push_results = subdocument[nested_field_list[-1]]
-
- if isinstance(value, dict) and "$each" in value:
- push_results += [
- obj
- for obj in list(value["$each"])
- if obj not in push_results
- ]
- elif value not in push_results:
- push_results.append(value)
-
- subdocument[nested_field_list[-1]] = push_results
- elif k == "$pull":
- for field, value in v.items():
- nested_field_list = field.rsplit(".")
- # nested fields includes a positional element
- # need to find that element
- if "$" in nested_field_list:
- if not subdocument:
- subdocument, _ = self._get_subdocument(
- existing_document, spec, nested_field_list
- )
-
- # value should be a dictionary since we're pulling
- pull_results = []
- # and the last subdoc should be an array
- for obj in subdocument[nested_field_list[-1]]:
- if isinstance(obj, dict):
- for pull_key, pull_value in value.items():
- if obj[pull_key] != pull_value:
- pull_results.append(obj)
- continue
- if obj != value:
- pull_results.append(obj)
-
- # cannot write to doc directly as it doesn't save to
- # existing_document
- subdocument[nested_field_list[-1]] = pull_results
- else:
- arr = existing_document
- for field_part in nested_field_list:
- if field_part not in arr:
- break
- arr = arr[field_part]
- if not isinstance(arr, list):
- continue
-
- arr_copy = copy.deepcopy(arr)
- if isinstance(value, dict):
- for obj in arr_copy:
- try:
- is_matching = filter_applies(value, obj)
- except OperationFailure:
- is_matching = False
- if is_matching:
- arr.remove(obj)
- continue
-
- if filter_applies({"field": value}, {"field": obj}):
- arr.remove(obj)
- else:
- for obj in arr_copy:
- if value == obj:
- arr.remove(obj)
- elif k == "$pullAll":
- for field, value in v.items():
- nested_field_list = field.rsplit(".")
- if len(nested_field_list) == 1:
- if field in existing_document:
- arr = existing_document[field]
- existing_document[field] = [
- obj for obj in arr if obj not in value
- ]
- continue
- else:
- subdocument, _ = self._get_subdocument(
- existing_document, spec, nested_field_list
- )
-
- if nested_field_list[-1] in subdocument:
- arr = subdocument[nested_field_list[-1]]
- subdocument[nested_field_list[-1]] = [
- obj for obj in arr if obj not in value
- ]
- elif k == "$push":
- for field, value in v.items():
- # Find the place where to push.
- nested_field_list = field.rsplit(".")
- subdocument, field = self._get_subdocument(
- existing_document, spec, nested_field_list
- )
-
- # Push the new element or elements.
- if isinstance(subdocument, dict) and field not in subdocument:
- subdocument[field] = []
- push_results = subdocument[field]
- if isinstance(value, dict) and "$each" in value:
- if "$position" in value:
- push_results = (
- push_results[0 : value["$position"]]
- + list(value["$each"])
- + push_results[value["$position"] :]
- )
- else:
- push_results += list(value["$each"])
-
- if "$sort" in value:
- sort_spec = value["$sort"]
- if isinstance(sort_spec, dict):
- sort_key = set(sort_spec.keys()).pop()
- push_results = sorted(
- push_results,
- key=lambda d: helpers.get_value_by_dot(
- d, sort_key
- ),
- reverse=set(sort_spec.values()).pop() < 0,
- )
- else:
- push_results = sorted(
- push_results, reverse=sort_spec < 0
- )
-
- if "$slice" in value:
- slice_value = value["$slice"]
- if slice_value < 0:
- push_results = push_results[slice_value:]
- elif slice_value == 0:
- push_results = []
- else:
- push_results = push_results[:slice_value]
-
- unused_modifiers = set(value.keys()) - {
- "$each",
- "$slice",
- "$position",
- "$sort",
- }
- if unused_modifiers:
- raise WriteError(
- "Unrecognized clause in $push: "
- + unused_modifiers.pop()
- )
- else:
- push_results.append(value)
- subdocument[field] = push_results
- else:
- if first:
- # replace entire document
- for key in document.keys():
- if key.startswith("$"):
- # can't mix modifiers with non-modifiers in
- # update
- raise ValueError(
- "field names cannot start with $ [{}]".format(k)
- )
- _id = spec.get("_id", existing_document.get("_id"))
- existing_document.clear()
- if _id is not None:
- existing_document["_id"] = _id
- if BSON:
- # bson validation
- check_keys = helpers.PYMONGO_VERSION < version.parse("3.6")
- if not check_keys:
- _validate_data_fields(document)
- _bson_encode(document, self.codec_options)
- existing_document.update(self._internalize_dict(document))
- if existing_document["_id"] != _id:
- raise OperationFailure(
- "The _id field cannot be changed from {0} to {1}".format(
- existing_document["_id"], _id
- )
- )
- break
- else:
- # can't mix modifiers with non-modifiers in update
- raise ValueError("Invalid modifier specified: {}".format(k))
- first = False
- # if empty document comes
- if not document:
- _id = spec.get("_id", existing_document.get("_id"))
- existing_document.clear()
- if _id:
- existing_document["_id"] = _id
-
- if was_insert:
- upserted_id = self._insert(existing_document)
- num_updated += 1
- elif existing_document != original_document_snapshot:
- # Document has been modified in-place.
-
- # Make sure the ID was not change.
- if original_document_snapshot.get("_id") != existing_document.get(
- "_id"
- ):
- # Rollback.
- self._store[original_document_snapshot["_id"]] = (
- original_document_snapshot
- )
- raise WriteError(
- "After applying the update, the (immutable) field '_id' was found to have "
- "been altered to _id: {}".format(existing_document.get("_id"))
- )
-
- # Make sure it still respect the unique indexes and, if not, to
- # revert modifications
- try:
- self._ensure_uniques(existing_document)
- num_updated += 1
- except DuplicateKeyError:
- # Rollback.
- self._store[original_document_snapshot["_id"]] = (
- original_document_snapshot
- )
- raise
-
- if not multi:
- break
-
- return {
- "connectionId": self.database.client._id,
- "err": None,
- "n": num_matched,
- "nModified": num_updated if updated_existing else 0,
- "ok": 1,
- "upserted": upserted_id,
- "updatedExisting": updated_existing,
- }
-
- def _get_subdocument(self, existing_document, spec, nested_field_list):
- """This method retrieves the subdocument of the existing_document.nested_field_list.
-
- It uses the spec to filter through the items. It will continue to grab nested documents
- until it can go no further. It will then return the subdocument that was last saved.
- '$' is the positional operator, so we use the $elemMatch in the spec to find the right
- subdocument in the array.
- """
- # Current document in view.
- doc = existing_document
- # Previous document in view.
- parent_doc = existing_document
- # Current spec in view.
- subspec = spec
- # Whether spec is following the document.
- is_following_spec = True
- # Walk down the dictionary.
- for index, subfield in enumerate(nested_field_list):
- if subfield == "$":
- if not is_following_spec:
- raise WriteError(
- "The positional operator did not find the match needed from the query"
- )
- # Positional element should have the equivalent elemMatch in the query.
- subspec = subspec["$elemMatch"]
- is_following_spec = False
- # Iterate through.
- for spec_index, item in enumerate(doc):
- if filter_applies(subspec, item):
- subfield = spec_index
- break
- else:
- raise WriteError(
- "The positional operator did not find the match needed from the query"
- )
-
- parent_doc = doc
- if isinstance(parent_doc, list):
- subfield = int(subfield)
- if is_following_spec and (subfield < 0 or subfield >= len(subspec)):
- is_following_spec = False
-
- if index == len(nested_field_list) - 1:
- return parent_doc, subfield
-
- if not isinstance(parent_doc, list):
- if subfield not in parent_doc:
- parent_doc[subfield] = {}
- if is_following_spec and subfield not in subspec:
- is_following_spec = False
-
- doc = parent_doc[subfield]
- if is_following_spec:
- subspec = subspec[subfield]
-
- def _expand_dots(self, doc):
- expanded = {}
- paths = {}
- for k, v in doc.items():
-
- def _raise_incompatible(subkey):
- raise WriteError(
- "cannot infer query fields to set, both paths '%s' and '%s' are matched"
- % (k, paths[subkey])
- )
-
- if k in paths:
- _raise_incompatible(k)
-
- key_parts = k.split(".")
- sub_expanded = expanded
-
- paths[k] = k
- for i, key_part in enumerate(key_parts[:-1]):
- if key_part not in sub_expanded:
- sub_expanded[key_part] = {}
- sub_expanded = sub_expanded[key_part]
- key = ".".join(key_parts[: i + 1])
- if not isinstance(sub_expanded, dict):
- _raise_incompatible(key)
- paths[key] = k
- sub_expanded[key_parts[-1]] = v
- return expanded
-
- def _discard_operators(self, doc):
- if not doc or not isinstance(doc, dict):
- return doc, False
- new_doc = OrderedDict()
- for k, v in doc.items():
- if k == "$eq":
- return v, False
- if k.startswith("$"):
- continue
- new_v, discarded = self._discard_operators(v)
- if not discarded:
- new_doc[k] = new_v
- return new_doc, not bool(new_doc)
-
- def find(
- self,
- filter=None,
- projection=None,
- skip=0,
- limit=0,
- no_cursor_timeout=False,
- cursor_type=None,
- sort=None,
- allow_partial_results=False,
- oplog_replay=False,
- modifiers=None,
- batch_size=0,
- manipulate=True,
- collation=None,
- session=None,
- max_time_ms=None,
- allow_disk_use=False,
- **kwargs,
- ):
- spec = filter
- if spec is None:
- spec = {}
- validate_is_mapping("filter", spec)
- for kwarg, value in kwargs.items():
- if value:
- raise OperationFailure("Unrecognized field '%s'" % kwarg)
- return (
- Cursor(self, spec, sort, projection, skip, limit, collation=collation)
- .max_time_ms(max_time_ms)
- .allow_disk_use(allow_disk_use)
- )
-
- def _get_dataset(self, spec, sort, fields, as_class):
- dataset = self._iter_documents(spec)
- if sort:
- for sort_key, sort_direction in reversed(sort):
- if sort_key == "$natural":
- if sort_direction < 0:
- dataset = iter(reversed(list(dataset)))
- continue
- if sort_key.startswith("$"):
- raise NotImplementedError(
- "Sorting by {} is not implemented in mongomock yet".format(
- sort_key
- )
- )
- dataset = iter(
- sorted(
- dataset,
- key=lambda x: filtering.resolve_sort_key(sort_key, x),
- reverse=sort_direction < 0,
- )
- )
- for document in dataset:
- yield self._copy_only_fields(document, fields, as_class)
-
- def _extract_projection_operators(self, fields):
- """Removes and returns fields with projection operators."""
- result = {}
- allowed_projection_operators = {"$elemMatch", "$slice"}
- for key, value in fields.items():
- if isinstance(value, dict):
- for op in value:
- if op not in allowed_projection_operators:
- raise ValueError("Unsupported projection option: {}".format(op))
- result[key] = value
-
- for key in result:
- del fields[key]
-
- return result
-
- def _apply_projection_operators(self, ops, doc, doc_copy):
- """Applies projection operators to copied document."""
- for field, op in ops.items():
- if field not in doc_copy:
- if field in doc:
- # field was not copied yet (since we are in include mode)
- doc_copy[field] = doc[field]
- else:
- # field doesn't exist in original document, no work to do
- continue
-
- if "$slice" in op:
- if not isinstance(doc_copy[field], list):
- raise OperationFailure(
- "Unsupported type {} for slicing operation: {}".format(
- type(doc_copy[field]), op
- )
- )
- op_value = op["$slice"]
- slice_ = None
- if isinstance(op_value, list):
- if len(op_value) != 2:
- raise OperationFailure(
- "Unsupported slice format {} for slicing operation: {}".format(
- op_value, op
- )
- )
- skip, limit = op_value
- if skip < 0:
- skip = len(doc_copy[field]) + skip
- last = min(skip + limit, len(doc_copy[field]))
- slice_ = slice(skip, last)
- elif isinstance(op_value, int):
- count = op_value
- start = 0
- end = len(doc_copy[field])
- if count < 0:
- start = max(0, len(doc_copy[field]) + count)
- else:
- end = min(count, len(doc_copy[field]))
- slice_ = slice(start, end)
-
- if slice_:
- doc_copy[field] = doc_copy[field][slice_]
- else:
- raise OperationFailure(
- "Unsupported slice value {} for slicing operation: {}".format(
- op_value, op
- )
- )
-
- if "$elemMatch" in op:
- if isinstance(doc_copy[field], list):
- # find the first item that matches
- matched = False
- for item in doc_copy[field]:
- if filter_applies(op["$elemMatch"], item):
- matched = True
- doc_copy[field] = [item]
- break
-
- # None have matched
- if not matched:
- del doc_copy[field]
-
- else:
- # remove the field since there is None to iterate
- del doc_copy[field]
-
- def _copy_only_fields(self, doc, fields, container):
- """Copy only the specified fields."""
-
- # https://pymongo.readthedocs.io/en/stable/migrate-to-pymongo4.html#collection-find-returns-entire-document-with-empty-projection
- if (
- fields is None
- or not fields
- and helpers.PYMONGO_VERSION >= version.parse("4.0")
- ):
- return _copy_field(doc, container)
-
- if not fields:
- fields = {"_id": 1}
- if not isinstance(fields, dict):
- fields = helpers.fields_list_to_dict(fields)
-
- # we can pass in something like {'_id':0, 'field':1}, so pull the id
- # value out and hang on to it until later
- id_value = fields.pop("_id", 1)
-
- # filter out fields with projection operators, we will take care of them later
- projection_operators = self._extract_projection_operators(fields)
-
- # other than the _id field, all fields must be either includes or
- # excludes, this can evaluate to 0
- if len(set(list(fields.values()))) > 1:
- raise ValueError("You cannot currently mix including and excluding fields.")
-
- # if we have novalues passed in, make a doc_copy based on the
- # id_value
- if not fields:
- if id_value == 1:
- doc_copy = container()
- else:
- doc_copy = _copy_field(doc, container)
- else:
- doc_copy = _project_by_spec(
- doc,
- _combine_projection_spec(fields),
- is_include=list(fields.values())[0],
- container=container,
- )
-
- # set the _id value if we requested it, otherwise remove it
- if id_value == 0:
- doc_copy.pop("_id", None)
- else:
- if "_id" in doc:
- doc_copy["_id"] = doc["_id"]
-
- fields["_id"] = id_value # put _id back in fields
-
- # time to apply the projection operators and put back their fields
- self._apply_projection_operators(projection_operators, doc, doc_copy)
- for field, op in projection_operators.items():
- fields[field] = op
- return doc_copy
-
- def _update_document_fields(self, doc, fields, updater):
- """Implements the $set behavior on an existing document"""
- for k, v in fields.items():
- self._update_document_single_field(doc, k, v, updater)
-
- def _update_document_fields_positional(
- self, doc, fields, spec, updater, subdocument=None
- ):
- """Implements the $set behavior on an existing document"""
- for k, v in fields.items():
- if "$" in k:
- field_name_parts = k.split(".")
- if not subdocument:
- current_doc = doc
- subspec = spec
- for part in field_name_parts[:-1]:
- if part == "$":
- subspec_dollar = subspec.get("$elemMatch", subspec)
- for item in current_doc:
- if filter_applies(subspec_dollar, item):
- current_doc = item
- break
- continue
-
- new_spec = {}
- for el in subspec:
- if el.startswith(part):
- if len(el.split(".")) > 1:
- new_spec[".".join(el.split(".")[1:])] = subspec[el]
- else:
- new_spec = subspec[el]
- subspec = new_spec
- current_doc = current_doc[part]
-
- subdocument = current_doc
- if field_name_parts[-1] == "$" and isinstance(subdocument, list):
- for i, doc in enumerate(subdocument):
- subspec_dollar = subspec.get("$elemMatch", subspec)
- if filter_applies(subspec_dollar, doc):
- subdocument[i] = v
- break
- continue
-
- updater(subdocument, field_name_parts[-1], v)
- continue
- # otherwise, we handle it the standard way
- self._update_document_single_field(doc, k, v, updater)
-
- return subdocument
-
- def _update_document_fields_with_positional_awareness(
- self, existing_document, v, spec, updater, subdocument
- ):
- positional = any("$" in key for key in v.keys())
-
- if positional:
- return self._update_document_fields_positional(
- existing_document, v, spec, updater, subdocument
- )
- self._update_document_fields(existing_document, v, updater)
- return subdocument
-
- def _update_document_single_field(self, doc, field_name, field_value, updater):
- field_name_parts = field_name.split(".")
- for part in field_name_parts[:-1]:
- if isinstance(doc, list):
- try:
- if part == "$":
- doc = doc[0]
- else:
- doc = doc[int(part)]
- continue
- except ValueError:
- pass
- elif isinstance(doc, dict):
- if updater is _unset_updater and part not in doc:
- # If the parent doesn't exists, so does it child.
- return
- doc = doc.setdefault(part, {})
- else:
- return
- field_name = field_name_parts[-1]
- updater(doc, field_name, field_value, codec_options=self._codec_options)
-
- def _iter_documents(self, filter):
- # Validate the filter even if no documents can be returned.
- if self._store.is_empty:
- filter_applies(filter, {})
-
- return (
- document
- for document in list(self._store.documents)
- if filter_applies(filter, document)
- )
-
- def find_one(self, filter=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
- # Allow calling find_one with a non-dict argument that gets used as
- # the id for the query.
- if filter is None:
- filter = {}
- if not isinstance(filter, Mapping):
- filter = {"_id": filter}
-
- try:
- return next(self.find(filter, *args, **kwargs))
- except StopIteration:
- return None
-
- def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs):
- kwargs["remove"] = True
- validate_is_mapping("filter", filter)
- return self._find_and_modify(filter, projection, sort=sort, **kwargs)
-
- def find_one_and_replace(
- self,
- filter,
- replacement,
- projection=None,
- sort=None,
- upsert=False,
- return_document=ReturnDocument.BEFORE,
- **kwargs,
- ):
- validate_is_mapping("filter", filter)
- validate_ok_for_replace(replacement)
- return self._find_and_modify(
- filter, projection, replacement, upsert, sort, return_document, **kwargs
- )
-
- def find_one_and_update(
- self,
- filter,
- update,
- projection=None,
- sort=None,
- upsert=False,
- return_document=ReturnDocument.BEFORE,
- **kwargs,
- ):
- validate_is_mapping("filter", filter)
- validate_ok_for_update(update)
- return self._find_and_modify(
- filter, projection, update, upsert, sort, return_document, **kwargs
- )
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def find_and_modify(
- self,
- query={},
- update=None,
- upsert=False,
- sort=None,
- full_response=False,
- manipulate=False,
- fields=None,
- **kwargs,
- ):
- warnings.warn(
- "find_and_modify is deprecated, use find_one_and_delete"
- ", find_one_and_replace, or find_one_and_update instead",
- DeprecationWarning,
- stacklevel=2,
- )
- if "projection" in kwargs:
- raise TypeError(
- "find_and_modify() got an unexpected keyword argument 'projection'"
- )
- return self._find_and_modify(
- query,
- update=update,
- upsert=upsert,
- sort=sort,
- projection=fields,
- **kwargs,
- )
-
- def _find_and_modify(
- self,
- query,
- projection=None,
- update=None,
- upsert=False,
- sort=None,
- return_document=ReturnDocument.BEFORE,
- session=None,
- **kwargs,
- ):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- remove = kwargs.get("remove", False)
- if kwargs.get("new", False) and remove:
- # message from mongodb
- raise OperationFailure("remove and returnNew can't co-exist")
-
- if not (remove or update):
- raise ValueError("Must either update or remove")
-
- if remove and update:
- raise ValueError("Can't do both update and remove")
-
- old = self.find_one(query, projection=projection, sort=sort)
- if not old and not upsert:
- return
-
- if old and "_id" in old:
- query = {"_id": old["_id"]}
-
- if remove:
- self.delete_one(query)
- else:
- updated = self._update(query, update, upsert)
- if updated["upserted"]:
- query = {"_id": updated["upserted"]}
-
- if return_document is ReturnDocument.AFTER or kwargs.get("new"):
- return self.find_one(query, projection)
- return old
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def save(self, to_save, manipulate=True, check_keys=True, **kwargs):
- warnings.warn(
- "save is deprecated. Use insert_one or replace_one " "instead",
- DeprecationWarning,
- stacklevel=2,
- )
- validate_is_mutable_mapping("to_save", to_save)
- validate_write_concern_params(**kwargs)
-
- if "_id" not in to_save:
- return self.insert(to_save)
- self._update(
- {"_id": to_save["_id"]},
- to_save,
- True,
- manipulate,
- check_keys=True,
- **kwargs,
- )
- return to_save.get("_id", None)
-
- def delete_one(self, filter, collation=None, hint=None, session=None):
- validate_is_mapping("filter", filter)
- return DeleteResult(
- self._delete(filter, collation=collation, hint=hint, session=session), True
- )
-
- def delete_many(self, filter, collation=None, hint=None, session=None):
- validate_is_mapping("filter", filter)
- return DeleteResult(
- self._delete(
- filter, collation=collation, hint=hint, multi=True, session=session
- ),
- True,
- )
-
- def _delete(self, filter, collation=None, hint=None, multi=False, session=None):
- if hint:
- raise NotImplementedError(
- "The hint argument of delete is valid but has not been implemented in "
- "mongomock yet"
- )
- if collation:
- raise_not_implemented(
- "collation",
- "The collation argument of delete is valid but has not been "
- "implemented in mongomock yet",
- )
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- filter = helpers.patch_datetime_awareness_in_document(filter)
- if filter is None:
- filter = {}
- if not isinstance(filter, Mapping):
- filter = {"_id": filter}
- to_delete = list(self.find(filter))
- deleted_count = 0
- for doc in to_delete:
- doc_id = doc["_id"]
- if isinstance(doc_id, dict):
- doc_id = helpers.hashdict(doc_id)
- del self._store[doc_id]
- deleted_count += 1
- if not multi:
- break
-
- return {
- "connectionId": self.database.client._id,
- "n": deleted_count,
- "ok": 1.0,
- "err": None,
- }
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def remove(self, spec_or_id=None, multi=True, **kwargs):
- warnings.warn(
- "remove is deprecated. Use delete_one or delete_many " "instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- validate_write_concern_params(**kwargs)
- return self._delete(spec_or_id, multi=multi)
-
- def count(self, filter=None, **kwargs):
- warnings.warn(
- "count is deprecated. Use estimated_document_count or "
- "count_documents instead. Please note that $where must be replaced "
- "by $expr, $near must be replaced by $geoWithin with $center, and "
- "$nearSphere must be replaced by $geoWithin with $centerSphere",
- DeprecationWarning,
- stacklevel=2,
- )
- if kwargs.pop("session", None):
- raise_not_implemented(
- "session", "Mongomock does not handle sessions yet"
- )
- if filter is None:
- return len(self._store)
- spec = helpers.patch_datetime_awareness_in_document(filter)
- return len(list(self._iter_documents(spec)))
-
- def count_documents(self, filter, **kwargs):
- if kwargs.pop("collation", None):
- raise_not_implemented(
- "collation",
- "The collation argument of count_documents is valid but has not been "
- "implemented in mongomock yet",
- )
- if kwargs.pop("session", None):
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- skip = kwargs.pop("skip", 0)
- if "limit" in kwargs:
- limit = kwargs.pop("limit")
- if not isinstance(limit, (int, float)):
- raise OperationFailure("the limit must be specified as a number")
- if limit <= 0:
- raise OperationFailure("the limit must be positive")
- limit = math.floor(limit)
- else:
- limit = None
- unknown_kwargs = set(kwargs) - {"maxTimeMS", "hint"}
- if unknown_kwargs:
- raise OperationFailure("unrecognized field '%s'" % unknown_kwargs.pop())
-
- spec = helpers.patch_datetime_awareness_in_document(filter)
- doc_num = len(list(self._iter_documents(spec)))
- count = max(doc_num - skip, 0)
- return count if limit is None else min(count, limit)
-
- def estimated_document_count(self, **kwargs):
- if kwargs.pop("session", None):
- raise ConfigurationError(
- "estimated_document_count does not support sessions"
- )
- unknown_kwargs = set(kwargs) - {"limit", "maxTimeMS", "hint"}
- if self.database.client.server_info()["versionArray"] < [5]:
- unknown_kwargs.discard("skip")
- if unknown_kwargs:
- raise OperationFailure(
- "BSON field 'count.%s' is an unknown field." % list(unknown_kwargs)[0]
- )
- return self.count_documents({}, **kwargs)
-
- def drop(self, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- self.database.drop_collection(self.name)
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def ensure_index(self, key_or_list, cache_for=300, **kwargs):
- return self.create_index(key_or_list, cache_for, **kwargs)
-
- def create_index(self, key_or_list, cache_for=300, session=None, **kwargs):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- index_list = helpers.create_index_list(key_or_list)
- is_unique = kwargs.pop("unique", False)
- is_sparse = kwargs.pop("sparse", False)
-
- index_name = kwargs.pop("name", helpers.gen_index_name(index_list))
- index_dict = {"key": index_list}
- if is_sparse:
- index_dict["sparse"] = True
- if is_unique:
- index_dict["unique"] = True
- if "expireAfterSeconds" in kwargs and kwargs["expireAfterSeconds"] is not None:
- index_dict["expireAfterSeconds"] = kwargs.pop("expireAfterSeconds")
- if (
- "partialFilterExpression" in kwargs
- and kwargs["partialFilterExpression"] is not None
- ):
- index_dict["partialFilterExpression"] = kwargs.pop(
- "partialFilterExpression"
- )
-
- existing_index = self._store.indexes.get(index_name)
- if existing_index and index_dict != existing_index:
- raise OperationFailure(
- "Index with name: %s already exists with different options" % index_name
- )
-
- # Check that documents already verify the uniquess of this new index.
- if is_unique:
- indexed = set()
- indexed_list = []
- documents_gen = self._store.documents
- for doc in documents_gen:
- index = []
- for key, unused_order in index_list:
- try:
- index.append(helpers.get_value_by_dot(doc, key))
- except KeyError:
- if is_sparse:
- continue
- index.append(None)
- if is_sparse and not index:
- continue
- index = tuple(index)
- try:
- if index in indexed:
- # Need to throw this inside the generator so it can clean the locks
- documents_gen.throw(
- DuplicateKeyError("E11000 Duplicate Key Error", 11000),
- None,
- None,
- )
- indexed.add(index)
- except TypeError as err:
- # index is not hashable.
- if index in indexed_list:
- documents_gen.throw(
- DuplicateKeyError("E11000 Duplicate Key Error", 11000),
- None,
- err,
- )
- indexed_list.append(index)
-
- self._store.create_index(index_name, index_dict)
-
- return index_name
-
- def create_indexes(self, indexes, session=None):
- for index in indexes:
- if not isinstance(index, IndexModel):
- raise TypeError(
- "%s is not an instance of pymongo.operations.IndexModel" % index
- )
-
- return [
- self.create_index(
- index.document["key"].items(),
- session=session,
- expireAfterSeconds=index.document.get("expireAfterSeconds"),
- unique=index.document.get("unique", False),
- sparse=index.document.get("sparse", False),
- name=index.document.get("name"),
- )
- for index in indexes
- ]
-
- def drop_index(self, index_or_name, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- if isinstance(index_or_name, list):
- name = helpers.gen_index_name(index_or_name)
- else:
- name = index_or_name
- try:
- self._store.drop_index(name)
- except KeyError as err:
- raise OperationFailure("index not found with name [%s]" % name) from err
-
- def drop_indexes(self, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- self._store.indexes = {}
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def reindex(self, session=None):
- if session:
- raise_not_implemented(
- "session", "Mongomock does not handle sessions yet"
- )
-
- def _list_all_indexes(self):
- if not self._store.is_created:
- return
- yield "_id_", {"key": [("_id", 1)]}
- for name, information in self._store.indexes.items():
- yield name, information
-
- def list_indexes(self, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- for name, information in self._list_all_indexes():
- yield dict(information, key=dict(information["key"]), name=name, v=2)
-
- def index_information(self, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- return {name: dict(index, v=2) for name, index in self._list_all_indexes()}
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def map_reduce(
- self,
- map_func,
- reduce_func,
- out,
- full_response=False,
- query=None,
- limit=0,
- session=None,
- ):
- if execjs is None:
- raise NotImplementedError(
- "PyExecJS is required in order to run Map-Reduce. "
- "Use 'pip install pyexecjs pymongo' to support Map-Reduce mock."
- )
- if session:
- raise_not_implemented(
- "session", "Mongomock does not handle sessions yet"
- )
- if limit == 0:
- limit = None
- start_time = time.perf_counter()
- out_collection = None
- reduced_rows = None
- full_dict = {
- "counts": {"input": 0, "reduce": 0, "emit": 0, "output": 0},
- "timeMillis": 0,
- "ok": 1.0,
- "result": None,
- }
- map_ctx = execjs.compile(
- """
- function doMap(fnc, docList) {
- var mappedDict = {};
- function emit(key, val) {
- if (key['$oid']) {
- mapped_key = '$oid' + key['$oid'];
- }
- else {
- mapped_key = key;
- }
- if(!mappedDict[mapped_key]) {
- mappedDict[mapped_key] = [];
- }
- mappedDict[mapped_key].push(val);
- }
- mapper = eval('('+fnc+')');
- var mappedList = new Array();
- for(var i=0; i 1:
- full_dict["counts"]["reduce"] += 1
- full_dict["counts"]["output"] = len(reduced_rows)
- if isinstance(out, (str, bytes)):
- out_collection = getattr(self.database, out)
- out_collection.drop()
- out_collection.insert(reduced_rows)
- ret_val = out_collection
- full_dict["result"] = out
- elif isinstance(out, SON) and out.get("replace") and out.get("db"):
- # Must be of the format SON([('replace','results'),('db','outdb')])
- out_db = getattr(self.database._client, out["db"])
- out_collection = getattr(out_db, out["replace"])
- out_collection.insert(reduced_rows)
- ret_val = out_collection
- full_dict["result"] = {"db": out["db"], "collection": out["replace"]}
- elif isinstance(out, dict) and out.get("inline"):
- ret_val = reduced_rows
- full_dict["result"] = reduced_rows
- else:
- raise TypeError("'out' must be an instance of string, dict or bson.SON")
- time_millis = (time.perf_counter() - start_time) * 1000
- full_dict["timeMillis"] = int(round(time_millis))
- if full_response:
- ret_val = full_dict
- return ret_val
-
- def inline_map_reduce(
- self,
- map_func,
- reduce_func,
- full_response=False,
- query=None,
- limit=0,
- session=None,
- ):
- return self.map_reduce(
- map_func,
- reduce_func,
- {"inline": 1},
- full_response,
- query,
- limit,
- session=session,
- )
-
- def distinct(self, key, filter=None, session=None):
- if session:
- raise_not_implemented("session", "Mongomock does not handle sessions yet")
- return self.find(filter).distinct(key)
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def group(self, key, condition, initial, reduce, finalize=None):
- if helpers.PYMONGO_VERSION >= version.parse("3.6"):
- raise OperationFailure("no such command: 'group'")
- if execjs is None:
- raise NotImplementedError(
- "PyExecJS is required in order to use group. "
- "Use 'pip install pyexecjs pymongo' to support group mock."
- )
- reduce_ctx = execjs.compile(
- """
- function doReduce(fnc, docList) {
- reducer = eval('('+fnc+')');
- for(var i=0, l=docList.length; i 0:
- doc += [None] * len_diff
- doc[field_index] = value
-
-
-def _unset_updater(doc, field_name, value, codec_options=None):
- if isinstance(doc, dict):
- doc.pop(field_name, None)
-
-
-def _inc_updater(doc, field_name, value, codec_options=None):
- if isinstance(doc, dict):
- doc[field_name] = doc.get(field_name, 0) + value
-
- if isinstance(doc, list):
- field_index = int(field_name)
- if field_index < 0:
- raise WriteError("Negative index provided")
- try:
- doc[field_index] += value
- except IndexError:
- len_diff = field_index - (len(doc) - 1)
- doc += [None] * len_diff
- doc[field_index] = value
-
-
-def _max_updater(doc, field_name, value, codec_options=None):
- if isinstance(doc, dict):
- doc[field_name] = max(doc.get(field_name, value), value)
-
-
-def _min_updater(doc, field_name, value, codec_options=None):
- if isinstance(doc, dict):
- doc[field_name] = min(doc.get(field_name, value), value)
-
-
-def _pop_updater(doc, field_name, value, codec_options=None):
- if value not in {1, -1}:
- raise WriteError("$pop expects 1 or -1, found: " + str(value))
-
- if isinstance(doc, dict):
- if isinstance(doc[field_name], (tuple, list)):
- doc[field_name] = list(doc[field_name])
- _pop_from_list(doc[field_name], value)
- return
- raise WriteError("Path contains element of non-array type")
-
- if isinstance(doc, list):
- field_index = int(field_name)
- if field_index < 0:
- raise WriteError("Negative index provided")
- if field_index >= len(doc):
- return
- _pop_from_list(doc[field_index], value)
-
-
-def _pop_from_list(list_instance, mongo_pop_value, codec_options=None):
- if not list_instance:
- return
-
- if mongo_pop_value == 1:
- list_instance.pop()
- elif mongo_pop_value == -1:
- list_instance.pop(0)
-
-
-def _current_date_updater(doc, field_name, value, codec_options=None):
- if isinstance(doc, dict):
- if value == {"$type": "timestamp"}:
- # TODO(juannyg): get_current_timestamp should also be using helpers utcnow,
- # as it currently using time.time internally
- doc[field_name] = helpers.get_current_timestamp()
- else:
- doc[field_name] = utcnow()
-
-
-_updaters = {
- "$set": _set_updater,
- "$unset": _unset_updater,
- "$inc": _inc_updater,
- "$max": _max_updater,
- "$min": _min_updater,
- "$pop": _pop_updater,
-}
diff --git a/packages/syft/tests/mongomock/command_cursor.py b/packages/syft/tests/mongomock/command_cursor.py
deleted file mode 100644
index 025bb836e24..00000000000
--- a/packages/syft/tests/mongomock/command_cursor.py
+++ /dev/null
@@ -1,37 +0,0 @@
-class CommandCursor(object):
- def __init__(self, collection, curser_info=None, address=None, retrieved=0):
- self._collection = iter(collection)
- self._id = None
- self._address = address
- self._data = {}
- self._retrieved = retrieved
- self._batch_size = 0
- self._killed = self._id == 0
-
- @property
- def address(self):
- return self._address
-
- def close(self):
- pass
-
- def batch_size(self, batch_size):
- return self
-
- @property
- def alive(self):
- return True
-
- def __iter__(self):
- return self
-
- def next(self):
- return next(self._collection)
-
- __next__ = next
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- return
diff --git a/packages/syft/tests/mongomock/database.py b/packages/syft/tests/mongomock/database.py
deleted file mode 100644
index 3b1a7e59f70..00000000000
--- a/packages/syft/tests/mongomock/database.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# stdlib
-import warnings
-
-# third party
-from packaging import version
-
-# relative
-from . import CollectionInvalid
-from . import InvalidName
-from . import OperationFailure
-from . import codec_options as mongomock_codec_options
-from . import helpers
-from . import read_preferences
-from . import store
-from .collection import Collection
-from .filtering import filter_applies
-
-try:
- # third party
- from pymongo import ReadPreference
-
- _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY
-except ImportError:
- _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY
-
-try:
- # third party
- from pymongo.read_concern import ReadConcern
-except ImportError:
- # relative
- from .read_concern import ReadConcern
-
-_LIST_COLLECTION_FILTER_ALLOWED_OPERATORS = frozenset(["$regex", "$eq", "$ne"])
-
-
-def _verify_list_collection_supported_op(keys):
- if set(keys) - _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS:
- raise NotImplementedError(
- "list collection names filter operator {0} is not implemented yet in mongomock "
- "allowed operators are {1}".format(
- keys, _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS
- )
- )
-
-
-class Database(object):
- def __init__(
- self,
- client,
- name,
- _store,
- read_preference=None,
- codec_options=None,
- read_concern=None,
- ):
- self.name = name
- self._client = client
- self._collection_accesses = {}
- self._store = _store or store.DatabaseStore()
- self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY
- mongomock_codec_options.is_supported(codec_options)
- self._codec_options = codec_options or mongomock_codec_options.CodecOptions()
- if read_concern and not isinstance(read_concern, ReadConcern):
- raise TypeError(
- "read_concern must be an instance of pymongo.read_concern.ReadConcern"
- )
- self._read_concern = read_concern or ReadConcern()
-
- def __getitem__(self, coll_name):
- return self.get_collection(coll_name)
-
- def __getattr__(self, attr):
- if attr.startswith("_"):
- raise AttributeError(
- "%s has no attribute '%s'. To access the %s collection, use database['%s']."
- % (self.__class__.__name__, attr, attr, attr)
- )
- return self[attr]
-
- def __repr__(self):
- return "Database({0}, '{1}')".format(self._client, self.name)
-
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- return self._client == other._client and self.name == other.name
- return NotImplemented
-
- if helpers.PYMONGO_VERSION >= version.parse("3.12"):
-
- def __hash__(self):
- return hash((self._client, self.name))
-
- @property
- def client(self):
- return self._client
-
- @property
- def read_preference(self):
- return self._read_preference
-
- @property
- def codec_options(self):
- return self._codec_options
-
- @property
- def read_concern(self):
- return self._read_concern
-
- def _get_created_collections(self):
- return self._store.list_created_collection_names()
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def collection_names(self, include_system_collections=True, session=None):
- warnings.warn(
- "collection_names is deprecated. Use list_collection_names instead."
- )
- if include_system_collections:
- return list(self._get_created_collections())
- return self.list_collection_names(session=session)
-
- def list_collections(self, filter=None, session=None, nameOnly=False):
- raise NotImplementedError(
- "list_collections is a valid method of Database but has not been implemented in "
- "mongomock yet."
- )
-
- def list_collection_names(self, filter=None, session=None):
- """filter: only name field type with eq,ne or regex operator
-
- session: not supported
- for supported operator please see _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS
- """
- field_name = "name"
-
- if session:
- raise NotImplementedError("Mongomock does not handle sessions yet")
-
- if filter:
- if not filter.get("name"):
- raise NotImplementedError(
- "list collection {0} might be valid but is not "
- "implemented yet in mongomock".format(filter)
- )
-
- filter = (
- {field_name: {"$eq": filter.get(field_name)}}
- if isinstance(filter.get(field_name), str)
- else filter
- )
-
- _verify_list_collection_supported_op(filter.get(field_name).keys())
-
- return [
- name
- for name in list(self._store._collections)
- if filter_applies(filter, {field_name: name})
- and not name.startswith("system.")
- ]
-
- return [
- name
- for name in self._get_created_collections()
- if not name.startswith("system.")
- ]
-
- def get_collection(
- self,
- name,
- codec_options=None,
- read_preference=None,
- write_concern=None,
- read_concern=None,
- ):
- if read_preference is not None:
- read_preferences.ensure_read_preference_type(
- "read_preference", read_preference
- )
- mongomock_codec_options.is_supported(codec_options)
- try:
- return self._collection_accesses[name].with_options(
- codec_options=codec_options or self._codec_options,
- read_preference=read_preference or self.read_preference,
- read_concern=read_concern,
- write_concern=write_concern,
- )
- except KeyError:
- self._ensure_valid_collection_name(name)
- collection = self._collection_accesses[name] = Collection(
- self,
- name=name,
- read_concern=read_concern,
- write_concern=write_concern,
- read_preference=read_preference or self.read_preference,
- codec_options=codec_options or self._codec_options,
- _db_store=self._store,
- )
- return collection
-
- def drop_collection(self, name_or_collection, session=None):
- if session:
- raise NotImplementedError("Mongomock does not handle sessions yet")
- if isinstance(name_or_collection, Collection):
- name_or_collection._store.drop()
- else:
- self._store[name_or_collection].drop()
-
- def _ensure_valid_collection_name(self, name):
- # These are the same checks that are done in pymongo.
- if not isinstance(name, str):
- raise TypeError("name must be an instance of str")
- if not name or ".." in name:
- raise InvalidName("collection names cannot be empty")
- if name[0] == "." or name[-1] == ".":
- raise InvalidName("collection names must not start or end with '.'")
- if "$" in name:
- raise InvalidName("collection names must not contain '$'")
- if "\x00" in name:
- raise InvalidName("collection names must not contain the null character")
-
- def create_collection(self, name, **kwargs):
- self._ensure_valid_collection_name(name)
- if name in self.list_collection_names():
- raise CollectionInvalid("collection %s already exists" % name)
-
- if kwargs:
- raise NotImplementedError("Special options not supported")
-
- self._store.create_collection(name)
- return self[name]
-
- def rename_collection(self, name, new_name, dropTarget=False):
- """Changes the name of an existing collection."""
- self._ensure_valid_collection_name(new_name)
-
- # Reference for server implementation:
- # https://docs.mongodb.com/manual/reference/command/renameCollection/
- if not self._store[name].is_created:
- raise OperationFailure(
- 'The collection "{0}" does not exist.'.format(name), 10026
- )
- if new_name in self._store:
- if dropTarget:
- self.drop_collection(new_name)
- else:
- raise OperationFailure(
- 'The target collection "{0}" already exists'.format(new_name), 10027
- )
- self._store.rename(name, new_name)
- return {"ok": 1}
-
- def dereference(self, dbref, session=None):
- if session:
- raise NotImplementedError("Mongomock does not handle sessions yet")
-
- if not hasattr(dbref, "collection") or not hasattr(dbref, "id"):
- raise TypeError("cannot dereference a %s" % type(dbref))
- if dbref.database is not None and dbref.database != self.name:
- raise ValueError(
- "trying to dereference a DBRef that points to "
- "another database (%r not %r)" % (dbref.database, self.name)
- )
- return self[dbref.collection].find_one({"_id": dbref.id})
-
- def command(self, command, **unused_kwargs):
- if isinstance(command, str):
- command = {command: 1}
- if "ping" in command:
- return {"ok": 1.0}
- # TODO(pascal): Differentiate NotImplementedError for valid commands
- # and OperationFailure if the command is not valid.
- raise NotImplementedError(
- "command is a valid Database method but is not implemented in Mongomock yet"
- )
-
- def with_options(
- self,
- codec_options=None,
- read_preference=None,
- write_concern=None,
- read_concern=None,
- ):
- mongomock_codec_options.is_supported(codec_options)
-
- if write_concern:
- raise NotImplementedError(
- "write_concern is a valid parameter for with_options but is not implemented yet in "
- "mongomock"
- )
-
- if read_preference is None or read_preference == self._read_preference:
- return self
-
- return Database(
- self._client,
- self.name,
- self._store,
- read_preference=read_preference or self._read_preference,
- codec_options=codec_options or self._codec_options,
- read_concern=read_concern or self._read_concern,
- )
diff --git a/packages/syft/tests/mongomock/filtering.py b/packages/syft/tests/mongomock/filtering.py
deleted file mode 100644
index 345b94c7b88..00000000000
--- a/packages/syft/tests/mongomock/filtering.py
+++ /dev/null
@@ -1,601 +0,0 @@
-# stdlib
-from datetime import datetime
-import itertools
-import numbers
-import operator
-import re
-import uuid
-
-# relative
-from . import OperationFailure
-from .helpers import ObjectId
-from .helpers import RE_TYPE
-
-try:
- # stdlib
- from types import NoneType
-except ImportError:
- NoneType = type(None)
-
-try:
- # third party
- from bson import DBRef
- from bson import Regex
-
- _RE_TYPES = (RE_TYPE, Regex)
-except ImportError:
- DBRef = None
- _RE_TYPES = (RE_TYPE,)
-
-try:
- # third party
- from bson.decimal128 import Decimal128
-except ImportError:
- Decimal128 = None
-
-_TOP_LEVEL_OPERATORS = {"$expr", "$text", "$where", "$jsonSchema"}
-
-
-_NOT_IMPLEMENTED_OPERATORS = {
- "$bitsAllClear",
- "$bitsAllSet",
- "$bitsAnyClear",
- "$bitsAnySet",
- "$geoIntersects",
- "$geoWithin",
- "$maxDistance",
- "$minDistance",
- "$near",
- "$nearSphere",
-}
-
-
-def filter_applies(search_filter, document):
- """Applies given filter
-
- This function implements MongoDB's matching strategy over documents in the find() method
- and other related scenarios (like $elemMatch)
- """
- return _filterer_inst.apply(search_filter, document)
-
-
-class _Filterer(object):
- """An object to help applying a filter, using the MongoDB query language."""
-
- # This is populated using register_parse_expression further down.
- parse_expression = []
-
- def __init__(self):
- self._operator_map = dict(
- {
- "$eq": _list_expand(operator_eq),
- "$ne": _list_expand(
- lambda dv, sv: not operator_eq(dv, sv), negative=True
- ),
- "$all": self._all_op,
- "$in": _in_op,
- "$nin": lambda dv, sv: not _in_op(dv, sv),
- "$exists": lambda dv, sv: bool(sv) == (dv is not None),
- "$regex": _not_None_and(_regex),
- "$elemMatch": self._elem_match_op,
- "$size": _size_op,
- "$type": _type_op,
- },
- **{
- key: _not_None_and(_list_expand(_compare_objects(op)))
- for key, op in SORTING_OPERATOR_MAP.items()
- },
- )
-
- def apply(self, search_filter, document):
- if not isinstance(search_filter, dict):
- raise OperationFailure(
- "the match filter must be an expression in an object"
- )
-
- for key, search in search_filter.items():
- # Top level operators.
- if key == "$comment":
- continue
- if key in LOGICAL_OPERATOR_MAP:
- if not search:
- raise OperationFailure(
- "BadValue $and/$or/$nor must be a nonempty array"
- )
- if not LOGICAL_OPERATOR_MAP[key](document, search, self.apply):
- return False
- continue
- if key == "$expr":
- parse_expression = self.parse_expression[0]
- if not parse_expression(search, document, ignore_missing_keys=True):
- return False
- continue
- if key in _TOP_LEVEL_OPERATORS:
- raise NotImplementedError(
- "The {} operator is not implemented in mongomock yet".format(key)
- )
- if key.startswith("$"):
- raise OperationFailure("unknown top level operator: " + key)
-
- is_match = False
-
- is_checking_negative_match = isinstance(search, dict) and {
- "$ne",
- "$nin",
- } & set(search.keys())
- is_checking_positive_match = not isinstance(search, dict) or (
- set(search.keys()) - {"$ne", "$nin"}
- )
- has_candidates = False
-
- if search == {"$exists": False} and not iter_key_candidates(key, document):
- continue
-
- if isinstance(search, dict) and "$all" in search:
- if not self._all_op(iter_key_candidates(key, document), search["$all"]):
- return False
- # if there are no query operators then continue
- if len(search) == 1:
- continue
-
- for doc_val in iter_key_candidates(key, document):
- has_candidates |= doc_val is not None
- is_ops_filter = (
- search
- and isinstance(search, dict)
- and all(key.startswith("$") for key in search.keys())
- )
- if is_ops_filter:
- if "$options" in search and "$regex" in search:
- search = _combine_regex_options(search)
- unknown_operators = set(search) - set(self._operator_map) - {"$not"}
- if unknown_operators:
- not_implemented_operators = (
- unknown_operators & _NOT_IMPLEMENTED_OPERATORS
- )
- if not_implemented_operators:
- raise NotImplementedError(
- "'%s' is a valid operation but it is not supported by Mongomock "
- "yet." % list(not_implemented_operators)[0]
- )
- raise OperationFailure(
- "unknown operator: " + list(unknown_operators)[0]
- )
- is_match = (
- all(
- operator_string in self._operator_map
- and self._operator_map[operator_string](doc_val, search_val)
- or operator_string == "$not"
- and self._not_op(document, key, search_val)
- for operator_string, search_val in search.items()
- )
- and search
- )
- elif isinstance(search, _RE_TYPES) and isinstance(doc_val, (str, list)):
- is_match = _regex(doc_val, search)
- elif key in LOGICAL_OPERATOR_MAP:
- if not search:
- raise OperationFailure(
- "BadValue $and/$or/$nor must be a nonempty array"
- )
- is_match = LOGICAL_OPERATOR_MAP[key](document, search, self.apply)
- elif isinstance(doc_val, (list, tuple)):
- is_match = search in doc_val or search == doc_val
- if isinstance(search, ObjectId):
- is_match |= str(search) in doc_val
- else:
- is_match = (doc_val == search) or (
- search is None and doc_val is None
- )
-
- # When checking negative match, all the elements should match.
- if is_checking_negative_match and not is_match:
- return False
-
- # If not checking negative matches, the first match is enouh for this criteria.
- if is_match and not is_checking_negative_match:
- break
-
- if not is_match and (has_candidates or is_checking_positive_match):
- return False
-
- return True
-
- def _not_op(self, d, k, s):
- if isinstance(s, dict):
- for key in s.keys():
- if key not in self._operator_map and key not in LOGICAL_OPERATOR_MAP:
- raise OperationFailure("unknown operator: %s" % key)
- elif isinstance(s, _RE_TYPES):
- pass
- else:
- raise OperationFailure("$not needs a regex or a document")
- return not self.apply({k: s}, d)
-
- def _elem_match_op(self, doc_val, query):
- if not isinstance(doc_val, list):
- return False
- if not isinstance(query, dict):
- raise OperationFailure("$elemMatch needs an Object")
- for item in doc_val:
- try:
- if self.apply(query, item):
- return True
- except OperationFailure:
- if self.apply({"field": query}, {"field": item}):
- return True
- return False
-
- def _all_op(self, doc_val, search_val):
- if isinstance(doc_val, list) and doc_val and isinstance(doc_val[0], list):
- doc_val = list(itertools.chain.from_iterable(doc_val))
- dv = _force_list(doc_val)
- matches = []
- for x in search_val:
- if isinstance(x, dict) and "$elemMatch" in x:
- matches.append(self._elem_match_op(doc_val, x["$elemMatch"]))
- else:
- matches.append(x in dv)
- return all(matches)
-
-
-def iter_key_candidates(key, doc):
- """Get possible subdocuments or lists that are referred to by the key in question
-
- Returns the appropriate nested value if the key includes dot notation.
- """
- if not key:
- return [doc]
-
- if doc is None:
- return ()
-
- if isinstance(doc, list):
- return _iter_key_candidates_sublist(key, doc)
-
- if not isinstance(doc, dict):
- return ()
-
- key_parts = key.split(".")
- if len(key_parts) == 1:
- return [doc.get(key, None)]
-
- sub_key = ".".join(key_parts[1:])
- sub_doc = doc.get(key_parts[0], {})
- return iter_key_candidates(sub_key, sub_doc)
-
-
-def _iter_key_candidates_sublist(key, doc):
- """Iterates of candidates
-
- :param doc: a list to be searched for candidates for our key
- :param key: the string key to be matched
- """
- key_parts = key.split(".")
- sub_key = key_parts.pop(0)
- key_remainder = ".".join(key_parts)
- try:
- sub_key_int = int(sub_key)
- except ValueError:
- sub_key_int = None
-
- if sub_key_int is None:
- # subkey is not an integer...
- ret = []
- for sub_doc in doc:
- if isinstance(sub_doc, dict):
- if sub_key in sub_doc:
- ret.extend(iter_key_candidates(key_remainder, sub_doc[sub_key]))
- else:
- ret.append(None)
- return ret
-
- # subkey is an index
- if sub_key_int >= len(doc):
- return () # dead end
- sub_doc = doc[sub_key_int]
- if key_parts:
- return iter_key_candidates(".".join(key_parts), sub_doc)
- return [sub_doc]
-
-
-def _force_list(v):
- return v if isinstance(v, (list, tuple)) else [v]
-
-
-def _in_op(doc_val, search_val):
- if not isinstance(search_val, (list, tuple)):
- raise OperationFailure("$in needs an array")
- if doc_val is None and None in search_val:
- return True
- doc_val = _force_list(doc_val)
- is_regex_list = [isinstance(x, _RE_TYPES) for x in search_val]
- if not any(is_regex_list):
- return any(x in search_val for x in doc_val)
- for x, is_regex in zip(search_val, is_regex_list):
- if (is_regex and _regex(doc_val, x)) or (x in doc_val):
- return True
- return False
-
-
-def _not_None_and(f):
- """wrap an operator to return False if the first arg is None"""
- return lambda v, l: v is not None and f(v, l)
-
-
-def _compare_objects(op):
- """Wrap an operator to also compare objects following BSON comparison.
-
- See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#objects
- """
-
- def _wrapped(a, b):
- # Do not compare uncomparable types, see Type Bracketing:
- # https://docs.mongodb.com/manual/reference/method/db.collection.find/#type-bracketing
- return bson_compare(op, a, b, can_compare_types=False)
-
- return _wrapped
-
-
-def bson_compare(op, a, b, can_compare_types=True):
- """Compare two elements using BSON comparison.
-
- Args:
- op: the basic operation to compare (e.g. operator.lt, operator.ge).
- a: the first operand
- b: the second operand
- can_compare_types: if True, according to BSON's definition order
- between types is used, otherwise always return False when types are
- different.
- """
- a_type = _get_compare_type(a)
- b_type = _get_compare_type(b)
- if a_type != b_type:
- return can_compare_types and op(a_type, b_type)
-
- # Compare DBRefs as dicts
- if type(a).__name__ == "DBRef" and hasattr(a, "as_doc"):
- a = a.as_doc()
- if type(b).__name__ == "DBRef" and hasattr(b, "as_doc"):
- b = b.as_doc()
-
- if isinstance(a, dict):
- # MongoDb server compares the type before comparing the keys
- # https://github.com/mongodb/mongo/blob/f10f214/src/mongo/bson/bsonelement.cpp#L516
- # even though the documentation does not say anything about that.
- a = [(_get_compare_type(v), k, v) for k, v in a.items()]
- b = [(_get_compare_type(v), k, v) for k, v in b.items()]
-
- if isinstance(a, (tuple, list)):
- for item_a, item_b in zip(a, b):
- if item_a != item_b:
- return bson_compare(op, item_a, item_b)
- return bson_compare(op, len(a), len(b))
-
- if isinstance(a, NoneType):
- return op(0, 0)
-
- # bson handles bytes as binary in python3+:
- # https://api.mongodb.com/python/current/api/bson/index.html
- if isinstance(a, bytes):
- # Performs the same operation as described by:
- # https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#bindata
- if len(a) != len(b):
- return op(len(a), len(b))
- # bytes is always treated as subtype 0 by the bson library
- return op(a, b)
-
-
-def _get_compare_type(val):
- """Get a number representing the base type of the value used for comparison.
-
- See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/
- also https://github.com/mongodb/mongo/blob/46b28bb/src/mongo/bson/bsontypes.h#L175
- for canonical values.
- """
- if isinstance(val, NoneType):
- return 5
- if isinstance(val, bool):
- return 40
- if isinstance(val, numbers.Number):
- return 10
- if isinstance(val, str):
- return 15
- if isinstance(val, dict):
- return 20
- if isinstance(val, (tuple, list)):
- return 25
- if isinstance(val, uuid.UUID):
- return 30
- if isinstance(val, bytes):
- return 30
- if isinstance(val, ObjectId):
- return 35
- if isinstance(val, datetime):
- return 45
- if isinstance(val, _RE_TYPES):
- return 50
- if DBRef and isinstance(val, DBRef):
- # According to the C++ code, this should be 55 but apparently sending a DBRef through
- # pymongo is stored as a dict.
- return 20
- return 0
-
-
-def _regex(doc_val, regex):
- if not (isinstance(doc_val, (str, list)) or isinstance(doc_val, RE_TYPE)):
- return False
- if isinstance(regex, str):
- regex = re.compile(regex)
- if not isinstance(regex, RE_TYPE):
- # bson.Regex
- regex = regex.try_compile()
- return any(
- regex.search(item) for item in _force_list(doc_val) if isinstance(item, str)
- )
-
-
-def _size_op(doc_val, search_val):
- if isinstance(doc_val, (list, tuple, dict)):
- return search_val == len(doc_val)
- return search_val == 1 if doc_val and doc_val is not None else 0
-
-
-def _list_expand(f, negative=False):
- def func(doc_val, search_val):
- if isinstance(doc_val, (list, tuple)) and not isinstance(
- search_val, (list, tuple)
- ):
- if negative:
- return all(f(val, search_val) for val in doc_val)
- return any(f(val, search_val) for val in doc_val)
- return f(doc_val, search_val)
-
- return func
-
-
-def _type_op(doc_val, search_val, in_array=False):
- if search_val not in TYPE_MAP:
- raise OperationFailure("%r is not a valid $type" % search_val)
- elif TYPE_MAP[search_val] is None:
- raise NotImplementedError(
- "%s is a valid $type but not implemented" % search_val
- )
- if TYPE_MAP[search_val](doc_val):
- return True
- if isinstance(doc_val, (list, tuple)) and not in_array:
- return any(_type_op(val, search_val, in_array=True) for val in doc_val)
- return False
-
-
-def _combine_regex_options(search):
- if not isinstance(search["$options"], str):
- raise OperationFailure("$options has to be a string")
-
- options = None
- for option in search["$options"]:
- if option not in "imxs":
- continue
- re_option = getattr(re, option.upper())
- if options is None:
- options = re_option
- else:
- options |= re_option
-
- search_copy = dict(search)
- del search_copy["$options"]
-
- if options is None:
- return search_copy
-
- if isinstance(search["$regex"], _RE_TYPES):
- if isinstance(search["$regex"], RE_TYPE):
- search_copy["$regex"] = re.compile(
- search["$regex"].pattern, search["$regex"].flags | options
- )
- else:
- # bson.Regex
- regex = search["$regex"]
- search_copy["$regex"] = regex.__class__(
- regex.pattern, regex.flags | options
- )
- else:
- search_copy["$regex"] = re.compile(search["$regex"], options)
- return search_copy
-
-
-def operator_eq(doc_val, search_val):
- if doc_val is None and search_val is None:
- return True
- return operator.eq(doc_val, search_val)
-
-
-SORTING_OPERATOR_MAP = {
- "$gt": operator.gt,
- "$gte": operator.ge,
- "$lt": operator.lt,
- "$lte": operator.le,
-}
-
-
-LOGICAL_OPERATOR_MAP = {
- "$or": lambda d, subq, filter_func: any(filter_func(q, d) for q in subq),
- "$and": lambda d, subq, filter_func: all(filter_func(q, d) for q in subq),
- "$nor": lambda d, subq, filter_func: all(not filter_func(q, d) for q in subq),
- "$not": lambda d, subq, filter_func: (not filter_func(q, d) for q in subq),
-}
-
-
-TYPE_MAP = {
- "double": lambda v: isinstance(v, float),
- "string": lambda v: isinstance(v, str),
- "object": lambda v: isinstance(v, dict),
- "array": lambda v: isinstance(v, list),
- "binData": lambda v: isinstance(v, bytes),
- "undefined": None,
- "objectId": lambda v: isinstance(v, ObjectId),
- "bool": lambda v: isinstance(v, bool),
- "date": lambda v: isinstance(v, datetime),
- "null": None,
- "regex": None,
- "dbPointer": None,
- "javascript": None,
- "symbol": None,
- "javascriptWithScope": None,
- "int": lambda v: (
- isinstance(v, int) and not isinstance(v, bool) and v.bit_length() <= 32
- ),
- "timestamp": None,
- "long": lambda v: (
- isinstance(v, int) and not isinstance(v, bool) and v.bit_length() > 32
- ),
- "decimal": (lambda v: isinstance(v, Decimal128)) if Decimal128 else None,
- "number": lambda v: (
- # pylint: disable-next=isinstance-second-argument-not-valid-type
- isinstance(v, (int, float) + ((Decimal128,) if Decimal128 else ()))
- and not isinstance(v, bool)
- ),
- "minKey": None,
- "maxKey": None,
-}
-
-
-def resolve_key(key, doc):
- return next(iter(iter_key_candidates(key, doc)), None)
-
-
-def resolve_sort_key(key, doc):
- value = resolve_key(key, doc)
- # see http://docs.mongodb.org/manual/reference/method/cursor.sort/#ascending-descending-sort
- if value is None:
- return 1, BsonComparable(None)
-
- # List or tuples are sorted solely by their first value.
- if isinstance(value, (tuple, list)):
- if not value:
- return 0, BsonComparable(None)
- return 1, BsonComparable(value[0])
-
- return 1, BsonComparable(value)
-
-
-class BsonComparable(object):
- """Wraps a value in an BSON like object that can be compared one to another."""
-
- def __init__(self, obj):
- self.obj = obj
-
- def __lt__(self, other):
- return bson_compare(operator.lt, self.obj, other.obj)
-
-
-_filterer_inst = _Filterer()
-
-
-# Developer note: to avoid a cross-modules dependency (filtering requires aggregation, that requires
-# filtering), the aggregation module needs to register its parse_expression function here.
-def register_parse_expression(parse_expression):
- """Register the parse_expression function from the aggregate module."""
-
- del _Filterer.parse_expression[:]
- _Filterer.parse_expression.append(parse_expression)
diff --git a/packages/syft/tests/mongomock/gridfs.py b/packages/syft/tests/mongomock/gridfs.py
deleted file mode 100644
index 13a59999855..00000000000
--- a/packages/syft/tests/mongomock/gridfs.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# stdlib
-from unittest import mock
-
-# relative
-from . import Collection as MongoMockCollection
-from . import Database as MongoMockDatabase
-from ..collection import Cursor as MongoMockCursor
-
-try:
- # third party
- from gridfs.grid_file import GridOut as PyMongoGridOut
- from gridfs.grid_file import GridOutCursor as PyMongoGridOutCursor
- from pymongo.collection import Collection as PyMongoCollection
- from pymongo.database import Database as PyMongoDatabase
-
- _HAVE_PYMONGO = True
-except ImportError:
- _HAVE_PYMONGO = False
-
-
-# This is a copy of GridOutCursor but with a different base. Note that we
-# need both classes as one might want to access both mongomock and real
-# MongoDb.
-class _MongoMockGridOutCursor(MongoMockCursor):
- def __init__(self, collection, *args, **kwargs):
- self.__root_collection = collection
- super(_MongoMockGridOutCursor, self).__init__(collection.files, *args, **kwargs)
-
- def next(self):
- next_file = super(_MongoMockGridOutCursor, self).next()
- return PyMongoGridOut(
- self.__root_collection, file_document=next_file, session=self.session
- )
-
- __next__ = next
-
- def add_option(self, *args, **kwargs):
- raise NotImplementedError()
-
- def remove_option(self, *args, **kwargs):
- raise NotImplementedError()
-
- def _clone_base(self, session):
- return _MongoMockGridOutCursor(self.__root_collection, session=session)
-
-
-def _create_grid_out_cursor(collection, *args, **kwargs):
- if isinstance(collection, MongoMockCollection):
- return _MongoMockGridOutCursor(collection, *args, **kwargs)
- return PyMongoGridOutCursor(collection, *args, **kwargs)
-
-
-def enable_gridfs_integration():
- """This function enables the use of mongomock Database's and Collection's inside gridfs
-
- Gridfs library use `isinstance` to make sure the passed elements
- are valid `pymongo.Database/Collection` so we monkey patch those types in the gridfs modules
- (luckily in the modules they are used, they are only used with isinstance).
- """
-
- if not _HAVE_PYMONGO:
- raise NotImplementedError("gridfs mocking requires pymongo to work")
-
- mock.patch("gridfs.Database", (PyMongoDatabase, MongoMockDatabase)).start()
- mock.patch(
- "gridfs.grid_file.Collection", (PyMongoCollection, MongoMockCollection)
- ).start()
- mock.patch("gridfs.GridOutCursor", _create_grid_out_cursor).start()
diff --git a/packages/syft/tests/mongomock/helpers.py b/packages/syft/tests/mongomock/helpers.py
deleted file mode 100644
index 13f6892cae5..00000000000
--- a/packages/syft/tests/mongomock/helpers.py
+++ /dev/null
@@ -1,474 +0,0 @@
-# stdlib
-from collections import OrderedDict
-from collections import abc
-from datetime import datetime
-from datetime import timedelta
-from datetime import tzinfo
-import re
-import time
-from urllib.parse import unquote_plus
-import warnings
-
-# third party
-from packaging import version
-
-# relative
-from . import InvalidURI
-
-# Get ObjectId from bson if available or import a crafted one. This is not used
-# in this module but is made available for callers of this module.
-try:
- # third party
- from bson import ObjectId # pylint: disable=unused-import
- from bson import Timestamp
- from pymongo import version as pymongo_version
-
- PYMONGO_VERSION = version.parse(pymongo_version)
- HAVE_PYMONGO = True
-except ImportError:
- from .object_id import ObjectId # noqa
-
- Timestamp = None
- # Default Pymongo version if not present.
- PYMONGO_VERSION = version.parse("4.0")
- HAVE_PYMONGO = False
-
-# Cache the RegExp pattern type.
-RE_TYPE = type(re.compile(""))
-_HOST_MATCH = re.compile(r"^([^@]+@)?([^:]+|\[[^\]]+\])(:([^:]+))?$")
-_SIMPLE_HOST_MATCH = re.compile(r"^([^:]+|\[[^\]]+\])(:([^:]+))?$")
-
-try:
- # third party
- from bson.tz_util import utc
-except ImportError:
-
- class _FixedOffset(tzinfo):
- def __init__(self, offset, name):
- self.__offset = timedelta(minutes=offset)
- self.__name = name
-
- def __getinitargs__(self):
- return self.__offset, self.__name
-
- def utcoffset(self, dt):
- return self.__offset
-
- def tzname(self, dt):
- return self.__name
-
- def dst(self, dt):
- return timedelta(0)
-
- utc = _FixedOffset(0, "UTC")
-
-
-ASCENDING = 1
-DESCENDING = -1
-
-
-def utcnow():
- """Simple wrapper for datetime.utcnow
-
- This provides a centralized definition of "now" in the mongomock realm,
- allowing users to transform the value of "now" to the future or the past,
- based on their testing needs. For example:
-
- ```python
- def test_x(self):
- with mock.patch("mongomock.utcnow") as mm_utc:
- mm_utc = datetime.utcnow() + timedelta(hours=100)
- # Test some things "100 hours" in the future
- ```
- """
- return datetime.utcnow()
-
-
-def print_deprecation_warning(old_param_name, new_param_name):
- warnings.warn(
- "'%s' has been deprecated to be in line with pymongo implementation, a new parameter '%s' "
- "should be used instead. the old parameter will be kept for backward compatibility "
- "purposes." % (old_param_name, new_param_name),
- DeprecationWarning,
- )
-
-
-def create_index_list(key_or_list, direction=None):
- """Helper to generate a list of (key, direction) pairs.
-
- It takes such a list, or a single key, or a single key and direction.
- """
- if isinstance(key_or_list, str):
- return [(key_or_list, direction or ASCENDING)]
- if not isinstance(key_or_list, (list, tuple, abc.Iterable)):
- raise TypeError(
- "if no direction is specified, " "key_or_list must be an instance of list"
- )
- return key_or_list
-
-
-def gen_index_name(index_list):
- """Generate an index name based on the list of keys with directions."""
-
- return "_".join(["%s_%s" % item for item in index_list])
-
-
-class hashdict(dict):
- """hashable dict implementation, suitable for use as a key into other dicts.
-
- >>> h1 = hashdict({'apples': 1, 'bananas':2})
- >>> h2 = hashdict({'bananas': 3, 'mangoes': 5})
- >>> h1+h2
- hashdict(apples=1, bananas=3, mangoes=5)
- >>> d1 = {}
- >>> d1[h1] = 'salad'
- >>> d1[h1]
- 'salad'
- >>> d1[h2]
- Traceback (most recent call last):
- ...
- KeyError: hashdict(bananas=3, mangoes=5)
-
- based on answers from
- http://stackoverflow.com/questions/1151658/python-hashable-dicts
- """
-
- def __key(self):
- return frozenset(
- (
- k,
- (
- hashdict(v)
- if isinstance(v, dict)
- else tuple(v)
- if isinstance(v, list)
- else v
- ),
- )
- for k, v in self.items()
- )
-
- def __repr__(self):
- return "{0}({1})".format(
- self.__class__.__name__,
- ", ".join(
- "{0}={1}".format(str(i[0]), repr(i[1])) for i in sorted(self.__key())
- ),
- )
-
- def __hash__(self):
- return hash(self.__key())
-
- def __setitem__(self, key, value):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def __delitem__(self, key):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def clear(self):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def pop(self, *args, **kwargs):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def popitem(self, *args, **kwargs):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def setdefault(self, *args, **kwargs):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def update(self, *args, **kwargs):
- raise TypeError(
- "{0} does not support item assignment".format(self.__class__.__name__)
- )
-
- def __add__(self, right):
- result = hashdict(self)
- dict.update(result, right)
- return result
-
-
-def fields_list_to_dict(fields):
- """Takes a list of field names and returns a matching dictionary.
-
- ['a', 'b'] becomes {'a': 1, 'b': 1}
-
- and
-
- ['a.b.c', 'd', 'a.c'] becomes {'a.b.c': 1, 'd': 1, 'a.c': 1}
- """
- as_dict = {}
- for field in fields:
- if not isinstance(field, str):
- raise TypeError(
- "fields must be a list of key names, each an instance of str"
- )
- as_dict[field] = 1
- return as_dict
-
-
-def parse_uri(uri, default_port=27017, warn=False):
- """A simplified version of pymongo.uri_parser.parse_uri.
-
- Returns a dict with:
- - nodelist, a tuple of (host, port)
- - database the name of the database or None if no database is provided in the URI.
-
- An invalid MongoDB connection URI may raise an InvalidURI exception,
- however, the URI is not fully parsed and some invalid URIs may not result
- in an exception.
-
- 'mongodb://host1/database' becomes 'host1', 27017, 'database'
-
- and
-
- 'mongodb://host1' becomes 'host1', 27017, None
- """
- SCHEME = "mongodb://"
-
- if not uri.startswith(SCHEME):
- raise InvalidURI("Invalid URI scheme: URI " "must begin with '%s'" % (SCHEME,))
-
- scheme_free = uri[len(SCHEME) :]
-
- if not scheme_free:
- raise InvalidURI("Must provide at least one hostname or IP.")
-
- dbase = None
-
- # Check for unix domain sockets in the uri
- if ".sock" in scheme_free:
- host_part, _, path_part = scheme_free.rpartition("/")
- if not host_part:
- host_part = path_part
- path_part = ""
- if "/" in host_part:
- raise InvalidURI(
- "Any '/' in a unix domain socket must be" " URL encoded: %s" % host_part
- )
- path_part = unquote_plus(path_part)
- else:
- host_part, _, path_part = scheme_free.partition("/")
-
- if not path_part and "?" in host_part:
- raise InvalidURI("A '/' is required between " "the host list and any options.")
-
- nodelist = []
- if "," in host_part:
- hosts = host_part.split(",")
- else:
- hosts = [host_part]
- for host in hosts:
- match = _HOST_MATCH.match(host)
- if not match:
- raise ValueError(
- "Reserved characters such as ':' must be escaped according RFC "
- "2396. An IPv6 address literal must be enclosed in '[' and ']' "
- "according to RFC 2732."
- )
- host = match.group(2)
- if host.startswith("[") and host.endswith("]"):
- host = host[1:-1]
-
- port = match.group(4)
- if port:
- try:
- port = int(port)
- if port < 0 or port > 65535:
- raise ValueError()
- except ValueError as err:
- raise ValueError(
- "Port must be an integer between 0 and 65535:", port
- ) from err
- else:
- port = default_port
-
- nodelist.append((host, port))
-
- if path_part and path_part[0] != "?":
- dbase, _, _ = path_part.partition("?")
- if "." in dbase:
- dbase, _ = dbase.split(".", 1)
-
- if dbase is not None:
- dbase = unquote_plus(dbase)
-
- return {"nodelist": tuple(nodelist), "database": dbase}
-
-
-def split_hosts(hosts, default_port=27017):
- """Split the entity into a list of tuples of host and port."""
-
- nodelist = []
- for entity in hosts.split(","):
- port = default_port
- if entity.endswith(".sock"):
- port = None
-
- match = _SIMPLE_HOST_MATCH.match(entity)
- if not match:
- raise ValueError(
- "Reserved characters such as ':' must be escaped according RFC "
- "2396. An IPv6 address literal must be enclosed in '[' and ']' "
- "according to RFC 2732."
- )
- host = match.group(1)
- if host.startswith("[") and host.endswith("]"):
- host = host[1:-1]
-
- if match.group(3):
- try:
- port = int(match.group(3))
- if port < 0 or port > 65535:
- raise ValueError()
- except ValueError as err:
- raise ValueError(
- "Port must be an integer between 0 and 65535:", port
- ) from err
-
- nodelist.append((host, port))
-
- return nodelist
-
-
-_LAST_TIMESTAMP_INC = []
-
-
-def get_current_timestamp():
- """Get the current timestamp as a bson Timestamp object."""
- if not Timestamp:
- raise NotImplementedError(
- "timestamp is not supported. Import pymongo to use it."
- )
- now = int(time.time())
- if _LAST_TIMESTAMP_INC and _LAST_TIMESTAMP_INC[0] == now:
- _LAST_TIMESTAMP_INC[1] += 1
- else:
- del _LAST_TIMESTAMP_INC[:]
- _LAST_TIMESTAMP_INC.extend([now, 1])
- return Timestamp(now, _LAST_TIMESTAMP_INC[1])
-
-
-def patch_datetime_awareness_in_document(value):
- # MongoDB is supposed to stock everything as timezone naive utc date
- # Hence we have to convert incoming datetimes to avoid errors while
- # mixing tz aware and naive.
- # On top of that, MongoDB date precision is up to millisecond, where Python
- # datetime use microsecond, so we must lower the precision to mimic mongo.
- for best_type in (OrderedDict, dict):
- if isinstance(value, best_type):
- return best_type(
- (k, patch_datetime_awareness_in_document(v)) for k, v in value.items()
- )
- if isinstance(value, (tuple, list)):
- return [patch_datetime_awareness_in_document(item) for item in value]
- if isinstance(value, datetime):
- mongo_us = (value.microsecond // 1000) * 1000
- if value.tzinfo:
- return (value - value.utcoffset()).replace(
- tzinfo=None, microsecond=mongo_us
- )
- return value.replace(microsecond=mongo_us)
- if Timestamp and isinstance(value, Timestamp) and not value.time and not value.inc:
- return get_current_timestamp()
- return value
-
-
-def make_datetime_timezone_aware_in_document(value):
- # MongoClient support tz_aware=True parameter to return timezone-aware
- # datetime objects. Given the date is stored internally without timezone
- # information, all returned datetime have utc as timezone.
- if isinstance(value, dict):
- return {
- k: make_datetime_timezone_aware_in_document(v) for k, v in value.items()
- }
- if isinstance(value, (tuple, list)):
- return [make_datetime_timezone_aware_in_document(item) for item in value]
- if isinstance(value, datetime):
- return value.replace(tzinfo=utc)
- return value
-
-
-def get_value_by_dot(doc, key, can_generate_array=False):
- """Get dictionary value using dotted key"""
- result = doc
- key_items = key.split(".")
- for key_index, key_item in enumerate(key_items):
- if isinstance(result, dict):
- result = result[key_item]
-
- elif isinstance(result, (list, tuple)):
- try:
- int_key = int(key_item)
- except ValueError as err:
- if not can_generate_array:
- raise KeyError(key_index) from err
- remaining_key = ".".join(key_items[key_index:])
- return [get_value_by_dot(subdoc, remaining_key) for subdoc in result]
-
- try:
- result = result[int_key]
- except (ValueError, IndexError) as err:
- raise KeyError(key_index) from err
-
- else:
- raise KeyError(key_index)
-
- return result
-
-
-def set_value_by_dot(doc, key, value):
- """Set dictionary value using dotted key"""
- try:
- parent_key, child_key = key.rsplit(".", 1)
- parent = get_value_by_dot(doc, parent_key)
- except ValueError:
- child_key = key
- parent = doc
-
- if isinstance(parent, dict):
- parent[child_key] = value
- elif isinstance(parent, (list, tuple)):
- try:
- parent[int(child_key)] = value
- except (ValueError, IndexError) as err:
- raise KeyError() from err
- else:
- raise KeyError()
-
- return doc
-
-
-def delete_value_by_dot(doc, key):
- """Delete dictionary value using dotted key.
-
- This function assumes that the value exists.
- """
- try:
- parent_key, child_key = key.rsplit(".", 1)
- parent = get_value_by_dot(doc, parent_key)
- except ValueError:
- child_key = key
- parent = doc
-
- del parent[child_key]
-
- return doc
-
-
-def mongodb_to_bool(value):
- """Converts any value to bool the way MongoDB does it"""
-
- return value not in [False, None, 0]
diff --git a/packages/syft/tests/mongomock/mongo_client.py b/packages/syft/tests/mongomock/mongo_client.py
deleted file mode 100644
index 560a7ce0f11..00000000000
--- a/packages/syft/tests/mongomock/mongo_client.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# stdlib
-import itertools
-import warnings
-
-# third party
-from packaging import version
-
-# relative
-from . import ConfigurationError
-from . import codec_options as mongomock_codec_options
-from . import helpers
-from . import read_preferences
-from .database import Database
-from .store import ServerStore
-
-try:
- # third party
- from pymongo import ReadPreference
- from pymongo.uri_parser import parse_uri
- from pymongo.uri_parser import split_hosts
-
- _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY
-except ImportError:
- # relative
- from .helpers import parse_uri
- from .helpers import split_hosts
-
- _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY
-
-
-def _convert_version_to_list(version_str):
- pieces = [int(part) for part in version_str.split(".")]
- return pieces + [0] * (4 - len(pieces))
-
-
-class MongoClient(object):
- HOST = "localhost"
- PORT = 27017
- _CONNECTION_ID = itertools.count()
-
- def __init__(
- self,
- host=None,
- port=None,
- document_class=dict,
- tz_aware=False,
- connect=True,
- _store=None,
- read_preference=None,
- uuidRepresentation=None,
- type_registry=None,
- **kwargs,
- ):
- if host:
- self.host = host[0] if isinstance(host, (list, tuple)) else host
- else:
- self.host = self.HOST
- self.port = port or self.PORT
-
- self._tz_aware = tz_aware
- self._codec_options = mongomock_codec_options.CodecOptions(
- tz_aware=tz_aware,
- uuid_representation=uuidRepresentation,
- type_registry=type_registry,
- )
- self._database_accesses = {}
- self._store = _store or ServerStore()
- self._id = next(self._CONNECTION_ID)
- self._document_class = document_class
- if read_preference is not None:
- read_preferences.ensure_read_preference_type(
- "read_preference", read_preference
- )
- self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY
-
- dbase = None
-
- if "://" in self.host:
- res = parse_uri(self.host, default_port=self.port, warn=True)
- self.host, self.port = res["nodelist"][0]
- dbase = res["database"]
- else:
- self.host, self.port = split_hosts(self.host, default_port=self.port)[0]
-
- self.__default_database_name = dbase
- # relative
- from . import SERVER_VERSION
-
- self._server_version = SERVER_VERSION
-
- def __getitem__(self, db_name):
- return self.get_database(db_name)
-
- def __getattr__(self, attr):
- return self[attr]
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
- def __repr__(self):
- return "mongomock.MongoClient('{0}', {1})".format(self.host, self.port)
-
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- return self.address == other.address
- return NotImplemented
-
- if helpers.PYMONGO_VERSION >= version.parse("3.12"):
-
- def __hash__(self):
- return hash(self.address)
-
- def close(self):
- pass
-
- @property
- def is_mongos(self):
- return True
-
- @property
- def is_primary(self):
- return True
-
- @property
- def address(self):
- return self.host, self.port
-
- @property
- def read_preference(self):
- return self._read_preference
-
- @property
- def codec_options(self):
- return self._codec_options
-
- def server_info(self):
- return {
- "version": self._server_version,
- "sysInfo": "Mock",
- "versionArray": _convert_version_to_list(self._server_version),
- "bits": 64,
- "debug": False,
- "maxBsonObjectSize": 16777216,
- "ok": 1,
- }
-
- if helpers.PYMONGO_VERSION < version.parse("4.0"):
-
- def database_names(self):
- warnings.warn(
- "database_names is deprecated. Use list_database_names instead."
- )
- return self.list_database_names()
-
- def list_database_names(self):
- return self._store.list_created_database_names()
-
- def drop_database(self, name_or_db):
- def drop_collections_for_db(_db):
- db_store = self._store[_db.name]
- for col_name in db_store.list_created_collection_names():
- _db.drop_collection(col_name)
-
- if isinstance(name_or_db, Database):
- db = next(db for db in self._database_accesses.values() if db is name_or_db)
- if db:
- drop_collections_for_db(db)
-
- elif name_or_db in self._store:
- db = self.get_database(name_or_db)
- drop_collections_for_db(db)
-
- def get_database(
- self,
- name=None,
- codec_options=None,
- read_preference=None,
- write_concern=None,
- read_concern=None,
- ):
- if name is None:
- db = self.get_default_database(
- codec_options=codec_options,
- read_preference=read_preference,
- write_concern=write_concern,
- read_concern=read_concern,
- )
- else:
- db = self._database_accesses.get(name)
- if db is None:
- db_store = self._store[name]
- db = self._database_accesses[name] = Database(
- self,
- name,
- read_preference=read_preference or self.read_preference,
- codec_options=codec_options or self._codec_options,
- _store=db_store,
- read_concern=read_concern,
- )
- return db
-
- def get_default_database(self, default=None, **kwargs):
- name = self.__default_database_name
- name = name if name is not None else default
- if name is None:
- raise ConfigurationError("No default database name defined or provided.")
-
- return self.get_database(name=name, **kwargs)
-
- def alive(self):
- """The original MongoConnection.alive method checks the status of the server.
-
- In our case as we mock the actual server, we should always return True.
- """
- return True
-
- def start_session(self, causal_consistency=True, default_transaction_options=None):
- """Start a logical session."""
- raise NotImplementedError("Mongomock does not support sessions yet")
diff --git a/packages/syft/tests/mongomock/not_implemented.py b/packages/syft/tests/mongomock/not_implemented.py
deleted file mode 100644
index 990b89a411e..00000000000
--- a/packages/syft/tests/mongomock/not_implemented.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""Module to handle features that are not implemented yet."""
-
-_IGNORED_FEATURES = {
- "array_filters": False,
- "collation": False,
- "let": False,
- "session": False,
-}
-
-
-def _ensure_ignorable_feature(feature):
- if feature not in _IGNORED_FEATURES:
- raise KeyError(
- "%s is not an error that can be ignored: maybe it has been implemented in Mongomock. "
- "Here is the list of features that can be ignored: %s"
- % (feature, _IGNORED_FEATURES.keys())
- )
-
-
-def ignore_feature(feature):
- """Ignore a feature instead of raising a NotImplementedError."""
- _ensure_ignorable_feature(feature)
- _IGNORED_FEATURES[feature] = True
-
-
-def warn_on_feature(feature):
- """Rasie a NotImplementedError the next times a feature is used."""
- _ensure_ignorable_feature(feature)
- _IGNORED_FEATURES[feature] = False
-
-
-def raise_for_feature(feature, reason):
- _ensure_ignorable_feature(feature)
- if _IGNORED_FEATURES[feature]:
- return False
- raise NotImplementedError(reason)
diff --git a/packages/syft/tests/mongomock/object_id.py b/packages/syft/tests/mongomock/object_id.py
deleted file mode 100644
index 281e8b02663..00000000000
--- a/packages/syft/tests/mongomock/object_id.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# stdlib
-import uuid
-
-
-class ObjectId(object):
- def __init__(self, id=None):
- super(ObjectId, self).__init__()
- if id is None:
- self._id = uuid.uuid1()
- else:
- self._id = uuid.UUID(id)
-
- def __eq__(self, other):
- return isinstance(other, ObjectId) and other._id == self._id
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- return hash(self._id)
-
- def __repr__(self):
- return "ObjectId({0})".format(self._id)
-
- def __str__(self):
- return str(self._id)
diff --git a/packages/syft/tests/mongomock/patch.py b/packages/syft/tests/mongomock/patch.py
deleted file mode 100644
index 8db36497e75..00000000000
--- a/packages/syft/tests/mongomock/patch.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# stdlib
-import time
-
-# relative
-from .mongo_client import MongoClient
-
-try:
- # stdlib
- from unittest import mock
-
- _IMPORT_MOCK_ERROR = None
-except ImportError:
- try:
- # third party
- import mock
-
- _IMPORT_MOCK_ERROR = None
- except ImportError as error:
- _IMPORT_MOCK_ERROR = error
-
-try:
- # third party
- import pymongo
- from pymongo.uri_parser import parse_uri
- from pymongo.uri_parser import split_hosts
-
- _IMPORT_PYMONGO_ERROR = None
-except ImportError as error:
- # relative
- from .helpers import parse_uri
- from .helpers import split_hosts
-
- _IMPORT_PYMONGO_ERROR = error
-
-
-def _parse_any_host(host, default_port=27017):
- if isinstance(host, tuple):
- return _parse_any_host(host[0], host[1])
- if "://" in host:
- return parse_uri(host, warn=True)["nodelist"]
- return split_hosts(host, default_port=default_port)
-
-
-def patch(servers="localhost", on_new="error"):
- """Patch pymongo.MongoClient.
-
- This will patch the class MongoClient and use mongomock to mock MongoDB
- servers. It keeps a consistant state of servers across multiple clients so
- you can do:
-
- ```
- client = pymongo.MongoClient(host='localhost', port=27017)
- client.db.coll.insert_one({'name': 'Pascal'})
-
- other_client = pymongo.MongoClient('mongodb://localhost:27017')
- client.db.coll.find_one()
- ```
-
- The data is persisted as long as the patch lives.
-
- Args:
- on_new: Behavior when accessing a new server (not in servers):
- 'create': mock a new empty server, accept any client connection.
- 'error': raise a ValueError immediately when trying to access.
- 'timeout': behave as pymongo when a server does not exist, raise an
- error after a timeout.
- 'pymongo': use an actual pymongo client.
- servers: a list of server that are avaiable.
- """
-
- if _IMPORT_MOCK_ERROR:
- raise _IMPORT_MOCK_ERROR # pylint: disable=raising-bad-type
-
- if _IMPORT_PYMONGO_ERROR:
- PyMongoClient = None
- else:
- PyMongoClient = pymongo.MongoClient
-
- persisted_clients = {}
- parsed_servers = set()
- for server in servers if isinstance(servers, (list, tuple)) else [servers]:
- parsed_servers.update(_parse_any_host(server))
-
- def _create_persistent_client(*args, **kwargs):
- if _IMPORT_PYMONGO_ERROR:
- raise _IMPORT_PYMONGO_ERROR # pylint: disable=raising-bad-type
-
- client = MongoClient(*args, **kwargs)
-
- try:
- persisted_client = persisted_clients[client.address]
- client._store = persisted_client._store
- return client
- except KeyError:
- pass
-
- if client.address in parsed_servers or on_new == "create":
- persisted_clients[client.address] = client
- return client
-
- if on_new == "timeout":
- # TODO(pcorpet): Only wait when trying to access the server's data.
- time.sleep(kwargs.get("serverSelectionTimeoutMS", 30000))
- raise pymongo.errors.ServerSelectionTimeoutError(
- "%s:%d: [Errno 111] Connection refused" % client.address
- )
-
- if on_new == "pymongo":
- return PyMongoClient(*args, **kwargs)
-
- raise ValueError(
- "MongoDB server %s:%d does not exist.\n" % client.address
- + "%s" % parsed_servers
- )
-
- class _PersistentClient:
- def __new__(cls, *args, **kwargs):
- return _create_persistent_client(*args, **kwargs)
-
- return mock.patch("pymongo.MongoClient", _PersistentClient)
diff --git a/packages/syft/tests/mongomock/read_concern.py b/packages/syft/tests/mongomock/read_concern.py
deleted file mode 100644
index 229e0f78bb4..00000000000
--- a/packages/syft/tests/mongomock/read_concern.py
+++ /dev/null
@@ -1,21 +0,0 @@
-class ReadConcern(object):
- def __init__(self, level=None):
- self._document = {}
-
- if level is not None:
- self._document["level"] = level
-
- @property
- def level(self):
- return self._document.get("level")
-
- @property
- def ok_for_legacy(self):
- return True
-
- @property
- def document(self):
- return self._document.copy()
-
- def __eq__(self, other):
- return other.document == self.document
diff --git a/packages/syft/tests/mongomock/read_preferences.py b/packages/syft/tests/mongomock/read_preferences.py
deleted file mode 100644
index a9349e6c576..00000000000
--- a/packages/syft/tests/mongomock/read_preferences.py
+++ /dev/null
@@ -1,42 +0,0 @@
-class _Primary(object):
- @property
- def mongos_mode(self):
- return "primary"
-
- @property
- def mode(self):
- return 0
-
- @property
- def name(self):
- return "Primary"
-
- @property
- def document(self):
- return {"mode": "primary"}
-
- @property
- def tag_sets(self):
- return [{}]
-
- @property
- def max_staleness(self):
- return -1
-
- @property
- def min_wire_version(self):
- return 0
-
-
-def ensure_read_preference_type(key, value):
- """Raise a TypeError if the value is not a type compatible for ReadPreference."""
- for attr in ("document", "mode", "mongos_mode", "max_staleness"):
- if not hasattr(value, attr):
- raise TypeError(
- "{} must be an instance of {}".format(
- key, "pymongo.read_preference.ReadPreference"
- )
- )
-
-
-PRIMARY = _Primary()
diff --git a/packages/syft/tests/mongomock/results.py b/packages/syft/tests/mongomock/results.py
deleted file mode 100644
index 07633a6e82e..00000000000
--- a/packages/syft/tests/mongomock/results.py
+++ /dev/null
@@ -1,117 +0,0 @@
-try:
- # third party
- from pymongo.results import BulkWriteResult
- from pymongo.results import DeleteResult
- from pymongo.results import InsertManyResult
- from pymongo.results import InsertOneResult
- from pymongo.results import UpdateResult
-except ImportError:
-
- class _WriteResult(object):
- def __init__(self, acknowledged=True):
- self.__acknowledged = acknowledged
-
- @property
- def acknowledged(self):
- return self.__acknowledged
-
- class InsertOneResult(_WriteResult):
- __slots__ = ("__inserted_id", "__acknowledged")
-
- def __init__(self, inserted_id, acknowledged=True):
- self.__inserted_id = inserted_id
- super(InsertOneResult, self).__init__(acknowledged)
-
- @property
- def inserted_id(self):
- return self.__inserted_id
-
- class InsertManyResult(_WriteResult):
- __slots__ = ("__inserted_ids", "__acknowledged")
-
- def __init__(self, inserted_ids, acknowledged=True):
- self.__inserted_ids = inserted_ids
- super(InsertManyResult, self).__init__(acknowledged)
-
- @property
- def inserted_ids(self):
- return self.__inserted_ids
-
- class UpdateResult(_WriteResult):
- __slots__ = ("__raw_result", "__acknowledged")
-
- def __init__(self, raw_result, acknowledged=True):
- self.__raw_result = raw_result
- super(UpdateResult, self).__init__(acknowledged)
-
- @property
- def raw_result(self):
- return self.__raw_result
-
- @property
- def matched_count(self):
- if self.upserted_id is not None:
- return 0
- return self.__raw_result.get("n", 0)
-
- @property
- def modified_count(self):
- return self.__raw_result.get("nModified")
-
- @property
- def upserted_id(self):
- return self.__raw_result.get("upserted")
-
- class DeleteResult(_WriteResult):
- __slots__ = ("__raw_result", "__acknowledged")
-
- def __init__(self, raw_result, acknowledged=True):
- self.__raw_result = raw_result
- super(DeleteResult, self).__init__(acknowledged)
-
- @property
- def raw_result(self):
- return self.__raw_result
-
- @property
- def deleted_count(self):
- return self.__raw_result.get("n", 0)
-
- class BulkWriteResult(_WriteResult):
- __slots__ = ("__bulk_api_result", "__acknowledged")
-
- def __init__(self, bulk_api_result, acknowledged):
- self.__bulk_api_result = bulk_api_result
- super(BulkWriteResult, self).__init__(acknowledged)
-
- @property
- def bulk_api_result(self):
- return self.__bulk_api_result
-
- @property
- def inserted_count(self):
- return self.__bulk_api_result.get("nInserted")
-
- @property
- def matched_count(self):
- return self.__bulk_api_result.get("nMatched")
-
- @property
- def modified_count(self):
- return self.__bulk_api_result.get("nModified")
-
- @property
- def deleted_count(self):
- return self.__bulk_api_result.get("nRemoved")
-
- @property
- def upserted_count(self):
- return self.__bulk_api_result.get("nUpserted")
-
- @property
- def upserted_ids(self):
- if self.__bulk_api_result:
- return dict(
- (upsert["index"], upsert["_id"])
- for upsert in self.bulk_api_result["upserted"]
- )
diff --git a/packages/syft/tests/mongomock/store.py b/packages/syft/tests/mongomock/store.py
deleted file mode 100644
index 9cef7206329..00000000000
--- a/packages/syft/tests/mongomock/store.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# stdlib
-import collections
-import datetime
-import functools
-
-# relative
-from .helpers import utcnow
-from .thread import RWLock
-
-
-class ServerStore(object):
- """Object holding the data for a whole server (many databases)."""
-
- def __init__(self):
- self._databases = {}
-
- def __getitem__(self, db_name):
- try:
- return self._databases[db_name]
- except KeyError:
- db = self._databases[db_name] = DatabaseStore()
- return db
-
- def __contains__(self, db_name):
- return self[db_name].is_created
-
- def list_created_database_names(self):
- return [name for name, db in self._databases.items() if db.is_created]
-
-
-class DatabaseStore(object):
- """Object holding the data for a database (many collections)."""
-
- def __init__(self):
- self._collections = {}
-
- def __getitem__(self, col_name):
- try:
- return self._collections[col_name]
- except KeyError:
- col = self._collections[col_name] = CollectionStore(col_name)
- return col
-
- def __contains__(self, col_name):
- return self[col_name].is_created
-
- def list_created_collection_names(self):
- return [name for name, col in self._collections.items() if col.is_created]
-
- def create_collection(self, name):
- col = self[name]
- col.create()
- return col
-
- def rename(self, name, new_name):
- col = self._collections.pop(name, CollectionStore(new_name))
- col.name = new_name
- self._collections[new_name] = col
-
- @property
- def is_created(self):
- return any(col.is_created for col in self._collections.values())
-
-
-class CollectionStore(object):
- """Object holding the data for a collection."""
-
- def __init__(self, name):
- self._documents = collections.OrderedDict()
- self.indexes = {}
- self._is_force_created = False
- self.name = name
- self._ttl_indexes = {}
-
- # 694 - Lock for safely iterating and mutating OrderedDicts
- self._rwlock = RWLock()
-
- def create(self):
- self._is_force_created = True
-
- @property
- def is_created(self):
- return self._documents or self.indexes or self._is_force_created
-
- def drop(self):
- self._documents = collections.OrderedDict()
- self.indexes = {}
- self._ttl_indexes = {}
- self._is_force_created = False
-
- def create_index(self, index_name, index_dict):
- self.indexes[index_name] = index_dict
- if index_dict.get("expireAfterSeconds") is not None:
- self._ttl_indexes[index_name] = index_dict
-
- def drop_index(self, index_name):
- self._remove_expired_documents()
-
- # The main index object should raise a KeyError, but the
- # TTL indexes have no meaning to the outside.
- del self.indexes[index_name]
- self._ttl_indexes.pop(index_name, None)
-
- @property
- def is_empty(self):
- self._remove_expired_documents()
- return not self._documents
-
- def __contains__(self, key):
- self._remove_expired_documents()
- with self._rwlock.reader():
- return key in self._documents
-
- def __getitem__(self, key):
- self._remove_expired_documents()
- with self._rwlock.reader():
- return self._documents[key]
-
- def __setitem__(self, key, val):
- with self._rwlock.writer():
- self._documents[key] = val
-
- def __delitem__(self, key):
- with self._rwlock.writer():
- del self._documents[key]
-
- def __len__(self):
- self._remove_expired_documents()
- with self._rwlock.reader():
- return len(self._documents)
-
- @property
- def documents(self):
- self._remove_expired_documents()
- with self._rwlock.reader():
- for doc in self._documents.values():
- yield doc
-
- def _remove_expired_documents(self):
- for index in self._ttl_indexes.values():
- self._expire_documents(index)
-
- def _expire_documents(self, index):
- # TODO(juannyg): use a caching mechanism to avoid re-expiring the documents if
- # we just did and no document was added / updated
-
- # Ignore non-integer values
- try:
- expiry = int(index["expireAfterSeconds"])
- except ValueError:
- return
-
- # Ignore commpound keys
- if len(index["key"]) > 1:
- return
-
- # "key" structure = list of (field name, direction) tuples
- ttl_field_name = next(iter(index["key"]))[0]
- ttl_now = utcnow()
-
- with self._rwlock.reader():
- expired_ids = [
- doc["_id"]
- for doc in self._documents.values()
- if self._value_meets_expiry(doc.get(ttl_field_name), expiry, ttl_now)
- ]
-
- for exp_id in expired_ids:
- del self[exp_id]
-
- def _value_meets_expiry(self, val, expiry, ttl_now):
- val_to_compare = _get_min_datetime_from_value(val)
- try:
- return (ttl_now - val_to_compare).total_seconds() >= expiry
- except TypeError:
- return False
-
-
-def _get_min_datetime_from_value(val):
- if not val:
- return datetime.datetime.max
- if isinstance(val, list):
- return functools.reduce(_min_dt, [datetime.datetime.max] + val)
- return val
-
-
-def _min_dt(dt1, dt2):
- try:
- return dt1 if dt1 < dt2 else dt2
- except TypeError:
- return dt1
diff --git a/packages/syft/tests/mongomock/thread.py b/packages/syft/tests/mongomock/thread.py
deleted file mode 100644
index ff673e44309..00000000000
--- a/packages/syft/tests/mongomock/thread.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# stdlib
-from contextlib import contextmanager
-import threading
-
-
-class RWLock:
- """Lock enabling multiple readers but only 1 exclusive writer
-
- Source: https://cutt.ly/Ij70qaq
- """
-
- def __init__(self):
- self._read_switch = _LightSwitch()
- self._write_switch = _LightSwitch()
- self._no_readers = threading.Lock()
- self._no_writers = threading.Lock()
- self._readers_queue = threading.RLock()
-
- @contextmanager
- def reader(self):
- self._reader_acquire()
- try:
- yield
- except Exception: # pylint: disable=W0706
- raise
- finally:
- self._reader_release()
-
- @contextmanager
- def writer(self):
- self._writer_acquire()
- try:
- yield
- except Exception: # pylint: disable=W0706
- raise
- finally:
- self._writer_release()
-
- def _reader_acquire(self):
- """Readers should block whenever a writer has acquired"""
- self._readers_queue.acquire()
- self._no_readers.acquire()
- self._read_switch.acquire(self._no_writers)
- self._no_readers.release()
- self._readers_queue.release()
-
- def _reader_release(self):
- self._read_switch.release(self._no_writers)
-
- def _writer_acquire(self):
- """Acquire the writer lock.
-
- Only the first writer will lock the readtry and then
- all subsequent writers can simply use the resource as
- it gets freed by the previous writer. The very last writer must
- release the readtry semaphore, thus opening the gate for readers
- to try reading.
-
- No reader can engage in the entry section if the readtry semaphore
- has been set by a writer previously
- """
- self._write_switch.acquire(self._no_readers)
- self._no_writers.acquire()
-
- def _writer_release(self):
- self._no_writers.release()
- self._write_switch.release(self._no_readers)
-
-
-class _LightSwitch:
- """An auxiliary "light switch"-like object
-
- The first thread turns on the "switch", the last one turns it off.
-
- Source: https://cutt.ly/Ij70qaq
- """
-
- def __init__(self):
- self._counter = 0
- self._mutex = threading.RLock()
-
- def acquire(self, lock):
- self._mutex.acquire()
- self._counter += 1
- if self._counter == 1:
- lock.acquire()
- self._mutex.release()
-
- def release(self, lock):
- self._mutex.acquire()
- self._counter -= 1
- if self._counter == 0:
- lock.release()
- self._mutex.release()
diff --git a/packages/syft/tests/mongomock/write_concern.py b/packages/syft/tests/mongomock/write_concern.py
deleted file mode 100644
index 93760445647..00000000000
--- a/packages/syft/tests/mongomock/write_concern.py
+++ /dev/null
@@ -1,45 +0,0 @@
-def _with_default_values(document):
- if "w" in document:
- return document
- return dict(document, w=1)
-
-
-class WriteConcern(object):
- def __init__(self, w=None, wtimeout=None, j=None, fsync=None):
- self._document = {}
- if w is not None:
- self._document["w"] = w
- if wtimeout is not None:
- self._document["wtimeout"] = wtimeout
- if j is not None:
- self._document["j"] = j
- if fsync is not None:
- self._document["fsync"] = fsync
-
- def __eq__(self, other):
- try:
- return _with_default_values(other.document) == _with_default_values(
- self.document
- )
- except AttributeError:
- return NotImplemented
-
- def __ne__(self, other):
- try:
- return _with_default_values(other.document) != _with_default_values(
- self.document
- )
- except AttributeError:
- return NotImplemented
-
- @property
- def acknowledged(self):
- return True
-
- @property
- def document(self):
- return self._document.copy()
-
- @property
- def is_server_default(self):
- return not self._document
diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py
index 39d6b1871bc..851a83cb7c2 100644
--- a/packages/syft/tests/syft/action_test.py
+++ b/packages/syft/tests/syft/action_test.py
@@ -23,20 +23,18 @@
def test_actionobject_method(worker):
root_datasite_client = worker.root_client
assert root_datasite_client.settings.enable_eager_execution(enable=True)
- action_store = worker.services.action.store
+ action_store = worker.services.action.stash
obj = ActionObject.from_obj("abc")
pointer = obj.send(root_datasite_client)
- assert len(action_store.data) == 1
+ assert len(action_store._data) == 1
res = pointer.capitalize()
- assert len(action_store.data) == 2
+ assert len(action_store._data) == 2
assert res[0] == "A"
-@pytest.mark.parametrize("delete_original_admin", [False, True])
def test_new_admin_has_action_object_permission(
worker: Worker,
faker: Faker,
- delete_original_admin: bool,
) -> None:
root_client = worker.root_client
@@ -60,10 +58,6 @@ def test_new_admin_has_action_object_permission(
root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN)
- if delete_original_admin:
- res = root_client.api.services.user.delete(root_client.account.id)
- assert not isinstance(res, SyftError)
-
assert admin.api.services.action.get(obj.id) == obj
@@ -75,7 +69,7 @@ def test_lib_function_action(worker):
assert isinstance(res, ActionObject)
assert all(res == np.array([0, 0, 0]))
- assert len(worker.services.action.store.data) > 0
+ assert len(worker.services.action.stash._data) > 0
def test_call_lib_function_action2(worker):
@@ -90,7 +84,7 @@ def test_lib_class_init_action(worker):
assert isinstance(res, ActionObject)
assert res == np.float32(4.0)
- assert len(worker.services.action.store.data) > 0
+ assert len(worker.services.action.stash._data) > 0
def test_call_lib_wo_permission(worker):
diff --git a/packages/syft/tests/syft/api_test.py b/packages/syft/tests/syft/api_test.py
index 6c511b45d48..ca2f61ac147 100644
--- a/packages/syft/tests/syft/api_test.py
+++ b/packages/syft/tests/syft/api_test.py
@@ -45,7 +45,7 @@ def test_api_cache_invalidation_login(root_verify_key, worker):
name="q", email="a@b.org", password="aaa", password_verify="aaa"
)
guest_client = guest_client.login(email="a@b.org", password="aaa")
- user_id = worker.document_store.partitions["User"].all(root_verify_key).value[-1].id
+ user_id = worker.root_client.users[-1].id
def get_role(verify_key):
users = worker.services.user.stash.get_all(root_verify_key).ok()
diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py
index 8b8613498fb..47e33f7926d 100644
--- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py
+++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py
@@ -1,6 +1,5 @@
# stdlib
import io
-import random
# third party
import numpy as np
@@ -42,10 +41,7 @@ def test_blob_storage_allocate(authed_context, blob_storage):
assert isinstance(blob_deposit, BlobDeposit)
-def test_blob_storage_write():
- random.seed()
- name = "".join(str(random.randint(0, 9)) for i in range(8))
- worker = sy.Worker.named(name=name)
+def test_blob_storage_write(worker):
blob_storage = worker.services.blob_storage
authed_context = AuthedServiceContext(
server=worker, credentials=worker.signing_key.verify_key
@@ -60,10 +56,7 @@ def test_blob_storage_write():
worker.cleanup()
-def test_blob_storage_write_syft_object():
- random.seed()
- name = "".join(str(random.randint(0, 9)) for i in range(8))
- worker = sy.Worker.named(name=name)
+def test_blob_storage_write_syft_object(worker):
blob_storage = worker.services.blob_storage
authed_context = AuthedServiceContext(
server=worker, credentials=worker.signing_key.verify_key
@@ -78,10 +71,7 @@ def test_blob_storage_write_syft_object():
worker.cleanup()
-def test_blob_storage_read():
- random.seed()
- name = "".join(str(random.randint(0, 9)) for i in range(8))
- worker = sy.Worker.named(name=name)
+def test_blob_storage_read(worker):
blob_storage = worker.services.blob_storage
authed_context = AuthedServiceContext(
server=worker, credentials=worker.signing_key.verify_key
diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py
index d177aaa508e..bfec6e00895 100644
--- a/packages/syft/tests/syft/dataset/dataset_stash_test.py
+++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py
@@ -1,55 +1,16 @@
# third party
import pytest
-from typeguard import TypeCheckError
# syft absolute
from syft.service.dataset.dataset import Dataset
-from syft.service.dataset.dataset_stash import ActionIDsPartitionKey
-from syft.service.dataset.dataset_stash import NamePartitionKey
-from syft.store.document_store import QueryKey
+from syft.service.dataset.dataset_stash import DatasetStash
from syft.store.document_store_errors import NotFoundException
from syft.types.uid import UID
-def test_dataset_namepartitionkey() -> None:
- mock_obj = "dummy_name_key"
-
- assert NamePartitionKey.key == "name"
- assert NamePartitionKey.type_ == str
-
- name_partition_key = NamePartitionKey.with_obj(obj=mock_obj)
-
- assert isinstance(name_partition_key, QueryKey)
- assert name_partition_key.key == "name"
- assert name_partition_key.type_ == str
- assert name_partition_key.value == mock_obj
-
- with pytest.raises(AttributeError):
- NamePartitionKey.with_obj(obj=[UID()])
-
-
-def test_dataset_actionidpartitionkey() -> None:
- mock_obj = [UID() for _ in range(3)]
-
- assert ActionIDsPartitionKey.key == "action_ids"
- assert ActionIDsPartitionKey.type_ == list[UID]
-
- action_ids_partition_key = ActionIDsPartitionKey.with_obj(obj=mock_obj)
-
- assert isinstance(action_ids_partition_key, QueryKey)
- assert action_ids_partition_key.key == "action_ids"
- assert action_ids_partition_key.type_ == list[UID]
- assert action_ids_partition_key.value == mock_obj
-
- with pytest.raises(AttributeError):
- ActionIDsPartitionKey.with_obj(obj="dummy_str")
-
- # Not sure what Exception should be raised here, Type or Attibute
- with pytest.raises(TypeCheckError):
- ActionIDsPartitionKey.with_obj(obj=["first_str", "second_str"])
-
-
-def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) -> None:
+def test_dataset_get_by_name(
+ root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset: Dataset
+) -> None:
# retrieving existing dataset
result = mock_dataset_stash.get_by_name(root_verify_key, mock_dataset.name)
assert result.is_ok()
@@ -63,7 +24,9 @@ def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset)
assert type(result.err()) is NotFoundException
-def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dataset):
+def test_dataset_search_action_ids(
+ root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset
+):
action_id = mock_dataset.assets[0].action_id
result = mock_dataset_stash.search_action_ids(root_verify_key, uid=action_id)
@@ -72,12 +35,6 @@ def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dat
assert isinstance(result.ok()[0], Dataset)
assert result.ok()[0].id == mock_dataset.id
- # retrieving dataset by list of action_ids
- result = mock_dataset_stash.search_action_ids(root_verify_key, uid=[action_id])
- assert result.is_ok()
- assert isinstance(result.ok()[0], Dataset)
- assert result.ok()[0].id == mock_dataset.id
-
# retrieving dataset by non-existing action_id
other_action_id = UID()
result = mock_dataset_stash.search_action_ids(root_verify_key, uid=other_action_id)
@@ -86,5 +43,5 @@ def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dat
# passing random object
random_obj = object()
- with pytest.raises(AttributeError):
+ with pytest.raises(ValueError):
result = mock_dataset_stash.search_action_ids(root_verify_key, uid=random_obj)
diff --git a/packages/syft/tests/syft/dataset/fixtures.py b/packages/syft/tests/syft/dataset/fixtures.py
index 9c062e756bc..bcb26bff262 100644
--- a/packages/syft/tests/syft/dataset/fixtures.py
+++ b/packages/syft/tests/syft/dataset/fixtures.py
@@ -60,7 +60,9 @@ def mock_asset(worker, root_datasite_client) -> Asset:
@pytest.fixture
-def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset:
+def mock_dataset(
+ root_verify_key, mock_dataset_stash: DatasetStash, mock_asset
+) -> Dataset:
uploader = Contributor(
role=str(Roles.UPLOADER),
name="test",
@@ -70,7 +72,7 @@ def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset:
id=UID(), name="test_dataset", uploader=uploader, contributors=[uploader]
)
mock_dataset.asset_list.append(mock_asset)
- result = mock_dataset_stash.partition.set(root_verify_key, mock_dataset)
+ result = mock_dataset_stash.set(root_verify_key, mock_dataset)
mock_dataset = result.ok()
yield mock_dataset
diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py
index 18fd4b85394..243a18130a2 100644
--- a/packages/syft/tests/syft/eager_test.py
+++ b/packages/syft/tests/syft/eager_test.py
@@ -174,7 +174,7 @@ def test_setattribute(worker, guest_client):
obj_pointer.dtype = np.int32
# local object is updated
- assert obj_pointer.id.id in worker.action_store.data
+ assert obj_pointer.id.id in worker.action_store._data
assert obj_pointer.id != original_id
res = root_datasite_client.api.services.action.get(obj_pointer.id)
@@ -206,7 +206,7 @@ def test_getattribute(worker, guest_client):
size_pointer = obj_pointer.size
# check result
- assert size_pointer.id.id in worker.action_store.data
+ assert size_pointer.id.id in worker.action_store._data
assert root_datasite_client.api.services.action.get(size_pointer.id) == 6
@@ -226,7 +226,7 @@ def test_eager_method(worker, guest_client):
flat_pointer = obj_pointer.flatten()
- assert flat_pointer.id.id in worker.action_store.data
+ assert flat_pointer.id.id in worker.action_store._data
# check result
assert all(
root_datasite_client.api.services.action.get(flat_pointer.id)
@@ -250,7 +250,7 @@ def test_eager_dunder_method(worker, guest_client):
first_row_pointer = obj_pointer[0]
- assert first_row_pointer.id.id in worker.action_store.data
+ assert first_row_pointer.id.id in worker.action_store._data
# check result
assert all(
root_datasite_client.api.services.action.get(first_row_pointer.id)
diff --git a/packages/syft/tests/syft/migrations/data_migration_test.py b/packages/syft/tests/syft/migrations/data_migration_test.py
index 708c56ac75a..a5203e2a0f8 100644
--- a/packages/syft/tests/syft/migrations/data_migration_test.py
+++ b/packages/syft/tests/syft/migrations/data_migration_test.py
@@ -115,7 +115,7 @@ def test_get_migration_data(worker, tmp_path):
@contextmanager
def named_worker_context(name):
# required to launch worker with same name twice within the same test + ensure cleanup
- worker = sy.Worker.named(name=name)
+ worker = sy.Worker.named(name=name, db_url="sqlite://")
try:
yield worker
finally:
diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py
index 059c729d921..64c670d5ea3 100644
--- a/packages/syft/tests/syft/migrations/protocol_communication_test.py
+++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py
@@ -20,6 +20,7 @@
from syft.service.service import ServiceConfigRegistry
from syft.service.service import service_method
from syft.service.user.user_roles import GUEST_ROLE_LEVEL
+from syft.store.db.db import DBManager
from syft.store.document_store import DocumentStore
from syft.store.document_store import NewBaseStash
from syft.store.document_store import PartitionSettings
@@ -85,7 +86,7 @@ class SyftMockObjectStash(NewBaseStash):
object_type=syft_object,
)
- def __init__(self, store: DocumentStore) -> None:
+ def __init__(self, store: DBManager) -> None:
super().__init__(store=store)
return SyftMockObjectStash
@@ -103,8 +104,7 @@ class SyftMockObjectService(AbstractService):
stash: stash_klass # type: ignore
__module__: str = "syft.test"
- def __init__(self, store: DocumentStore) -> None:
- self.store = store
+ def __init__(self, store: DBManager) -> None:
self.stash = stash_klass(store=store)
@service_method(
diff --git a/packages/syft/tests/syft/network_test.py b/packages/syft/tests/syft/network_test.py
new file mode 100644
index 00000000000..3bb4b5e84e3
--- /dev/null
+++ b/packages/syft/tests/syft/network_test.py
@@ -0,0 +1,31 @@
+# syft absolute
+from syft.abstract_server import ServerType
+from syft.server.credentials import SyftSigningKey
+from syft.service.network.network_service import NetworkStash
+from syft.service.network.server_peer import ServerPeer
+from syft.service.network.server_peer import ServerPeerUpdate
+from syft.types.uid import UID
+
+
+def test_add_route() -> None:
+ uid = UID()
+ peer = ServerPeer(
+ id=uid,
+ name="test",
+ verify_key=SyftSigningKey.generate().verify_key,
+ server_type=ServerType.DATASITE,
+ admin_email="info@openmined.org",
+ )
+ network_stash = NetworkStash.random()
+
+ network_stash.set(
+ credentials=network_stash.db.root_verify_key,
+ obj=peer,
+ ).unwrap()
+ peer_update = ServerPeerUpdate(id=uid, name="new name")
+ peer = network_stash.update(
+ credentials=network_stash.db.root_verify_key,
+ obj=peer_update,
+ ).unwrap()
+
+ assert peer.name == "new name"
diff --git a/packages/syft/tests/syft/notifications/notification_service_test.py b/packages/syft/tests/syft/notifications/notification_service_test.py
index f48e77ab97d..a8319d32a80 100644
--- a/packages/syft/tests/syft/notifications/notification_service_test.py
+++ b/packages/syft/tests/syft/notifications/notification_service_test.py
@@ -144,20 +144,12 @@ def test_get_all_success(
NotificationStatus.UNREAD,
)
- @as_result(StashException)
- def mock_get_all_inbox_for_verify_key(*args, **kwargs) -> list[Notification]:
- return [expected_message]
-
- monkeypatch.setattr(
- notification_service.stash,
- "get_all_inbox_for_verify_key",
- mock_get_all_inbox_for_verify_key,
- )
-
response = test_notification_service.get_all(authed_context)
assert len(response) == 1
assert isinstance(response[0], Notification)
+ response[0].syft_client_verify_key = None
+ response[0].syft_server_location = None
assert response[0] == expected_message
@@ -188,9 +180,6 @@ def mock_get_all_inbox_for_verify_key(
def test_get_sent_success(
- root_verify_key,
- monkeypatch: MonkeyPatch,
- notification_service: NotificationService,
authed_context: AuthedServiceContext,
document_store: DocumentStore,
) -> None:
@@ -207,20 +196,12 @@ def test_get_sent_success(
NotificationStatus.UNREAD,
)
- @as_result(StashException)
- def mock_get_all_sent_for_verify_key(credentials, verify_key) -> list[Notification]:
- return [expected_message]
-
- monkeypatch.setattr(
- notification_service.stash,
- "get_all_sent_for_verify_key",
- mock_get_all_sent_for_verify_key,
- )
-
response = test_notification_service.get_all_sent(authed_context)
assert len(response) == 1
assert isinstance(response[0], Notification)
+ response[0].syft_server_location = None
+ response[0].syft_client_verify_key = None
assert response[0] == expected_message
@@ -340,19 +321,12 @@ def test_get_all_read_success(
NotificationStatus.READ,
)
- def mock_get_all_by_verify_key_for_status() -> list[Notification]:
- return [expected_message]
-
- monkeypatch.setattr(
- notification_service.stash,
- "get_all_by_verify_key_for_status",
- mock_get_all_by_verify_key_for_status,
- )
-
response = test_notification_service.get_all_read(authed_context)
assert len(response) == 1
assert isinstance(response[0], Notification)
+ response[0].syft_server_location = None
+ response[0].syft_client_verify_key = None
assert response[0] == expected_message
@@ -404,19 +378,11 @@ def test_get_all_unread_success(
NotificationStatus.UNREAD,
)
- @as_result(StashException)
- def mock_get_all_by_verify_key_for_status() -> list[Notification]:
- return [expected_message]
-
- monkeypatch.setattr(
- notification_service.stash,
- "get_all_by_verify_key_for_status",
- mock_get_all_by_verify_key_for_status,
- )
-
response = test_notification_service.get_all_unread(authed_context)
assert len(response) == 1
assert isinstance(response[0], Notification)
+ response[0].syft_server_location = None
+ response[0].syft_client_verify_key = None
assert response[0] == expected_message
diff --git a/packages/syft/tests/syft/notifications/notification_stash_test.py b/packages/syft/tests/syft/notifications/notification_stash_test.py
index b848324a2b7..7864fb10d19 100644
--- a/packages/syft/tests/syft/notifications/notification_stash_test.py
+++ b/packages/syft/tests/syft/notifications/notification_stash_test.py
@@ -8,13 +8,7 @@
# syft absolute
from syft.server.credentials import SyftSigningKey
from syft.server.credentials import SyftVerifyKey
-from syft.service.notification.notification_stash import (
- OrderByCreatedAtTimeStampPartitionKey,
-)
-from syft.service.notification.notification_stash import FromUserVerifyKeyPartitionKey
from syft.service.notification.notification_stash import NotificationStash
-from syft.service.notification.notification_stash import StatusPartitionKey
-from syft.service.notification.notification_stash import ToUserVerifyKeyPartitionKey
from syft.service.notification.notifications import Notification
from syft.service.notification.notifications import NotificationExpiryStatus
from syft.service.notification.notifications import NotificationStatus
@@ -60,86 +54,14 @@ def add_mock_notification(
return mock_notification
-def test_fromuserverifykey_partitionkey() -> None:
- random_verify_key = SyftSigningKey.generate().verify_key
-
- assert FromUserVerifyKeyPartitionKey.type_ == SyftVerifyKey
- assert FromUserVerifyKeyPartitionKey.key == "from_user_verify_key"
-
- result = FromUserVerifyKeyPartitionKey.with_obj(random_verify_key)
-
- assert result.type_ == SyftVerifyKey
- assert result.key == "from_user_verify_key"
-
- assert result.value == random_verify_key
-
- signing_key = SyftSigningKey.from_string(test_signing_key_string)
- with pytest.raises(AttributeError):
- FromUserVerifyKeyPartitionKey.with_obj(signing_key)
-
-
-def test_touserverifykey_partitionkey() -> None:
- random_verify_key = SyftSigningKey.generate().verify_key
-
- assert ToUserVerifyKeyPartitionKey.type_ == SyftVerifyKey
- assert ToUserVerifyKeyPartitionKey.key == "to_user_verify_key"
-
- result = ToUserVerifyKeyPartitionKey.with_obj(random_verify_key)
-
- assert result.type_ == SyftVerifyKey
- assert result.key == "to_user_verify_key"
- assert result.value == random_verify_key
-
- signing_key = SyftSigningKey.from_string(test_signing_key_string)
- with pytest.raises(AttributeError):
- ToUserVerifyKeyPartitionKey.with_obj(signing_key)
-
-
-def test_status_partitionkey() -> None:
- assert StatusPartitionKey.key == "status"
- assert StatusPartitionKey.type_ == NotificationStatus
-
- result1 = StatusPartitionKey.with_obj(NotificationStatus.UNREAD)
- result2 = StatusPartitionKey.with_obj(NotificationStatus.READ)
-
- assert result1.type_ == NotificationStatus
- assert result1.key == "status"
- assert result1.value == NotificationStatus.UNREAD
- assert result2.type_ == NotificationStatus
- assert result2.key == "status"
- assert result2.value == NotificationStatus.READ
-
- notification_expiry_status_auto = NotificationExpiryStatus(0)
-
- with pytest.raises(AttributeError):
- StatusPartitionKey.with_obj(notification_expiry_status_auto)
-
-
-def test_orderbycreatedattimestamp_partitionkey() -> None:
- random_datetime = DateTime.now()
-
- assert OrderByCreatedAtTimeStampPartitionKey.key == "created_at"
- assert OrderByCreatedAtTimeStampPartitionKey.type_ == DateTime
-
- result = OrderByCreatedAtTimeStampPartitionKey.with_obj(random_datetime)
-
- assert result.type_ == DateTime
- assert result.key == "created_at"
- assert result.value == random_datetime
-
-
def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None:
random_signing_key = SyftSigningKey.generate()
random_verify_key = random_signing_key.verify_key
test_stash = NotificationStash(store=document_store)
- response = test_stash.get_all_inbox_for_verify_key(
+ result = test_stash.get_all_inbox_for_verify_key(
root_verify_key, random_verify_key
- )
-
- assert response.is_ok()
-
- result = response.ok()
+ ).unwrap()
assert len(result) == 0
# list of mock notifications
@@ -152,14 +74,11 @@ def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None:
notification_list.append(mock_notification)
# returned list of notifications from stash that's sorted by created_at
- response2 = test_stash.get_all_inbox_for_verify_key(
+ result = test_stash.get_all_inbox_for_verify_key(
root_verify_key, random_verify_key
- )
+ ).unwrap()
- assert response2.is_ok()
-
- result = response2.ok()
- assert len(response2.value) == 5
+ assert len(result) == 5
for notification in notification_list:
# check if all notifications are present in the result
@@ -205,12 +124,9 @@ def test_get_all_sent_for_verify_key(root_verify_key, document_store) -> None:
def test_get_all_for_verify_key(root_verify_key, document_store) -> None:
random_signing_key = SyftSigningKey.generate()
random_verify_key = random_signing_key.verify_key
- query_key = FromUserVerifyKeyPartitionKey.with_obj(test_verify_key)
test_stash = NotificationStash(store=document_store)
- response = test_stash.get_all_for_verify_key(
- root_verify_key, random_verify_key, query_key
- )
+ response = test_stash.get_all_for_verify_key(root_verify_key, random_verify_key)
assert response.is_ok()
@@ -221,11 +137,8 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None:
root_verify_key, test_stash, test_verify_key, random_verify_key
)
- query_key2 = FromUserVerifyKeyPartitionKey.with_obj(
- mock_notification.from_user_verify_key
- )
response_from_verify_key = test_stash.get_all_for_verify_key(
- root_verify_key, mock_notification.from_user_verify_key, query_key2
+ root_verify_key, mock_notification.from_user_verify_key
)
assert response_from_verify_key.is_ok()
@@ -235,7 +148,7 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None:
assert result[0] == mock_notification
response_from_verify_key_string = test_stash.get_all_for_verify_key(
- root_verify_key, test_verify_key_string, query_key2
+ root_verify_key, test_verify_key_string
)
assert response_from_verify_key_string.is_ok()
@@ -249,28 +162,21 @@ def test_get_all_by_verify_key_for_status(root_verify_key, document_store) -> No
random_verify_key = random_signing_key.verify_key
test_stash = NotificationStash(store=document_store)
- response = test_stash.get_all_by_verify_key_for_status(
+ result = test_stash.get_all_by_verify_key_for_status(
root_verify_key, random_verify_key, NotificationStatus.READ
- )
-
- assert response.is_ok()
-
- result = response.ok()
+ ).unwrap()
assert len(result) == 0
mock_notification = add_mock_notification(
root_verify_key, test_stash, test_verify_key, random_verify_key
)
- response2 = test_stash.get_all_by_verify_key_for_status(
+ result2 = test_stash.get_all_by_verify_key_for_status(
root_verify_key, mock_notification.to_user_verify_key, NotificationStatus.UNREAD
- )
- assert response2.is_ok()
+ ).unwrap()
+ assert len(result2) == 1
- result = response2.ok()
- assert len(result) == 1
-
- assert result[0] == mock_notification
+ assert result2[0] == mock_notification
with pytest.raises(AttributeError):
test_stash.get_all_by_verify_key_for_status(
@@ -288,7 +194,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None:
root_verify_key, uid=random_uid, status=NotificationStatus.READ
).unwrap()
- assert exc.type is SyftException
+ assert issubclass(exc.type, SyftException)
assert exc.value.public_message
mock_notification = add_mock_notification(
@@ -314,7 +220,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None:
status=notification_expiry_status_auto,
).unwrap()
- assert exc.type is SyftException
+ assert issubclass(exc.type, SyftException)
assert exc.value.public_message
@@ -326,6 +232,10 @@ def test_update_notification_status_error_on_get_by_uid(
test_stash = NotificationStash(store=document_store)
expected_error_msg = f"No notification exists for id: {random_verify_key}"
+ add_mock_notification(
+ root_verify_key, test_stash, test_verify_key, random_verify_key
+ )
+
@as_result(StashException)
def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn:
raise StashException(public_message=f"No notification exists for id: {uid}")
@@ -335,11 +245,6 @@ def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn:
"get_by_uid",
mock_get_by_uid,
)
-
- add_mock_notification(
- root_verify_key, test_stash, test_verify_key, random_verify_key
- )
-
with pytest.raises(StashException) as exc:
test_stash.update_notification_status(
root_verify_key, random_verify_key, NotificationStatus.READ
@@ -354,11 +259,9 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None:
random_verify_key = random_signing_key.verify_key
test_stash = NotificationStash(store=document_store)
- response = test_stash.delete_all_for_verify_key(root_verify_key, test_verify_key)
-
- assert response.is_ok()
-
- result = response.ok()
+ result = test_stash.delete_all_for_verify_key(
+ root_verify_key, test_verify_key
+ ).unwrap()
assert result is True
add_mock_notification(
@@ -367,23 +270,23 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None:
inbox_before = test_stash.get_all_inbox_for_verify_key(
root_verify_key, random_verify_key
- ).value
+ ).unwrap()
assert len(inbox_before) == 1
- response2 = test_stash.delete_all_for_verify_key(root_verify_key, random_verify_key)
-
- assert response2.is_ok()
-
- result = response2.ok()
- assert result is True
+ result2 = test_stash.delete_all_for_verify_key(
+ root_verify_key, random_verify_key
+ ).unwrap()
+ assert result2 is True
inbox_after = test_stash.get_all_inbox_for_verify_key(
root_verify_key, random_verify_key
- ).value
+ ).unwrap()
assert len(inbox_after) == 0
with pytest.raises(AttributeError):
- test_stash.delete_all_for_verify_key(root_verify_key, random_signing_key)
+ test_stash.delete_all_for_verify_key(
+ root_verify_key, random_signing_key
+ ).unwrap()
def test_delete_all_for_verify_key_error_on_get_all_inbox_for_verify_key(
diff --git a/packages/syft/tests/syft/request/request_stash_test.py b/packages/syft/tests/syft/request/request_stash_test.py
index a492c2f6b9f..a9115d5c934 100644
--- a/packages/syft/tests/syft/request/request_stash_test.py
+++ b/packages/syft/tests/syft/request/request_stash_test.py
@@ -1,9 +1,4 @@
-# stdlib
-from typing import NoReturn
-
# third party
-import pytest
-from pytest import MonkeyPatch
# syft absolute
from syft.client.client import SyftClient
@@ -12,10 +7,6 @@
from syft.service.request.request import Request
from syft.service.request.request import SubmitRequest
from syft.service.request.request_stash import RequestStash
-from syft.service.request.request_stash import RequestingUserVerifyKeyPartitionKey
-from syft.store.document_store import PartitionKey
-from syft.store.document_store import QueryKeys
-from syft.types.errors import SyftException
def test_requeststash_get_all_for_verify_key_no_requests(
@@ -33,7 +24,6 @@ def test_requeststash_get_all_for_verify_key_no_requests(
assert len(requests.ok()) == 0
-@pytest.mark.xfail
def test_requeststash_get_all_for_verify_key_success(
root_verify_key,
request_stash: RequestStash,
@@ -77,60 +67,3 @@ def test_requeststash_get_all_for_verify_key_success(
requests.ok()[1] == stash_set_result_2.ok()
or requests.ok()[0] == stash_set_result_2.ok()
)
-
-
-def test_requeststash_get_all_for_verify_key_fail(
- root_verify_key,
- request_stash: RequestStash,
- monkeypatch: MonkeyPatch,
- guest_datasite_client: SyftClient,
-) -> None:
- verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key
- mock_error_message = (
- "verify key not in the document store's unique or searchable keys"
- )
-
- def mock_query_all_error(
- credentials: SyftVerifyKey, qks: QueryKeys, order_by: PartitionKey | None
- ) -> NoReturn:
- raise SyftException(public_message=mock_error_message)
-
- monkeypatch.setattr(request_stash, "query_all", mock_query_all_error)
-
- with pytest.raises(SyftException) as exc:
- request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap()
-
- assert exc.type is SyftException
- assert exc.value.public_message == mock_error_message
-
-
-def test_requeststash_get_all_for_verify_key_find_index_fail(
- root_verify_key,
- request_stash: RequestStash,
- monkeypatch: MonkeyPatch,
- guest_datasite_client: SyftClient,
-) -> None:
- verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key
- qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)])
-
- mock_error_message = f"Failed to query index or search with {qks.all[0]}"
-
- def mock_find_index_or_search_keys_error(
- credentials: SyftVerifyKey,
- index_qks: QueryKeys,
- search_qks: QueryKeys,
- order_by: PartitionKey | None,
- ) -> NoReturn:
- raise SyftException(public_message=mock_error_message)
-
- monkeypatch.setattr(
- request_stash.partition,
- "find_index_or_search_keys",
- mock_find_index_or_search_keys_error,
- )
-
- with pytest.raises(SyftException) as exc:
- request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap()
-
- assert exc.type == SyftException
- assert exc.value.public_message == mock_error_message
diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py
index 76fd4d82685..a57a61e1509 100644
--- a/packages/syft/tests/syft/service/action/action_object_test.py
+++ b/packages/syft/tests/syft/service/action/action_object_test.py
@@ -506,15 +506,15 @@ def test_actionobject_syft_get_path(testcase):
def test_actionobject_syft_send_get(worker, testcase):
root_datasite_client = worker.root_client
root_datasite_client._fetch_api(root_datasite_client.credentials)
- action_store = worker.services.action.store
+ action_store = worker.services.action.stash
orig_obj = testcase
obj = helper_make_action_obj(orig_obj)
- assert len(action_store.data) == 0
+ assert len(action_store._data) == 0
ptr = obj.send(root_datasite_client)
- assert len(action_store.data) == 1
+ assert len(action_store._data) == 1
retrieved = ptr.get()
assert obj.syft_action_data == retrieved
@@ -1001,7 +1001,7 @@ def test_actionobject_syft_getattr_float_history():
@pytest.mark.skipif(
- sys.platform != "linux",
+ sys.platform == "win32",
reason="This is a hackish way to test attribute set/get, and it might fail on Windows or OSX",
)
def test_actionobject_syft_getattr_np(worker):
diff --git a/packages/syft/tests/syft/service/action/action_service_test.py b/packages/syft/tests/syft/service/action/action_service_test.py
index 5a6b5561d99..e97362a6340 100644
--- a/packages/syft/tests/syft/service/action/action_service_test.py
+++ b/packages/syft/tests/syft/service/action/action_service_test.py
@@ -22,6 +22,6 @@ def test_action_service_sanity(worker):
obj = ActionObject.from_obj("abc")
pointer = obj.send(root_datasite_client)
- assert len(service.store.data) == 1
+ assert len(service.stash._data) == 1
res = pointer.capitalize()
assert res[0] == "A"
diff --git a/packages/syft/tests/syft/service_permission_test.py b/packages/syft/tests/syft/service_permission_test.py
index afc266005f3..ceb6d63923c 100644
--- a/packages/syft/tests/syft/service_permission_test.py
+++ b/packages/syft/tests/syft/service_permission_test.py
@@ -9,12 +9,8 @@
@pytest.fixture
def guest_mock_user(root_verify_key, user_stash, guest_user):
- result = user_stash.partition.set(root_verify_key, guest_user)
- assert result.is_ok()
-
- user = result.ok()
+ user = user_stash.set(root_verify_key, guest_user).unwrap()
assert user is not None
-
yield user
diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py
index aaa7b0460fc..7555aadd91e 100644
--- a/packages/syft/tests/syft/settings/settings_service_test.py
+++ b/packages/syft/tests/syft/settings/settings_service_test.py
@@ -34,7 +34,6 @@
from syft.store.document_store_errors import NotFoundException
from syft.store.document_store_errors import StashException
from syft.types.errors import SyftException
-from syft.types.result import Ok
from syft.types.result import as_result
@@ -100,37 +99,22 @@ def test_settingsservice_set_success(
) -> None:
response = settings_service.set(authed_context, settings)
assert isinstance(response, ServerSettings)
+ response.syft_client_verify_key = None
+ response.syft_server_location = None
+ response.pwd_token_config.syft_client_verify_key = None
+ response.pwd_token_config.syft_server_location = None
+ response.welcome_markdown.syft_client_verify_key = None
+ response.welcome_markdown.syft_server_location = None
assert response == settings
-def test_settingsservice_set_fail(
- monkeypatch: MonkeyPatch,
- settings_service: SettingsService,
- settings: ServerSettings,
- authed_context: AuthedServiceContext,
-) -> None:
- mock_error_message = "database failure"
-
- @as_result(StashException)
- def mock_stash_set_error(credentials, settings: ServerSettings) -> NoReturn:
- raise StashException(public_message=mock_error_message)
-
- monkeypatch.setattr(settings_service.stash, "set", mock_stash_set_error)
-
- with pytest.raises(StashException) as exc:
- settings_service.set(authed_context, settings)
-
- assert exc.type == StashException
- assert exc.value.public_message == mock_error_message
-
-
def add_mock_settings(
root_verify_key: SyftVerifyKey,
settings_stash: SettingsStash,
settings: ServerSettings,
) -> ServerSettings:
# create a mock settings in the stash so that we can update it
- result = settings_stash.partition.set(root_verify_key, settings)
+ result = settings_stash.set(root_verify_key, settings)
assert result.is_ok()
created_settings = result.ok()
@@ -150,9 +134,7 @@ def test_settingsservice_update_success(
notifier_stash: NotifierStash,
) -> None:
# add a mock settings to the stash
- mock_settings = add_mock_settings(
- authed_context.credentials, settings_stash, settings
- )
+ mock_settings = settings_stash.set(authed_context.credentials, settings).unwrap()
# get a new settings according to update_settings
new_settings = deepcopy(settings)
@@ -164,14 +146,6 @@ def test_settingsservice_update_success(
assert new_settings != mock_settings
assert mock_settings == settings
- mock_stash_get_all_output = [mock_settings, mock_settings]
-
- def mock_stash_get_all(root_verify_key) -> Ok:
- return Ok(mock_stash_get_all_output)
-
- monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all)
-
- # Mock the get_service method to return a mocked notifier_service with the notifier_stash
class MockNotifierService:
def __init__(self, stash):
self.stash = stash
@@ -194,33 +168,7 @@ def mock_get_service(service_name: str):
# update the settings in the settings stash using settings_service
response = settings_service.update(context=authed_context, settings=update_settings)
- # not_updated_settings = response.ok()[1]
-
assert isinstance(response, SyftSuccess)
- # assert (
- # not_updated_settings.to_dict() == settings.to_dict()
- # ) # the second settings is not updated
-
-
-def test_settingsservice_update_stash_get_all_fail(
- monkeypatch: MonkeyPatch,
- settings_service: SettingsService,
- update_settings: ServerSettingsUpdate,
- authed_context: AuthedServiceContext,
-) -> None:
- mock_error_message = "database failure"
-
- @as_result(StashException)
- def mock_stash_get_all_error(credentials) -> NoReturn:
- raise StashException(public_message=mock_error_message)
-
- monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all_error)
-
- with pytest.raises(StashException) as exc:
- settings_service.update(context=authed_context, settings=update_settings)
-
- assert exc.type == StashException
- assert exc.value.public_message == mock_error_message
def test_settingsservice_update_stash_empty(
@@ -230,9 +178,7 @@ def test_settingsservice_update_stash_empty(
) -> None:
with pytest.raises(NotFoundException) as exc:
settings_service.update(context=authed_context, settings=update_settings)
-
- assert exc.type == NotFoundException
- assert exc.value.public_message == "Server settings not found"
+ assert exc.value.public_message == "Server settings not found"
def test_settingsservice_update_fail(
@@ -248,7 +194,7 @@ def test_settingsservice_update_fail(
mock_stash_get_all_output = [settings, settings]
@as_result(StashException)
- def mock_stash_get_all(credentials) -> list[ServerSettings]:
+ def mock_stash_get_all(credentials, **kwargs) -> list[ServerSettings]:
return mock_stash_get_all_output
monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all)
@@ -256,7 +202,7 @@ def mock_stash_get_all(credentials) -> list[ServerSettings]:
mock_update_error_message = "Failed to update obj ServerMetadata"
@as_result(StashException)
- def mock_stash_update_error(credentials, settings: ServerSettings) -> NoReturn:
+ def mock_stash_update_error(credentials, obj: ServerSettings) -> NoReturn:
raise StashException(public_message=mock_update_error_message)
monkeypatch.setattr(settings_service.stash, "update", mock_stash_update_error)
@@ -309,7 +255,7 @@ def test_settings_allow_guest_registration(
new_callable=mock.PropertyMock,
return_value=mock_server_settings,
):
- worker = syft.Worker.named(name=faker.name(), reset=True)
+ worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://")
guest_datasite_client = worker.guest_client
root_datasite_client = worker.root_client
@@ -343,7 +289,7 @@ def test_settings_allow_guest_registration(
new_callable=mock.PropertyMock,
return_value=mock_server_settings,
):
- worker = syft.Worker.named(name=faker.name(), reset=True)
+ worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://")
guest_datasite_client = worker.guest_client
root_datasite_client = worker.root_client
@@ -402,7 +348,7 @@ def get_mock_client(faker, root_client, role):
new_callable=mock.PropertyMock,
return_value=mock_server_settings,
):
- worker = syft.Worker.named(name=faker.name(), reset=True)
+ worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://")
root_client = worker.root_client
emails_added = []
diff --git a/packages/syft/tests/syft/settings/settings_stash_test.py b/packages/syft/tests/syft/settings/settings_stash_test.py
index 3fbbd28f9e9..2d976b52108 100644
--- a/packages/syft/tests/syft/settings/settings_stash_test.py
+++ b/packages/syft/tests/syft/settings/settings_stash_test.py
@@ -1,54 +1,26 @@
-# third party
-
# syft absolute
from syft.service.settings.settings import ServerSettings
from syft.service.settings.settings import ServerSettingsUpdate
from syft.service.settings.settings_stash import SettingsStash
-def add_mock_settings(
- root_verify_key, settings_stash: SettingsStash, settings: ServerSettings
-) -> ServerSettings:
- # prepare: add mock settings
- result = settings_stash.partition.set(root_verify_key, settings)
- assert result.is_ok()
-
- created_settings = result.ok()
- assert created_settings is not None
-
- return created_settings
-
-
def test_settingsstash_set(
- root_verify_key, settings_stash: SettingsStash, settings: ServerSettings
-) -> None:
- result = settings_stash.set(root_verify_key, settings)
- assert result.is_ok()
-
- created_settings = result.ok()
- assert isinstance(created_settings, ServerSettings)
- assert created_settings == settings
- assert settings.id in settings_stash.partition.data
-
-
-def test_settingsstash_update(
root_verify_key,
settings_stash: SettingsStash,
settings: ServerSettings,
update_settings: ServerSettingsUpdate,
) -> None:
- # prepare: add a mock settings
- mock_settings = add_mock_settings(root_verify_key, settings_stash, settings)
+ created_settings = settings_stash.set(root_verify_key, settings).unwrap()
+ assert isinstance(created_settings, ServerSettings)
+ assert created_settings == settings
+ assert settings_stash.exists(root_verify_key, settings.id)
# update mock_settings according to update_settings
update_kwargs = update_settings.to_dict(exclude_empty=True).items()
for field_name, value in update_kwargs:
- setattr(mock_settings, field_name, value)
+ setattr(settings, field_name, value)
# update the settings in the stash
- result = settings_stash.update(root_verify_key, settings=mock_settings)
-
- assert result.is_ok()
- updated_settings = result.ok()
+ updated_settings = settings_stash.update(root_verify_key, obj=settings).unwrap()
assert isinstance(updated_settings, ServerSettings)
- assert mock_settings == updated_settings
+ assert settings == updated_settings
diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py
index 375204908c1..5c2fe63be0d 100644
--- a/packages/syft/tests/syft/stores/action_store_test.py
+++ b/packages/syft/tests/syft/stores/action_store_test.py
@@ -1,24 +1,26 @@
# stdlib
-import sys
-from typing import Any
# third party
import pytest
# syft absolute
+from syft.server.credentials import SyftSigningKey
from syft.server.credentials import SyftVerifyKey
+from syft.service.action.action_object import ActionObject
+from syft.service.action.action_permissions import ActionObjectOWNER
+from syft.service.action.action_permissions import ActionObjectPermission
from syft.service.action.action_store import ActionObjectEXECUTE
-from syft.service.action.action_store import ActionObjectOWNER
from syft.service.action.action_store import ActionObjectREAD
+from syft.service.action.action_store import ActionObjectStash
from syft.service.action.action_store import ActionObjectWRITE
+from syft.service.user.user import User
+from syft.service.user.user_roles import ServiceRole
+from syft.service.user.user_stash import UserStash
+from syft.store.db.db import DBManager
from syft.types.uid import UID
# relative
-from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN
-from .store_constants_test import TEST_VERIFY_KEY_STRING_CLIENT
-from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER
-from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT
-from .store_mocks_test import MockSyftObject
+from ..worker_test import action_object_stash # noqa: F401
permissions = [
ActionObjectOWNER,
@@ -28,134 +30,119 @@
]
-@pytest.mark.parametrize(
- "store",
- [
- pytest.lazy_fixture("dict_action_store"),
- pytest.lazy_fixture("sqlite_action_store"),
- pytest.lazy_fixture("mongo_action_store"),
- ],
-)
-def test_action_store_sanity(store: Any):
- assert hasattr(store, "store_config")
- assert hasattr(store, "settings")
- assert hasattr(store, "data")
- assert hasattr(store, "permissions")
- assert hasattr(store, "root_verify_key")
- assert store.root_verify_key.verify == TEST_VERIFY_KEY_STRING_ROOT
+def add_user(db_manager: DBManager, role: ServiceRole) -> SyftVerifyKey:
+ user_stash = UserStash(store=db_manager)
+ verify_key = SyftSigningKey.generate().verify_key
+ user_stash.set(
+ credentials=db_manager.root_verify_key,
+ obj=User(verify_key=verify_key, role=role, id=UID()),
+ ).unwrap()
+ return verify_key
+
+
+def add_test_object(
+ stash: ActionObjectStash, verify_key: SyftVerifyKey
+) -> ActionObject:
+ test_object = ActionObject.from_obj([1, 2, 3])
+ uid = test_object.id
+ stash.set_or_update(
+ uid=uid,
+ credentials=verify_key,
+ syft_object=test_object,
+ has_result_read_permission=True,
+ ).unwrap()
+ return uid
@pytest.mark.parametrize(
- "store",
+ "stash",
[
- pytest.lazy_fixture("dict_action_store"),
- pytest.lazy_fixture("sqlite_action_store"),
- pytest.lazy_fixture("mongo_action_store"),
+ pytest.lazy_fixture("action_object_stash"),
],
)
@pytest.mark.parametrize("permission", permissions)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-@pytest.mark.skipif(sys.platform == "darwin", reason="skip on mac")
-def test_action_store_test_permissions(store: Any, permission: Any):
- client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT)
- root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT)
- hacker_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
- new_admin_key = TEST_VERIFY_KEY_NEW_ADMIN
-
- access = permission(uid=UID(), credentials=client_key)
- access_root = permission(uid=UID(), credentials=root_key)
- access_hacker = permission(uid=UID(), credentials=hacker_key)
- access_new_admin = permission(uid=UID(), credentials=new_admin_key)
-
- # add permission
- store.add_permission(access)
-
- assert store.has_permission(access)
- assert store.has_permission(access_root)
- assert store.has_permission(access_new_admin)
- assert not store.has_permission(access_hacker)
+def test_action_store_test_permissions(
+ stash: ActionObjectStash, permission: ActionObjectPermission
+) -> None:
+ client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST)
+ root_key = add_user(stash.db, ServiceRole.ADMIN)
+ hacker_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST)
+ new_admin_key = add_user(stash.db, ServiceRole.ADMIN)
+
+ test_item_id = add_test_object(stash, client_key)
+
+ access = permission(uid=test_item_id, credentials=client_key)
+ access_root = permission(uid=test_item_id, credentials=root_key)
+ access_hacker = permission(uid=test_item_id, credentials=hacker_key)
+ access_new_admin = permission(uid=test_item_id, credentials=new_admin_key)
+
+ stash.add_permission(access)
+ assert stash.has_permission(access)
+ assert stash.has_permission(access_root)
+ assert stash.has_permission(access_new_admin)
+ assert not stash.has_permission(access_hacker)
# remove permission
- store.remove_permission(access)
+ stash.remove_permission(access)
- assert not store.has_permission(access)
- assert store.has_permission(access_root)
- assert store.has_permission(access_new_admin)
- assert not store.has_permission(access_hacker)
+ assert not stash.has_permission(access)
+ assert stash.has_permission(access_root)
+ assert stash.has_permission(access_new_admin)
+ assert not stash.has_permission(access_hacker)
# take ownership with new UID
- client_uid2 = UID()
- access = permission(uid=client_uid2, credentials=client_key)
+ item2_id = add_test_object(stash, client_key)
+ access = permission(uid=item2_id, credentials=client_key)
- store.take_ownership(client_uid2, client_key)
- assert store.has_permission(access)
- assert store.has_permission(access_root)
- assert store.has_permission(access_new_admin)
- assert not store.has_permission(access_hacker)
+ stash.add_permission(ActionObjectREAD(uid=item2_id, credentials=client_key))
+ assert stash.has_permission(access)
+ assert stash.has_permission(access_root)
+ assert stash.has_permission(access_new_admin)
+ assert not stash.has_permission(access_hacker)
# delete UID as hacker
- access_hacker_ro = ActionObjectREAD(uid=UID(), credentials=hacker_key)
- store.add_permission(access_hacker_ro)
- res = store.delete(client_uid2, hacker_key)
+ res = stash.delete_by_uid(hacker_key, item2_id)
assert res.is_err()
- assert store.has_permission(access)
- assert store.has_permission(access_new_admin)
- assert store.has_permission(access_hacker_ro)
+ assert stash.has_permission(access)
+ assert stash.has_permission(access_root)
+ assert stash.has_permission(access_new_admin)
+ assert not stash.has_permission(access_hacker)
# delete UID as owner
- res = store.delete(client_uid2, client_key)
+ res = stash.delete_by_uid(client_key, item2_id)
assert res.is_ok()
- assert not store.has_permission(access)
- assert store.has_permission(access_new_admin)
- assert not store.has_permission(access_hacker)
+ assert not stash.has_permission(access)
+ assert stash.has_permission(access_new_admin)
+ assert not stash.has_permission(access_hacker)
@pytest.mark.parametrize(
- "store",
+ "stash",
[
- pytest.lazy_fixture("dict_action_store"),
- pytest.lazy_fixture("sqlite_action_store"),
- pytest.lazy_fixture("mongo_action_store"),
+ pytest.lazy_fixture("action_object_stash"),
],
)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_action_store_test_dataset_get(store: Any):
- client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT)
- root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT)
- SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
+def test_action_store_test_dataset_get(stash: ActionObjectStash) -> None:
+ client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST)
+ root_key = add_user(stash.db, ServiceRole.ADMIN)
- permission_only_uid = UID()
- access = ActionObjectWRITE(uid=permission_only_uid, credentials=client_key)
- access_root = ActionObjectWRITE(uid=permission_only_uid, credentials=root_key)
- read_permission = ActionObjectREAD(uid=permission_only_uid, credentials=client_key)
+ data_uid = add_test_object(stash, client_key)
+ access = ActionObjectWRITE(uid=data_uid, credentials=client_key)
+ access_root = ActionObjectWRITE(uid=data_uid, credentials=root_key)
+ read_permission = ActionObjectREAD(uid=data_uid, credentials=client_key)
# add permission
- store.add_permission(access)
+ stash.add_permission(access)
- assert store.has_permission(access)
- assert store.has_permission(access_root)
+ assert stash.has_permission(access)
+ assert stash.has_permission(access_root)
- store.add_permission(read_permission)
- assert store.has_permission(read_permission)
+ stash.add_permission(read_permission)
+ assert stash.has_permission(read_permission)
# check that trying to get action data that doesn't exist returns an error, even if have permissions
- res = store.get(permission_only_uid, client_key)
- assert res.is_err()
-
- # add data
- data_uid = UID()
- obj = MockSyftObject(data=1)
-
- res = store.set(data_uid, client_key, obj, has_result_read_permission=True)
- assert res.is_ok()
- res = store.get(data_uid, client_key)
- assert res.is_ok()
- assert res.ok() == obj
-
- assert store.exists(data_uid)
- res = store.delete(data_uid, client_key)
- assert res.is_ok()
- res = store.delete(data_uid, client_key)
+ stash.delete_by_uid(client_key, data_uid)
+ res = stash.get(data_uid, client_key)
assert res.is_err()
diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py
index c61806fb31b..8ed9a312b86 100644
--- a/packages/syft/tests/syft/stores/base_stash_test.py
+++ b/packages/syft/tests/syft/stores/base_stash_test.py
@@ -12,15 +12,15 @@
# syft absolute
from syft.serde.serializable import serializable
-from syft.store.dict_document_store import DictDocumentStore
-from syft.store.document_store import NewBaseUIDStoreStash
+from syft.service.queue.queue_stash import Status
+from syft.service.request.request_service import RequestService
+from syft.store.db.sqlite import SQLiteDBConfig
+from syft.store.db.sqlite import SQLiteDBManager
+from syft.store.db.stash import ObjectStash
from syft.store.document_store import PartitionKey
-from syft.store.document_store import PartitionSettings
-from syft.store.document_store import QueryKey
-from syft.store.document_store import QueryKeys
-from syft.store.document_store import UIDPartitionKey
from syft.store.document_store_errors import NotFoundException
from syft.store.document_store_errors import StashException
+from syft.store.linked_obj import LinkedObject
from syft.types.errors import SyftException
from syft.types.syft_object import SyftObject
from syft.types.uid import UID
@@ -35,6 +35,8 @@ class MockObject(SyftObject):
desc: str
importance: int
value: int
+ linked_obj: LinkedObject | None = None
+ status: Status = Status.CREATED
__attr_searchable__ = ["id", "name", "desc", "importance"]
__attr_unique__ = ["id", "name"]
@@ -45,24 +47,14 @@ class MockObject(SyftObject):
ImportancePartitionKey = PartitionKey(key="importance", type_=int)
-class MockStash(NewBaseUIDStoreStash):
- object_type = MockObject
- settings = PartitionSettings(
- name=MockObject.__canonical_name__, object_type=MockObject
- )
+class MockStash(ObjectStash[MockObject]):
+ pass
def get_object_values(obj: SyftObject) -> tuple[Any]:
return tuple(obj.to_dict().values())
-def add_mock_object(root_verify_key, stash: MockStash, obj: MockObject) -> MockObject:
- result = stash.set(root_verify_key, obj)
- assert result.is_ok()
-
- return result.ok()
-
-
T = TypeVar("T")
P = ParamSpec("P")
@@ -80,7 +72,11 @@ def create_unique(
@pytest.fixture
def base_stash(root_verify_key) -> MockStash:
- yield MockStash(store=DictDocumentStore(UID(), root_verify_key))
+ config = SQLiteDBConfig()
+ db_manager = SQLiteDBManager(config, UID(), root_verify_key)
+ mock_stash = MockStash(store=db_manager)
+ db_manager.init_tables()
+ yield mock_stash
def random_sentence(faker: Faker) -> str:
@@ -119,8 +115,7 @@ def mock_objects(faker: Faker) -> list[MockObject]:
def test_basestash_set(
root_verify_key, base_stash: MockStash, mock_object: MockObject
) -> None:
- result = add_mock_object(root_verify_key, base_stash, mock_object)
-
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
assert result is not None
assert result == mock_object
@@ -132,11 +127,10 @@ def test_basestash_set_duplicate(
MockObject(**kwargs) for kwargs in multiple_object_kwargs(faker, n=2, same=True)
)
- result = base_stash.set(root_verify_key, original)
- assert result.is_ok()
+ base_stash.set(root_verify_key, original).unwrap()
- result = base_stash.set(root_verify_key, duplicate)
- assert result.is_err()
+ with pytest.raises(StashException):
+ base_stash.set(root_verify_key, duplicate).unwrap()
def test_basestash_set_duplicate_unique_key(
@@ -157,28 +151,19 @@ def test_basestash_set_duplicate_unique_key(
def test_basestash_delete(
root_verify_key, base_stash: MockStash, mock_object: MockObject
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
-
- result = base_stash.delete(
- root_verify_key, UIDPartitionKey.with_obj(mock_object.id)
- )
- assert result.is_ok()
-
- assert len(base_stash.get_all(root_verify_key).ok()) == 0
+ base_stash.set(root_verify_key, mock_object).unwrap()
+ base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap()
+ assert len(base_stash.get_all(root_verify_key).unwrap()) == 0
def test_basestash_cannot_delete_non_existent(
root_verify_key, base_stash: MockStash, mock_object: MockObject
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
random_uid = create_unique(UID, [mock_object.id])
- for result in [
- base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(random_uid)),
- base_stash.delete_by_uid(root_verify_key, random_uid),
- ]:
- result = base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(UID()))
- assert result.is_err()
+ result = base_stash.delete_by_uid(root_verify_key, random_uid)
+ assert result.is_err()
assert (
len(
@@ -193,7 +178,7 @@ def test_basestash_cannot_delete_non_existent(
def test_basestash_update(
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
updated_obj = mock_object.copy()
updated_obj.name = faker.name()
@@ -208,7 +193,7 @@ def test_basestash_update(
def test_basestash_cannot_update_non_existent(
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
updated_obj = mock_object.copy()
updated_obj.id = create_unique(UID, [mock_object.id])
@@ -227,10 +212,7 @@ def test_basestash_set_get_all(
stored_objects = base_stash.get_all(
root_verify_key,
- )
- assert stored_objects.is_ok()
-
- stored_objects = stored_objects.ok()
+ ).unwrap()
assert len(stored_objects) == len(mock_objects)
stored_objects_values = {get_object_values(obj) for obj in stored_objects}
@@ -241,11 +223,10 @@ def test_basestash_set_get_all(
def test_basestash_get_by_uid(
root_verify_key, base_stash: MockStash, mock_object: MockObject
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
- result = base_stash.get_by_uid(root_verify_key, mock_object.id)
- assert result.is_ok()
- assert result.ok() == mock_object
+ result = base_stash.get_by_uid(root_verify_key, mock_object.id).unwrap()
+ assert result == mock_object
random_uid = create_unique(UID, [mock_object.id])
bad_uid = base_stash.get_by_uid(root_verify_key, random_uid)
@@ -262,12 +243,9 @@ def test_basestash_get_by_uid(
def test_basestash_delete_by_uid(
root_verify_key, base_stash: MockStash, mock_object: MockObject
) -> None:
- add_mock_object(root_verify_key, base_stash, mock_object)
+ result = base_stash.set(root_verify_key, mock_object).unwrap()
- result = base_stash.delete_by_uid(root_verify_key, mock_object.id)
- assert result.is_ok()
-
- response = result.ok()
+ response = base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap()
assert isinstance(response, UID)
result = base_stash.get_by_uid(root_verify_key, mock_object.id)
@@ -288,43 +266,73 @@ def test_basestash_query_one(
base_stash.set(root_verify_key, obj)
obj = random.choice(mock_objects)
+ result = base_stash.get_one(
+ root_verify_key,
+ filters={"name": obj.name},
+ ).unwrap()
- for result in (
- base_stash.query_one_kwargs(root_verify_key, name=obj.name),
- base_stash.query_one(
- root_verify_key, QueryKey.from_obj(NamePartitionKey, obj.name)
- ),
- ):
- assert result.is_ok()
- assert result.ok() == obj
+ assert result == obj
existing_names = {obj.name for obj in mock_objects}
random_name = create_unique(faker.name, existing_names)
- for result in (
- base_stash.query_one_kwargs(root_verify_key, name=random_name),
- base_stash.query_one(
- root_verify_key, QueryKey.from_obj(NamePartitionKey, random_name)
- ),
- ):
- assert result.is_err()
- assert isinstance(result.err(), NotFoundException)
+ with pytest.raises(NotFoundException):
+ result = base_stash.get_one(
+ root_verify_key,
+ filters={"name": random_name},
+ ).unwrap()
params = {"name": obj.name, "desc": obj.desc}
- for result in [
- base_stash.query_one_kwargs(root_verify_key, **params),
- base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_ok()
- assert result.ok() == obj
+ result = base_stash.get_one(
+ root_verify_key,
+ filters=params,
+ ).unwrap()
+ assert result == obj
params = {"name": random_name, "desc": random_sentence(faker)}
- for result in [
- base_stash.query_one_kwargs(root_verify_key, **params),
- base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_err()
- assert isinstance(result.err(), NotFoundException)
+ with pytest.raises(NotFoundException):
+ result = base_stash.get_one(
+ root_verify_key,
+ filters=params,
+ ).unwrap()
+
+
+def test_basestash_query_enum(
+ root_verify_key, base_stash: MockStash, mock_object: MockObject
+) -> None:
+ base_stash.set(root_verify_key, mock_object).unwrap()
+ result = base_stash.get_one(
+ root_verify_key,
+ filters={"status": Status.CREATED},
+ ).unwrap()
+
+ assert result == mock_object
+ with pytest.raises(NotFoundException):
+ result = base_stash.get_one(
+ root_verify_key,
+ filters={"status": Status.PROCESSING},
+ ).unwrap()
+
+
+def test_basestash_query_linked_obj(
+ root_verify_key, base_stash: MockStash, mock_object: MockObject
+) -> None:
+ mock_object.linked_obj = LinkedObject(
+ object_type=MockObject,
+ object_uid=UID(),
+ id=UID(),
+ tags=["tag1", "tag2"],
+ server_uid=UID(),
+ service_type=RequestService,
+ )
+ base_stash.set(root_verify_key, mock_object).unwrap()
+
+ result = base_stash.get_one(
+ root_verify_key,
+ filters={"linked_obj.id": mock_object.linked_obj.id},
+ ).unwrap()
+
+ assert result == mock_object
def test_basestash_query_all(
@@ -339,46 +347,30 @@ def test_basestash_query_all(
for obj in all_objects:
base_stash.set(root_verify_key, obj)
- for result in [
- base_stash.query_all_kwargs(root_verify_key, desc=desc),
- base_stash.query_all(
- root_verify_key, QueryKey.from_obj(DescPartitionKey, desc)
- ),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == n_same
- assert all(obj.desc == desc for obj in objects)
- original_object_values = {get_object_values(obj) for obj in similar_objects}
- retrived_objects_values = {get_object_values(obj) for obj in objects}
- assert original_object_values == retrived_objects_values
+ objects = base_stash.get_all(root_verify_key, filters={"desc": desc}).unwrap()
+ assert len(objects) == n_same
+ assert all(obj.desc == desc for obj in objects)
+ original_object_values = {get_object_values(obj) for obj in similar_objects}
+ retrived_objects_values = {get_object_values(obj) for obj in objects}
+ assert original_object_values == retrived_objects_values
random_desc = create_unique(
random_sentence, [obj.desc for obj in all_objects], faker
)
- for result in [
- base_stash.query_all_kwargs(root_verify_key, desc=random_desc),
- base_stash.query_all(
- root_verify_key, QueryKey.from_obj(DescPartitionKey, random_desc)
- ),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == 0
+
+ objects = base_stash.get_all(
+ root_verify_key, filters={"desc": random_desc}
+ ).unwrap()
+ assert len(objects) == 0
obj = random.choice(similar_objects)
params = {"name": obj.name, "desc": obj.desc}
- for result in [
- base_stash.query_all_kwargs(root_verify_key, **params),
- base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == sum(
- 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc)
- )
- assert objects[0] == obj
+ objects = base_stash.get_all(root_verify_key, filters=params).unwrap()
+ assert len(objects) == sum(
+ 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc)
+ )
+ assert objects[0] == obj
def test_basestash_query_all_kwargs_multiple_params(
@@ -397,66 +389,23 @@ def test_basestash_query_all_kwargs_multiple_params(
base_stash.set(root_verify_key, obj)
params = {"importance": importance, "desc": desc}
- for result in [
- base_stash.query_all_kwargs(root_verify_key, **params),
- base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == n_same
- assert all(obj.desc == desc for obj in objects)
- original_object_values = {get_object_values(obj) for obj in similar_objects}
- retrived_objects_values = {get_object_values(obj) for obj in objects}
- assert original_object_values == retrived_objects_values
+ objects = base_stash.get_all(root_verify_key, filters=params).unwrap()
+ assert len(objects) == n_same
+ assert all(obj.desc == desc for obj in objects)
+ original_object_values = {get_object_values(obj) for obj in similar_objects}
+ retrived_objects_values = {get_object_values(obj) for obj in objects}
+ assert original_object_values == retrived_objects_values
params = {
"name": create_unique(faker.name, [obj.name for obj in all_objects]),
"desc": random_sentence(faker),
}
- for result in [
- base_stash.query_all_kwargs(root_verify_key, **params),
- base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == 0
+ objects = base_stash.get_all(root_verify_key, filters=params).unwrap()
+ assert len(objects) == 0
obj = random.choice(similar_objects)
params = {"id": obj.id, "name": obj.name, "desc": obj.desc}
- for result in [
- base_stash.query_all_kwargs(root_verify_key, **params),
- base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)),
- ]:
- assert result.is_ok()
- objects = result.ok()
- assert len(objects) == 1
- assert objects[0] == obj
-
-
-def test_basestash_cannot_query_non_searchable(
- root_verify_key, base_stash: MockStash, mock_objects: list[MockObject]
-) -> None:
- for obj in mock_objects:
- base_stash.set(root_verify_key, obj)
-
- obj = random.choice(mock_objects)
-
- assert base_stash.query_one_kwargs(root_verify_key, value=10).is_err()
- assert base_stash.query_all_kwargs(root_verify_key, value=10).is_err()
- assert base_stash.query_one_kwargs(
- root_verify_key, value=10, name=obj.name
- ).is_err()
- assert base_stash.query_all_kwargs(
- root_verify_key, value=10, name=obj.name
- ).is_err()
-
- ValuePartitionKey = PartitionKey(key="value", type_=int)
- qk = ValuePartitionKey.with_obj(10)
-
- assert base_stash.query_one(root_verify_key, qk).is_err()
- assert base_stash.query_all(root_verify_key, qk).is_err()
- assert base_stash.query_all(root_verify_key, QueryKeys(qks=[qk])).is_err()
- assert base_stash.query_all(
- root_verify_key, QueryKeys(qks=[qk, UIDPartitionKey.with_obj(obj.id)])
- ).is_err()
+ objects = base_stash.get_all(root_verify_key, filters=params).unwrap()
+ assert len(objects) == 1
+ assert objects[0] == obj
diff --git a/packages/syft/tests/syft/stores/dict_document_store_test.py b/packages/syft/tests/syft/stores/dict_document_store_test.py
deleted file mode 100644
index e04414d666c..00000000000
--- a/packages/syft/tests/syft/stores/dict_document_store_test.py
+++ /dev/null
@@ -1,358 +0,0 @@
-# stdlib
-from threading import Thread
-
-# syft absolute
-from syft.store.dict_document_store import DictStorePartition
-from syft.store.document_store import QueryKeys
-from syft.types.uid import UID
-
-# relative
-from .store_mocks_test import MockObjectType
-from .store_mocks_test import MockSyftObject
-
-
-def test_dict_store_partition_sanity(dict_store_partition: DictStorePartition) -> None:
- res = dict_store_partition.init_store()
- assert res.is_ok()
-
- assert hasattr(dict_store_partition, "data")
- assert hasattr(dict_store_partition, "unique_keys")
- assert hasattr(dict_store_partition, "searchable_keys")
-
-
-def test_dict_store_partition_set(
- root_verify_key, dict_store_partition: DictStorePartition
-) -> None:
- res = dict_store_partition.init_store()
- assert res.is_ok()
-
- obj = MockSyftObject(id=UID(), data=1)
- res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
-
- assert res.is_ok()
- assert res.ok() == obj
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_err()
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=True)
- assert res.is_ok()
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- obj2 = MockSyftObject(data=2)
- res = dict_store_partition.set(root_verify_key, obj2, ignore_duplicates=False)
- assert res.is_ok()
- assert res.ok() == obj2
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 2
- )
-
- repeats = 5
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
- res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 3 + idx
- )
-
-
-def test_dict_store_partition_delete(
- root_verify_key, dict_store_partition: DictStorePartition
-) -> None:
- res = dict_store_partition.init_store()
- assert res.is_ok()
-
- objs = []
- repeats = 5
- for v in range(repeats):
- obj = MockSyftObject(data=v)
- dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- objs.append(obj)
-
- assert len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # random object
- obj = MockSyftObject(data="bogus")
- key = dict_store_partition.settings.store_key.with_obj(obj)
- res = dict_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # cleanup store
- for idx, v in enumerate(objs):
- key = dict_store_partition.settings.store_key.with_obj(v)
- res = dict_store_partition.delete(root_verify_key, key)
- assert res.is_ok()
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- res = dict_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 0
- )
-
-
-def test_dict_store_partition_update(
- root_verify_key, dict_store_partition: DictStorePartition
-) -> None:
- dict_store_partition.init_store()
-
- # add item
- obj = MockSyftObject(data=1)
- dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert len(dict_store_partition.all(root_verify_key).ok()) == 1
-
- # fail to update missing keys
- rand_obj = MockSyftObject(data="bogus")
- key = dict_store_partition.settings.store_key.with_obj(rand_obj)
- res = dict_store_partition.update(root_verify_key, key, obj)
- assert res.is_err()
-
- # update the key multiple times
- repeats = 5
- for v in range(repeats):
- key = dict_store_partition.settings.store_key.with_obj(obj)
- obj_new = MockSyftObject(data=v)
-
- res = dict_store_partition.update(root_verify_key, key, obj_new)
- assert res.is_ok()
-
- # The ID should stay the same on update, unly the values are updated.
- assert (
- len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
- assert (
- dict_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- == obj.id
- )
- assert (
- dict_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- != obj_new.id
- )
- assert (
- dict_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .data
- == v
- )
-
- stored = dict_store_partition.get_all_from_store(
- root_verify_key, QueryKeys(qks=[key])
- )
- assert stored.ok()[0].data == v
-
-
-def test_dict_store_partition_set_multithreaded(
- root_verify_key,
- dict_store_partition: DictStorePartition,
-) -> None:
- thread_cnt = 3
- repeats = 5
-
- dict_store_partition.init_store()
-
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for idx in range(repeats):
- obj = MockObjectType(data=idx)
-
- for _ in range(10):
- res = dict_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
- stored_cnt = len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == repeats * thread_cnt
-
-
-def test_dict_store_partition_update_multithreaded(
- root_verify_key,
- dict_store_partition: DictStorePartition,
-) -> None:
- thread_cnt = 3
- repeats = 5
- dict_store_partition.init_store()
-
- obj = MockSyftObject(data=0)
- key = dict_store_partition.settings.store_key.with_obj(obj)
- dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for repeat in range(repeats):
- obj = MockSyftObject(data=repeat)
-
- for _ in range(10):
- res = dict_store_partition.update(root_verify_key, key, obj)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
-
-def test_dict_store_partition_set_delete_multithreaded(
- root_verify_key,
- dict_store_partition: DictStorePartition,
-) -> None:
- dict_store_partition.init_store()
-
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
-
- for _ in range(10):
- res = dict_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- key = dict_store_partition.settings.store_key.with_obj(obj)
-
- res = dict_store_partition.delete(root_verify_key, key)
- if res.is_err():
- execution_err = res
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
- stored_cnt = len(
- dict_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == 0
diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py
deleted file mode 100644
index 95df806c189..00000000000
--- a/packages/syft/tests/syft/stores/mongo_document_store_test.py
+++ /dev/null
@@ -1,1045 +0,0 @@
-# stdlib
-from secrets import token_hex
-from threading import Thread
-
-# third party
-import pytest
-
-# syft absolute
-from syft.server.credentials import SyftVerifyKey
-from syft.service.action.action_permissions import ActionObjectPermission
-from syft.service.action.action_permissions import ActionPermission
-from syft.service.action.action_permissions import StoragePermission
-from syft.service.action.action_store import ActionObjectEXECUTE
-from syft.service.action.action_store import ActionObjectOWNER
-from syft.service.action.action_store import ActionObjectREAD
-from syft.service.action.action_store import ActionObjectWRITE
-from syft.store.document_store import PartitionSettings
-from syft.store.document_store import QueryKey
-from syft.store.document_store import QueryKeys
-from syft.store.mongo_client import MongoStoreClientConfig
-from syft.store.mongo_document_store import MongoStoreConfig
-from syft.store.mongo_document_store import MongoStorePartition
-from syft.types.errors import SyftException
-from syft.types.uid import UID
-
-# relative
-from ...mongomock.collection import Collection as MongoCollection
-from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER
-from .store_fixtures_test import mongo_store_partition_fn
-from .store_mocks_test import MockObjectType
-from .store_mocks_test import MockSyftObject
-
-PERMISSIONS = [
- ActionObjectOWNER,
- ActionObjectREAD,
- ActionObjectWRITE,
- ActionObjectEXECUTE,
-]
-
-
-def test_mongo_store_partition_sanity(
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- assert hasattr(mongo_store_partition, "_collection")
- assert hasattr(mongo_store_partition, "_permissions")
-
-
-@pytest.mark.skip(reason="Test gets stuck at store.init_store()")
-def test_mongo_store_partition_init_failed(root_verify_key) -> None:
- # won't connect
- mongo_config = MongoStoreClientConfig(
- connectTimeoutMS=1,
- timeoutMS=1,
- )
-
- store_config = MongoStoreConfig(client_config=mongo_config)
- settings = PartitionSettings(name="test", object_type=MockObjectType)
-
- store = MongoStorePartition(
- UID(), root_verify_key, settings=settings, store_config=store_config
- )
-
- res = store.init_store()
- assert res.is_err()
-
-
-def test_mongo_store_partition_set(
- root_verify_key, mongo_store_partition: MongoStorePartition
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- obj = MockSyftObject(data=1)
-
- res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
-
- assert res.is_ok()
- assert res.ok() == obj
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_err()
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=True)
- assert res.is_ok()
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- obj2 = MockSyftObject(data=2)
- res = mongo_store_partition.set(root_verify_key, obj2, ignore_duplicates=False)
- assert res.is_ok()
- assert res.ok() == obj2
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 2
- )
-
- repeats = 5
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
- res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 3 + idx
- )
-
-
-def test_mongo_store_partition_delete(
- root_verify_key,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- repeats = 5
-
- objs = []
- for v in range(repeats):
- obj = MockSyftObject(data=v)
- mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- objs.append(obj)
-
- assert len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # random object
- obj = MockSyftObject(data="bogus")
- key = mongo_store_partition.settings.store_key.with_obj(obj)
- res = mongo_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # cleanup store
- for idx, v in enumerate(objs):
- key = mongo_store_partition.settings.store_key.with_obj(v)
- res = mongo_store_partition.delete(root_verify_key, key)
- assert res.is_ok()
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- res = mongo_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 0
- )
-
-
-def test_mongo_store_partition_update(
- root_verify_key,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- mongo_store_partition.init_store()
-
- # add item
- obj = MockSyftObject(data=1)
- mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- # fail to update missing keys
- rand_obj = MockSyftObject(data="bogus")
- key = mongo_store_partition.settings.store_key.with_obj(rand_obj)
- res = mongo_store_partition.update(root_verify_key, key, obj)
- assert res.is_err()
-
- # update the key multiple times
- repeats = 5
- for v in range(repeats):
- key = mongo_store_partition.settings.store_key.with_obj(obj)
- obj_new = MockSyftObject(data=v)
-
- res = mongo_store_partition.update(root_verify_key, key, obj_new)
- assert res.is_ok()
-
- # The ID should stay the same on update, only the values are updated.
- assert (
- len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
- assert (
- mongo_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- == obj.id
- )
- assert (
- mongo_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- != obj_new.id
- )
- assert (
- mongo_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .data
- == v
- )
-
- stored = mongo_store_partition.get_all_from_store(
- root_verify_key, QueryKeys(qks=[key])
- )
- assert stored.ok()[0].data == v
-
-
-def test_mongo_store_partition_set_threading(root_verify_key, mongo_client) -> None:
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
- mongo_db_name = token_hex(8)
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
-
- mongo_store_partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
- for idx in range(repeats):
- obj = MockObjectType(data=idx)
-
- for _ in range(10):
- res = mongo_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- return execution_err
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
- mongo_store_partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
- stored_cnt = len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == thread_cnt * repeats
-
-
-# @pytest.mark.skip(
-# reason="PicklingError: Could not pickle the task to send it to the workers."
-# )
-# def test_mongo_store_partition_set_joblib(
-# root_verify_key,
-# mongo_client,
-# ) -> None:
-# thread_cnt = 3
-# repeats = 5
-# mongo_db_name = token_hex(8)
-
-# def _kv_cbk(tid: int) -> None:
-# for idx in range(repeats):
-# mongo_store_partition = mongo_store_partition_fn(
-# mongo_client,
-# root_verify_key,
-# mongo_db_name=mongo_db_name,
-# )
-# obj = MockObjectType(data=idx)
-
-# for _ in range(10):
-# res = mongo_store_partition.set(
-# root_verify_key, obj, ignore_duplicates=False
-# )
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-
-# for execution_err in errs:
-# assert execution_err is None
-
-# mongo_store_partition = mongo_store_partition_fn(
-# mongo_client,
-# root_verify_key,
-# mongo_db_name=mongo_db_name,
-# )
-# stored_cnt = len(
-# mongo_store_partition.all(
-# root_verify_key,
-# ).ok()
-# )
-# assert stored_cnt == thread_cnt * repeats
-
-
-def test_mongo_store_partition_update_threading(
- root_verify_key,
- mongo_client,
-) -> None:
- thread_cnt = 3
- repeats = 5
-
- mongo_db_name = token_hex(8)
- mongo_store_partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
-
- obj = MockSyftObject(data=0)
- key = mongo_store_partition.settings.store_key.with_obj(obj)
- mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
-
- mongo_store_partition_local = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
- for repeat in range(repeats):
- obj = MockSyftObject(data=repeat)
-
- for _ in range(10):
- res = mongo_store_partition_local.update(root_verify_key, key, obj)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
-
-# @pytest.mark.skip(
-# reason="PicklingError: Could not pickle the task to send it to the workers."
-# )
-# def test_mongo_store_partition_update_joblib(root_verify_key, mongo_client) -> None:
-# thread_cnt = 3
-# repeats = 5
-
-# mongo_db_name = token_hex(8)
-
-# mongo_store_partition = mongo_store_partition_fn(
-# mongo_client,
-# root_verify_key,
-# mongo_db_name=mongo_db_name,
-# )
-# obj = MockSyftObject(data=0)
-# key = mongo_store_partition.settings.store_key.with_obj(obj)
-# mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
-
-# def _kv_cbk(tid: int) -> None:
-# mongo_store_partition_local = mongo_store_partition_fn(
-# mongo_client,
-# root_verify_key,
-# mongo_db_name=mongo_db_name,
-# )
-# for repeat in range(repeats):
-# obj = MockSyftObject(data=repeat)
-
-# for _ in range(10):
-# res = mongo_store_partition_local.update(root_verify_key, key, obj)
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-
-# for execution_err in errs:
-# assert execution_err is None
-
-
-def test_mongo_store_partition_set_delete_threading(
- root_verify_key,
- mongo_client,
-) -> None:
- thread_cnt = 3
- repeats = 5
- execution_err = None
- mongo_db_name = token_hex(8)
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- mongo_store_partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
-
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
-
- for _ in range(10):
- res = mongo_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- key = mongo_store_partition.settings.store_key.with_obj(obj)
-
- res = mongo_store_partition.delete(root_verify_key, key)
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
- mongo_store_partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- )
- stored_cnt = len(
- mongo_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == 0
-
-
-# @pytest.mark.skip(
-# reason="PicklingError: Could not pickle the task to send it to the workers."
-# )
-# def test_mongo_store_partition_set_delete_joblib(root_verify_key, mongo_client) -> None:
-# thread_cnt = 3
-# repeats = 5
-# mongo_db_name = token_hex(8)
-
-# def _kv_cbk(tid: int) -> None:
-# mongo_store_partition = mongo_store_partition_fn(
-# mongo_client, root_verify_key, mongo_db_name=mongo_db_name
-# )
-
-# for idx in range(repeats):
-# obj = MockSyftObject(data=idx)
-
-# for _ in range(10):
-# res = mongo_store_partition.set(
-# root_verify_key, obj, ignore_duplicates=False
-# )
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-
-# key = mongo_store_partition.settings.store_key.with_obj(obj)
-
-# res = mongo_store_partition.delete(root_verify_key, key)
-# if res.is_err():
-# return res
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-# for execution_err in errs:
-# assert execution_err is None
-
-# mongo_store_partition = mongo_store_partition_fn(
-# mongo_client,
-# root_verify_key,
-# mongo_db_name=mongo_db_name,
-# )
-# stored_cnt = len(
-# mongo_store_partition.all(
-# root_verify_key,
-# ).ok()
-# )
-# assert stored_cnt == 0
-
-
-def test_mongo_store_partition_permissions_collection(
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- collection_permissions_status = mongo_store_partition.permissions
- assert not collection_permissions_status.is_err()
- collection_permissions = collection_permissions_status.ok()
- assert isinstance(collection_permissions, MongoCollection)
-
-
-def test_mongo_store_partition_add_remove_permission(
- root_verify_key: SyftVerifyKey, mongo_store_partition: MongoStorePartition
-) -> None:
- """
- Test the add_permission and remove_permission functions of MongoStorePartition
- """
- # setting up
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- permissions_collection: MongoCollection = mongo_store_partition.permissions.ok()
- obj = MockSyftObject(data=1)
-
- # add the first permission
- obj_read_permission = ActionObjectPermission(
- uid=obj.id, permission=ActionPermission.READ, credentials=root_verify_key
- )
- mongo_store_partition.add_permission(obj_read_permission)
- find_res_1 = permissions_collection.find_one({"_id": obj_read_permission.uid})
- assert find_res_1 is not None
- assert len(find_res_1["permissions"]) == 1
- assert find_res_1["permissions"] == {
- obj_read_permission.permission_string,
- }
-
- # add the second permission
- obj_write_permission = ActionObjectPermission(
- uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key
- )
- mongo_store_partition.add_permission(obj_write_permission)
-
- find_res_2 = permissions_collection.find_one({"_id": obj.id})
- assert find_res_2 is not None
- assert len(find_res_2["permissions"]) == 2
- assert find_res_2["permissions"] == {
- obj_read_permission.permission_string,
- obj_write_permission.permission_string,
- }
-
- # add duplicated permission
- mongo_store_partition.add_permission(obj_write_permission)
- find_res_3 = permissions_collection.find_one({"_id": obj.id})
- assert len(find_res_3["permissions"]) == 2
- assert find_res_3["permissions"] == find_res_2["permissions"]
-
- # remove the write permission
- mongo_store_partition.remove_permission(obj_write_permission)
- find_res_4 = permissions_collection.find_one({"_id": obj.id})
- assert len(find_res_4["permissions"]) == 1
- assert find_res_1["permissions"] == {
- obj_read_permission.permission_string,
- }
-
- # remove a non-existent permission
- with pytest.raises(SyftException):
- mongo_store_partition.remove_permission(
- ActionObjectPermission(
- uid=obj.id,
- permission=ActionPermission.OWNER,
- credentials=root_verify_key,
- )
- )
- find_res_5 = permissions_collection.find_one({"_id": obj.id})
- assert len(find_res_5["permissions"]) == 1
- assert find_res_1["permissions"] == {
- obj_read_permission.permission_string,
- }
-
- # there is only one permission object
- assert permissions_collection.count_documents({}) == 1
-
- # add permissions in a loop
- new_permissions = []
- repeats = 5
- for idx in range(1, repeats + 1):
- new_obj = MockSyftObject(data=idx)
- new_obj_read_permission = ActionObjectPermission(
- uid=new_obj.id,
- permission=ActionPermission.READ,
- credentials=root_verify_key,
- )
- new_permissions.append(new_obj_read_permission)
- mongo_store_partition.add_permission(new_obj_read_permission)
- assert permissions_collection.count_documents({}) == 1 + idx
-
- # remove all the permissions added in the loop
- for permission in new_permissions:
- mongo_store_partition.remove_permission(permission)
-
- assert permissions_collection.count_documents({}) == 1
-
-
-def test_mongo_store_partition_add_remove_storage_permission(
- root_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- """
- Test the add_storage_permission and remove_storage_permission functions of MongoStorePartition
- """
-
- obj = MockSyftObject(data=1)
-
- storage_permission = StoragePermission(
- uid=obj.id,
- server_uid=UID(),
- )
- assert not mongo_store_partition.has_storage_permission(storage_permission)
- mongo_store_partition.add_storage_permission(storage_permission)
- assert mongo_store_partition.has_storage_permission(storage_permission)
- mongo_store_partition.remove_storage_permission(storage_permission)
- assert not mongo_store_partition.has_storage_permission(storage_permission)
-
- obj2 = MockSyftObject(data=1)
- mongo_store_partition.set(root_verify_key, obj2, add_storage_permission=False)
- storage_permission3 = StoragePermission(
- uid=obj2.id, server_uid=mongo_store_partition.server_uid
- )
- assert not mongo_store_partition.has_storage_permission(storage_permission3)
-
- obj3 = MockSyftObject(data=1)
- mongo_store_partition.set(root_verify_key, obj3, add_storage_permission=True)
- storage_permission4 = StoragePermission(
- uid=obj3.id, server_uid=mongo_store_partition.server_uid
- )
- assert mongo_store_partition.has_storage_permission(storage_permission4)
-
-
-def test_mongo_store_partition_add_permissions(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- permissions_collection: MongoCollection = mongo_store_partition.permissions.ok()
- obj = MockSyftObject(data=1)
-
- # add multiple permissions for the first object
- permission_1 = ActionObjectPermission(
- uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key
- )
- permission_2 = ActionObjectPermission(
- uid=obj.id, permission=ActionPermission.OWNER, credentials=root_verify_key
- )
- permission_3 = ActionObjectPermission(
- uid=obj.id, permission=ActionPermission.READ, credentials=guest_verify_key
- )
- permissions: list[ActionObjectPermission] = [
- permission_1,
- permission_2,
- permission_3,
- ]
- mongo_store_partition.add_permissions(permissions)
-
- # check if the permissions have been added properly
- assert permissions_collection.count_documents({}) == 1
- find_res = permissions_collection.find_one({"_id": obj.id})
- assert find_res is not None
- assert len(find_res["permissions"]) == 3
-
- # add permissions for the second object
- obj_2 = MockSyftObject(data=2)
- permission_4 = ActionObjectPermission(
- uid=obj_2.id, permission=ActionPermission.READ, credentials=root_verify_key
- )
- permission_5 = ActionObjectPermission(
- uid=obj_2.id, permission=ActionPermission.WRITE, credentials=root_verify_key
- )
- mongo_store_partition.add_permissions([permission_4, permission_5])
-
- assert permissions_collection.count_documents({}) == 2
- find_res_2 = permissions_collection.find_one({"_id": obj_2.id})
- assert find_res_2 is not None
- assert len(find_res_2["permissions"]) == 2
-
-
-@pytest.mark.parametrize("permission", PERMISSIONS)
-def test_mongo_store_partition_has_permission(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
- permission: ActionObjectPermission,
-) -> None:
- hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
-
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- # root permission
- obj = MockSyftObject(data=1)
- permission_root = permission(uid=obj.id, credentials=root_verify_key)
- permission_client = permission(uid=obj.id, credentials=guest_verify_key)
- permission_hacker = permission(uid=obj.id, credentials=hacker_verify_key)
- mongo_store_partition.add_permission(permission_root)
- # only the root user has access to this permission
- assert mongo_store_partition.has_permission(permission_root)
- assert not mongo_store_partition.has_permission(permission_client)
- assert not mongo_store_partition.has_permission(permission_hacker)
-
- # client permission for another object
- obj_2 = MockSyftObject(data=2)
- permission_client_2 = permission(uid=obj_2.id, credentials=guest_verify_key)
- permission_root_2 = permission(uid=obj_2.id, credentials=root_verify_key)
- permisson_hacker_2 = permission(uid=obj_2.id, credentials=hacker_verify_key)
- mongo_store_partition.add_permission(permission_client_2)
- # the root (admin) and guest client should have this permission
- assert mongo_store_partition.has_permission(permission_root_2)
- assert mongo_store_partition.has_permission(permission_client_2)
- assert not mongo_store_partition.has_permission(permisson_hacker_2)
-
- # remove permissions
- mongo_store_partition.remove_permission(permission_root)
- assert not mongo_store_partition.has_permission(permission_root)
- assert not mongo_store_partition.has_permission(permission_client)
- assert not mongo_store_partition.has_permission(permission_hacker)
-
- mongo_store_partition.remove_permission(permission_client_2)
- assert not mongo_store_partition.has_permission(permission_root_2)
- assert not mongo_store_partition.has_permission(permission_client_2)
- assert not mongo_store_partition.has_permission(permisson_hacker_2)
-
-
-@pytest.mark.parametrize("permission", PERMISSIONS)
-def test_mongo_store_partition_take_ownership(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
- permission: ActionObjectPermission,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
- obj = MockSyftObject(data=1)
-
- # the guest client takes ownership of obj
- mongo_store_partition.take_ownership(
- uid=obj.id, credentials=guest_verify_key
- ).unwrap()
- assert mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=guest_verify_key)
- )
- # the root client will also has the permission
- assert mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=root_verify_key)
- )
- assert not mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=hacker_verify_key)
- )
-
- # hacker or root try to take ownership of the obj and will fail
- res = mongo_store_partition.take_ownership(
- uid=obj.id, credentials=hacker_verify_key
- )
- res_2 = mongo_store_partition.take_ownership(
- uid=obj.id, credentials=root_verify_key
- )
- assert res.is_err()
- assert res_2.is_err()
- assert (
- res.value.public_message
- == res_2.value.public_message
- == f"UID: {obj.id} already owned."
- )
-
- # another object
- obj_2 = MockSyftObject(data=2)
- # root client takes ownership
- mongo_store_partition.take_ownership(uid=obj_2.id, credentials=root_verify_key)
- assert mongo_store_partition.has_permission(
- permission(uid=obj_2.id, credentials=root_verify_key)
- )
- assert not mongo_store_partition.has_permission(
- permission(uid=obj_2.id, credentials=guest_verify_key)
- )
- assert not mongo_store_partition.has_permission(
- permission(uid=obj_2.id, credentials=hacker_verify_key)
- )
-
-
-def test_mongo_store_partition_permissions_set(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- """
- Test the permissions functionalities when using MongoStorePartition._set function
- """
- hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
- res = mongo_store_partition.init_store()
- assert res.is_ok()
-
- # set the object to mongo_store_partition.collection
- obj = MockSyftObject(data=1)
- res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
- assert res.ok() == obj
-
- # check if the corresponding permissions has been added to the permissions
- # collection after the root client claim it
- pemissions_collection = mongo_store_partition.permissions.ok()
- assert isinstance(pemissions_collection, MongoCollection)
- permissions = pemissions_collection.find_one({"_id": obj.id})
- assert permissions is not None
- assert isinstance(permissions["permissions"], set)
- assert len(permissions["permissions"]) == 4
- for permission in PERMISSIONS:
- assert mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=root_verify_key)
- )
-
- # the hacker tries to set duplicated object but should not be able to claim it
- res_2 = mongo_store_partition.set(guest_verify_key, obj, ignore_duplicates=True)
- assert res_2.is_ok()
- for permission in PERMISSIONS:
- assert not mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=hacker_verify_key)
- )
- assert mongo_store_partition.has_permission(
- permission(uid=obj.id, credentials=root_verify_key)
- )
-
-
-def test_mongo_store_partition_permissions_get_all(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
- # set several objects for the root and guest client
- num_root_objects: int = 5
- num_guest_objects: int = 3
- for i in range(num_root_objects):
- obj = MockSyftObject(data=i)
- mongo_store_partition.set(
- credentials=root_verify_key, obj=obj, ignore_duplicates=False
- )
- for i in range(num_guest_objects):
- obj = MockSyftObject(data=i)
- mongo_store_partition.set(
- credentials=guest_verify_key, obj=obj, ignore_duplicates=False
- )
-
- assert (
- len(mongo_store_partition.all(root_verify_key).ok())
- == num_root_objects + num_guest_objects
- )
- assert len(mongo_store_partition.all(guest_verify_key).ok()) == num_guest_objects
- assert len(mongo_store_partition.all(hacker_verify_key).ok()) == 0
-
-
-def test_mongo_store_partition_permissions_delete(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- collection: MongoCollection = mongo_store_partition.collection.ok()
- pemissions_collection: MongoCollection = mongo_store_partition.permissions.ok()
- hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER)
-
- # the root client set an object
- obj = MockSyftObject(data=1)
- mongo_store_partition.set(
- credentials=root_verify_key, obj=obj, ignore_duplicates=False
- )
- qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj)
- # guest or hacker can't delete it
- assert not mongo_store_partition.delete(guest_verify_key, qk).is_ok()
- assert not mongo_store_partition.delete(hacker_verify_key, qk).is_ok()
- # only the root client can delete it
- assert mongo_store_partition.delete(root_verify_key, qk).is_ok()
- # check if the object and its permission have been deleted
- assert collection.count_documents({}) == 0
- assert pemissions_collection.count_documents({}) == 0
-
- # the guest client set an object
- obj_2 = MockSyftObject(data=2)
- mongo_store_partition.set(
- credentials=guest_verify_key, obj=obj_2, ignore_duplicates=False
- )
- qk_2: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_2)
- # the hacker can't delete it
- assert not mongo_store_partition.delete(hacker_verify_key, qk_2).is_ok()
- # the guest client can delete it
- assert mongo_store_partition.delete(guest_verify_key, qk_2).is_ok()
- assert collection.count_documents({}) == 0
- assert pemissions_collection.count_documents({}) == 0
-
- # the guest client set another object
- obj_3 = MockSyftObject(data=3)
- mongo_store_partition.set(
- credentials=guest_verify_key, obj=obj_3, ignore_duplicates=False
- )
- qk_3: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_3)
- # the root client also has the permission to delete it
- assert mongo_store_partition.delete(root_verify_key, qk_3).is_ok()
- assert collection.count_documents({}) == 0
- assert pemissions_collection.count_documents({}) == 0
-
-
-def test_mongo_store_partition_permissions_update(
- root_verify_key: SyftVerifyKey,
- guest_verify_key: SyftVerifyKey,
- mongo_store_partition: MongoStorePartition,
-) -> None:
- res = mongo_store_partition.init_store()
- assert res.is_ok()
- # the root client set an object
- obj = MockSyftObject(data=1)
- mongo_store_partition.set(
- credentials=root_verify_key, obj=obj, ignore_duplicates=False
- )
- assert len(mongo_store_partition.all(credentials=root_verify_key).ok()) == 1
-
- qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj)
- permsissions: MongoCollection = mongo_store_partition.permissions.ok()
- repeats = 5
-
- for v in range(repeats):
- # the guest client should not have permission to update obj
- obj_new = MockSyftObject(data=v)
- res = mongo_store_partition.update(
- credentials=guest_verify_key, qk=qk, obj=obj_new
- )
- assert res.is_err()
- # the root client has the permission to update obj
- res = mongo_store_partition.update(
- credentials=root_verify_key, qk=qk, obj=obj_new
- )
- assert res.is_ok()
- # the id of the object in the permission collection should not be changed
- assert permsissions.find_one(qk.as_dict_mongo)["_id"] == obj.id
diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py
index 312766e7c4e..d4e2cd25747 100644
--- a/packages/syft/tests/syft/stores/queue_stash_test.py
+++ b/packages/syft/tests/syft/stores/queue_stash_test.py
@@ -1,26 +1,20 @@
# stdlib
-import threading
-from threading import Thread
-import time
-from typing import Any
+from concurrent.futures import ThreadPoolExecutor
# third party
import pytest
# syft absolute
from syft.service.queue.queue_stash import QueueItem
+from syft.service.queue.queue_stash import QueueStash
from syft.service.worker.worker_pool import WorkerPool
from syft.service.worker.worker_pool_service import SyftWorkerPoolService
from syft.store.linked_obj import LinkedObject
from syft.types.errors import SyftException
from syft.types.uid import UID
-# relative
-from .store_fixtures_test import mongo_queue_stash_fn
-from .store_fixtures_test import sqlite_queue_stash_fn
-
-def mock_queue_object():
+def mock_queue_object() -> QueueItem:
worker_pool_obj = WorkerPool(
name="mypool",
image_id=UID(),
@@ -47,405 +41,246 @@ def mock_queue_object():
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-def test_queue_stash_sanity(queue: Any) -> None:
+def test_queue_stash_sanity(queue: QueueStash) -> None:
assert len(queue) == 0
- assert hasattr(queue, "store")
- assert hasattr(queue, "partition")
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-# @pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_stash_set_get(root_verify_key, queue: Any) -> None:
- objs = []
+#
+def test_queue_stash_set_get(root_verify_key, queue: QueueStash) -> None:
+ objs: list[QueueItem] = []
repeats = 5
for idx in range(repeats):
obj = mock_queue_object()
objs.append(obj)
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
+ queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap()
assert len(queue) == idx + 1
with pytest.raises(SyftException):
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
+ queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap()
assert len(queue) == idx + 1
assert len(queue.get_all(root_verify_key).ok()) == idx + 1
- item = queue.find_one(root_verify_key, id=obj.id)
- assert item.is_ok()
- assert item.ok() == obj
+ item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap()
+ assert item == obj
cnt = len(objs)
for obj in objs:
- res = queue.find_and_delete(root_verify_key, id=obj.id)
- assert res.is_ok()
-
+ queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap()
cnt -= 1
assert len(queue) == cnt
- item = queue.find_one(root_verify_key, id=obj.id)
+ item = queue.get_by_uid(root_verify_key, uid=obj.id)
assert item.is_err()
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_stash_update(root_verify_key, queue: Any) -> None:
+def test_queue_stash_update(queue: QueueStash) -> None:
+ root_verify_key = queue.db.root_verify_key
obj = mock_queue_object()
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
+ queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap()
repeats = 5
for idx in range(repeats):
obj.args = [idx]
- res = queue.update(root_verify_key, obj)
- assert res.is_ok()
+ queue.update(root_verify_key, obj).unwrap()
assert len(queue) == 1
- item = queue.find_one(root_verify_key, id=obj.id)
- assert item.is_ok()
- assert item.ok().args == [idx]
+ item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap()
+ assert item.args == [idx]
- res = queue.find_and_delete(root_verify_key, id=obj.id)
- assert res.is_ok()
+ queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap()
assert len(queue) == 0
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_set_existing_queue_threading(root_verify_key, queue: Any) -> None:
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for _ in range(repeats):
- obj = mock_queue_object()
-
- for _ in range(10):
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
- assert len(queue) == thread_cnt * repeats
+def test_queue_set_existing_queue_threading(root_verify_key, queue: QueueStash) -> None:
+ root_verify_key = queue.db.root_verify_key
+ items_to_create = 100
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ results = list(
+ executor.map(
+ lambda obj: queue.set(
+ root_verify_key,
+ mock_queue_object(),
+ ),
+ range(items_to_create),
+ )
+ )
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
+ assert len(queue) == items_to_create
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_update_existing_queue_threading(root_verify_key, queue: Any) -> None:
- thread_cnt = 3
- repeats = 5
-
+def test_queue_update_existing_queue_threading(queue: QueueStash) -> None:
+ root_verify_key = queue.db.root_verify_key
obj = mock_queue_object()
- queue.set(root_verify_key, obj, ignore_duplicates=False)
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for repeat in range(repeats):
- obj.args = [repeat]
- for _ in range(10):
- res = queue.update(root_verify_key, obj)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
+ def update_queue():
+ obj.args = [UID()]
+ res = queue.update(root_verify_key, obj)
+ return res
- tids.append(thread)
+ queue.set(root_verify_key, obj, ignore_duplicates=False)
- for thread in tids:
- thread.join()
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ # Run the update_queue function in multiple threads
+ results = list(
+ executor.map(
+ lambda _: update_queue(),
+ range(5),
+ )
+ )
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
- assert execution_err is None
+ assert len(queue) == 1
+ item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap()
+ assert item.args != []
@pytest.mark.parametrize(
"queue",
[
- pytest.lazy_fixture("dict_queue_stash"),
- pytest.lazy_fixture("sqlite_queue_stash"),
- pytest.lazy_fixture("mongo_queue_stash"),
+ pytest.lazy_fixture("queue_stash"),
],
)
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_queue_set_delete_existing_queue_threading(
- root_verify_key,
- queue: Any,
+ queue: QueueStash,
) -> None:
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
- objs = []
-
- for _ in range(repeats * thread_cnt):
- obj = mock_queue_object()
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- objs.append(obj)
-
- assert res.is_ok()
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- for idx in range(repeats):
- item_idx = tid * repeats + idx
-
- for _ in range(10):
- res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
- assert len(queue) == 0
-
-
-def helper_queue_set_threading(root_verify_key, create_queue_cbk) -> None:
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
- lock = threading.Lock()
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- with lock:
- queue = create_queue_cbk()
-
- for _ in range(repeats):
- obj = mock_queue_object()
-
- for _ in range(10):
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- queue = create_queue_cbk()
-
- assert execution_err is None
- assert len(queue) == thread_cnt * repeats
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_set_sqlite(root_verify_key, sqlite_workspace):
- def create_queue_cbk():
- return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)
-
- helper_queue_set_threading(root_verify_key, create_queue_cbk)
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_set_threading_mongo(root_verify_key, mongo_document_store):
- def create_queue_cbk():
- return mongo_queue_stash_fn(mongo_document_store)
-
- helper_queue_set_threading(root_verify_key, create_queue_cbk)
-
-
-def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None:
- thread_cnt = 3
- repeats = 5
-
- queue = create_queue_cbk()
- time.sleep(1)
-
+ root_verify_key = queue.db.root_verify_key
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ results = list(
+ executor.map(
+ lambda obj: queue.set(
+ root_verify_key,
+ mock_queue_object(),
+ ),
+ range(15),
+ )
+ )
+ objs = [item.unwrap() for item in results]
+
+ results = list(
+ executor.map(
+ lambda obj: queue.delete_by_uid(root_verify_key, uid=obj.id),
+ objs,
+ )
+ )
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
+
+
+def test_queue_set(queue_stash: QueueStash):
+ root_verify_key = queue_stash.db.root_verify_key
+ config = queue_stash.db.config
+ server_uid = queue_stash.db.server_uid
+
+ def set_in_new_thread(_):
+ queue_stash = QueueStash.random(
+ root_verify_key=root_verify_key,
+ config=config,
+ server_uid=server_uid,
+ )
+ return queue_stash.set(root_verify_key, mock_queue_object())
+
+ total_repeats = 50
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ results = list(
+ executor.map(
+ set_in_new_thread,
+ range(total_repeats),
+ )
+ )
+
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
+ assert len(queue_stash) == total_repeats
+
+
+def test_queue_update_threading(queue_stash: QueueStash):
+ root_verify_key = queue_stash.db.root_verify_key
+ config = queue_stash.db.config
+ server_uid = queue_stash.db.server_uid
obj = mock_queue_object()
- queue.set(root_verify_key, obj, ignore_duplicates=False)
- execution_err = None
- lock = threading.Lock()
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- with lock:
- queue_local = create_queue_cbk()
-
- for repeat in range(repeats):
- obj.args = [repeat]
-
- for _ in range(10):
- res = queue_local.update(root_verify_key, obj)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace):
- def create_queue_cbk():
- return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)
-
- helper_queue_update_threading(root_verify_key, create_queue_cbk)
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_update_threading_mongo(root_verify_key, mongo_document_store):
- def create_queue_cbk():
- return mongo_queue_stash_fn(mongo_document_store)
-
- helper_queue_update_threading(root_verify_key, create_queue_cbk)
-
-
-def helper_queue_set_delete_threading(
- root_verify_key,
- create_queue_cbk,
-) -> None:
- thread_cnt = 3
- repeats = 5
-
- queue = create_queue_cbk()
- execution_err = None
- objs = []
-
- for _ in range(repeats * thread_cnt):
- obj = mock_queue_object()
- res = queue.set(root_verify_key, obj, ignore_duplicates=False)
- objs.append(obj)
-
- assert res.is_ok()
-
- lock = threading.Lock()
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- with lock:
- queue = create_queue_cbk()
- for idx in range(repeats):
- item_idx = tid * repeats + idx
-
- for _ in range(10):
- res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
- assert len(queue) == 0
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace):
- def create_queue_cbk():
- return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace)
-
- helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store):
- def create_queue_cbk():
- return mongo_queue_stash_fn(mongo_document_store)
-
- helper_queue_set_delete_threading(root_verify_key, create_queue_cbk)
+ queue_stash.set(root_verify_key, obj).unwrap()
+
+ def update_in_new_thread(_):
+ queue_stash = QueueStash.random(
+ root_verify_key=root_verify_key,
+ config=config,
+ server_uid=server_uid,
+ )
+ obj.args = [UID()]
+ return queue_stash.update(root_verify_key, obj)
+
+ total_repeats = 50
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ results = list(
+ executor.map(
+ update_in_new_thread,
+ range(total_repeats),
+ )
+ )
+
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
+ assert len(queue_stash) == 1
+
+
+def test_queue_delete_threading(queue_stash: QueueStash):
+ root_verify_key = queue_stash.db.root_verify_key
+ root_verify_key = queue_stash.db.root_verify_key
+ config = queue_stash.db.config
+ server_uid = queue_stash.db.server_uid
+
+ def delete_in_new_thread(obj: QueueItem):
+ queue_stash = QueueStash.random(
+ root_verify_key=root_verify_key,
+ config=config,
+ server_uid=server_uid,
+ )
+ return queue_stash.delete_by_uid(root_verify_key, uid=obj.id)
+
+ with ThreadPoolExecutor(max_workers=3) as executor:
+ results = list(
+ executor.map(
+ lambda obj: queue_stash.set(
+ root_verify_key,
+ mock_queue_object(),
+ ),
+ range(50),
+ )
+ )
+ objs = [item.unwrap() for item in results]
+
+ results = list(
+ executor.map(
+ delete_in_new_thread,
+ objs,
+ )
+ )
+ assert all(res.is_ok() for res in results), "Error occurred during execution"
+
+ assert len(queue_stash) == 0
diff --git a/packages/syft/tests/syft/stores/sqlite_document_store_test.py b/packages/syft/tests/syft/stores/sqlite_document_store_test.py
deleted file mode 100644
index 46ee540aa9c..00000000000
--- a/packages/syft/tests/syft/stores/sqlite_document_store_test.py
+++ /dev/null
@@ -1,520 +0,0 @@
-# stdlib
-from threading import Thread
-
-# third party
-import pytest
-
-# syft absolute
-from syft.store.document_store import QueryKeys
-from syft.store.sqlite_document_store import SQLiteStorePartition
-
-# relative
-from .store_fixtures_test import sqlite_store_partition_fn
-from .store_mocks_test import MockObjectType
-from .store_mocks_test import MockSyftObject
-
-
-def test_sqlite_store_partition_sanity(
- sqlite_store_partition: SQLiteStorePartition,
-) -> None:
- assert hasattr(sqlite_store_partition, "data")
- assert hasattr(sqlite_store_partition, "unique_keys")
- assert hasattr(sqlite_store_partition, "searchable_keys")
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_set(
- root_verify_key,
- sqlite_store_partition: SQLiteStorePartition,
-) -> None:
- obj = MockSyftObject(data=1)
- res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
-
- assert res.is_ok()
- assert res.ok() == obj
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_err()
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=True)
- assert res.is_ok()
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- obj2 = MockSyftObject(data=2)
- res = sqlite_store_partition.set(root_verify_key, obj2, ignore_duplicates=False)
- assert res.is_ok()
- assert res.ok() == obj2
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 2
- )
- repeats = 5
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
- res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert res.is_ok()
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 3 + idx
- )
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_delete(
- root_verify_key,
- sqlite_store_partition: SQLiteStorePartition,
-) -> None:
- objs = []
- repeats = 5
- for v in range(repeats):
- obj = MockSyftObject(data=v)
- sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- objs.append(obj)
-
- assert len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # random object
- obj = MockSyftObject(data="bogus")
- key = sqlite_store_partition.settings.store_key.with_obj(obj)
- res = sqlite_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- ) == len(objs)
-
- # cleanup store
- for idx, v in enumerate(objs):
- key = sqlite_store_partition.settings.store_key.with_obj(v)
- res = sqlite_store_partition.delete(root_verify_key, key)
- assert res.is_ok()
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- res = sqlite_store_partition.delete(root_verify_key, key)
- assert res.is_err()
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == len(objs) - idx - 1
- )
-
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 0
- )
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_update(
- root_verify_key,
- sqlite_store_partition: SQLiteStorePartition,
-) -> None:
- # add item
- obj = MockSyftObject(data=1)
- sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
-
- # fail to update missing keys
- rand_obj = MockSyftObject(data="bogus")
- key = sqlite_store_partition.settings.store_key.with_obj(rand_obj)
- res = sqlite_store_partition.update(root_verify_key, key, obj)
- assert res.is_err()
-
- # update the key multiple times
- repeats = 5
- for v in range(repeats):
- key = sqlite_store_partition.settings.store_key.with_obj(obj)
- obj_new = MockSyftObject(data=v)
-
- res = sqlite_store_partition.update(root_verify_key, key, obj_new)
- assert res.is_ok()
-
- # The ID should stay the same on update, unly the values are updated.
- assert (
- len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- == 1
- )
- assert (
- sqlite_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- == obj.id
- )
- assert (
- sqlite_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .id
- != obj_new.id
- )
- assert (
- sqlite_store_partition.all(
- root_verify_key,
- )
- .ok()[0]
- .data
- == v
- )
-
- stored = sqlite_store_partition.get_all_from_store(
- root_verify_key, QueryKeys(qks=[key])
- )
- assert stored.ok()[0].data == v
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_set_threading(
- sqlite_workspace: tuple,
- root_verify_key,
-) -> None:
- thread_cnt = 3
- repeats = 5
-
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
-
- sqlite_store_partition = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
- for idx in range(repeats):
- for _ in range(10):
- obj = MockObjectType(data=idx)
- res = sqlite_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- return execution_err
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
- sqlite_store_partition = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
- stored_cnt = len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == thread_cnt * repeats
-
-
-# @pytest.mark.skip(reason="Joblib is flaky")
-# def test_sqlite_store_partition_set_joblib(
-# root_verify_key,
-# sqlite_workspace: Tuple,
-# ) -> None:
-# thread_cnt = 3
-# repeats = 5
-
-# def _kv_cbk(tid: int) -> None:
-# for idx in range(repeats):
-# sqlite_store_partition = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-# obj = MockObjectType(data=idx)
-
-# for _ in range(10):
-# res = sqlite_store_partition.set(
-# root_verify_key, obj, ignore_duplicates=False
-# )
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-
-# for execution_err in errs:
-# assert execution_err is None
-
-# sqlite_store_partition = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-# stored_cnt = len(
-# sqlite_store_partition.all(
-# root_verify_key,
-# ).ok()
-# )
-# assert stored_cnt == thread_cnt * repeats
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_update_threading(
- root_verify_key,
- sqlite_workspace: tuple,
-) -> None:
- thread_cnt = 3
- repeats = 5
-
- sqlite_store_partition = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
- obj = MockSyftObject(data=0)
- key = sqlite_store_partition.settings.store_key.with_obj(obj)
- sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
-
- sqlite_store_partition_local = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
- for repeat in range(repeats):
- obj = MockSyftObject(data=repeat)
-
- for _ in range(10):
- res = sqlite_store_partition_local.update(root_verify_key, key, obj)
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
-
-# @pytest.mark.skip(reason="Joblib is flaky")
-# def test_sqlite_store_partition_update_joblib(
-# root_verify_key,
-# sqlite_workspace: Tuple,
-# ) -> None:
-# thread_cnt = 3
-# repeats = 5
-
-# sqlite_store_partition = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-# obj = MockSyftObject(data=0)
-# key = sqlite_store_partition.settings.store_key.with_obj(obj)
-# sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False)
-
-# def _kv_cbk(tid: int) -> None:
-# sqlite_store_partition_local = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-# for repeat in range(repeats):
-# obj = MockSyftObject(data=repeat)
-
-# for _ in range(10):
-# res = sqlite_store_partition_local.update(root_verify_key, key, obj)
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-
-# for execution_err in errs:
-# assert execution_err is None
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=3)
-def test_sqlite_store_partition_set_delete_threading(
- root_verify_key,
- sqlite_workspace: tuple,
-) -> None:
- thread_cnt = 3
- repeats = 5
- execution_err = None
-
- def _kv_cbk(tid: int) -> None:
- nonlocal execution_err
- sqlite_store_partition = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
-
- for idx in range(repeats):
- obj = MockSyftObject(data=idx)
-
- for _ in range(10):
- res = sqlite_store_partition.set(
- root_verify_key, obj, ignore_duplicates=False
- )
- if res.is_ok():
- break
-
- if res.is_err():
- execution_err = res
- assert res.is_ok()
-
- key = sqlite_store_partition.settings.store_key.with_obj(obj)
-
- res = sqlite_store_partition.delete(root_verify_key, key)
- if res.is_err():
- execution_err = res
- assert res.is_ok(), res
-
- tids = []
- for tid in range(thread_cnt):
- thread = Thread(target=_kv_cbk, args=(tid,))
- thread.start()
-
- tids.append(thread)
-
- for thread in tids:
- thread.join()
-
- assert execution_err is None
-
- sqlite_store_partition = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace
- )
- stored_cnt = len(
- sqlite_store_partition.all(
- root_verify_key,
- ).ok()
- )
- assert stored_cnt == 0
-
-
-# @pytest.mark.skip(reason="Joblib is flaky")
-# def test_sqlite_store_partition_set_delete_joblib(
-# root_verify_key,
-# sqlite_workspace: Tuple,
-# ) -> None:
-# thread_cnt = 3
-# repeats = 5
-
-# def _kv_cbk(tid: int) -> None:
-# sqlite_store_partition = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-
-# for idx in range(repeats):
-# obj = MockSyftObject(data=idx)
-
-# for _ in range(10):
-# res = sqlite_store_partition.set(
-# root_verify_key, obj, ignore_duplicates=False
-# )
-# if res.is_ok():
-# break
-
-# if res.is_err():
-# return res
-
-# key = sqlite_store_partition.settings.store_key.with_obj(obj)
-
-# res = sqlite_store_partition.delete(root_verify_key, key)
-# if res.is_err():
-# return res
-# return None
-
-# errs = Parallel(n_jobs=thread_cnt)(
-# delayed(_kv_cbk)(idx) for idx in range(thread_cnt)
-# )
-# for execution_err in errs:
-# assert execution_err is None
-
-# sqlite_store_partition = sqlite_store_partition_fn(
-# root_verify_key, sqlite_workspace
-# )
-# stored_cnt = len(
-# sqlite_store_partition.all(
-# root_verify_key,
-# ).ok()
-# )
-# assert stored_cnt == 0
diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py
index b64d14be8e3..47452e9740e 100644
--- a/packages/syft/tests/syft/stores/store_fixtures_test.py
+++ b/packages/syft/tests/syft/stores/store_fixtures_test.py
@@ -1,72 +1,30 @@
# stdlib
-from collections.abc import Generator
-import os
-from pathlib import Path
-from secrets import token_hex
-import tempfile
import uuid
-# third party
-import pytest
-
# syft absolute
from syft.server.credentials import SyftVerifyKey
from syft.service.action.action_permissions import ActionObjectPermission
from syft.service.action.action_permissions import ActionPermission
-from syft.service.action.action_store import DictActionStore
-from syft.service.action.action_store import MongoActionStore
-from syft.service.action.action_store import SQLiteActionStore
-from syft.service.queue.queue_stash import QueueStash
from syft.service.user.user import User
from syft.service.user.user import UserCreate
from syft.service.user.user_roles import ServiceRole
from syft.service.user.user_stash import UserStash
-from syft.store.dict_document_store import DictDocumentStore
-from syft.store.dict_document_store import DictStoreConfig
-from syft.store.dict_document_store import DictStorePartition
+from syft.store.db.sqlite import SQLiteDBConfig
+from syft.store.db.sqlite import SQLiteDBManager
from syft.store.document_store import DocumentStore
-from syft.store.document_store import PartitionSettings
-from syft.store.locks import LockingConfig
-from syft.store.locks import NoLockingConfig
-from syft.store.locks import ThreadingLockingConfig
-from syft.store.mongo_client import MongoStoreClientConfig
-from syft.store.mongo_document_store import MongoDocumentStore
-from syft.store.mongo_document_store import MongoStoreConfig
-from syft.store.mongo_document_store import MongoStorePartition
-from syft.store.sqlite_document_store import SQLiteDocumentStore
-from syft.store.sqlite_document_store import SQLiteStoreClientConfig
-from syft.store.sqlite_document_store import SQLiteStoreConfig
-from syft.store.sqlite_document_store import SQLiteStorePartition
from syft.types.uid import UID
# relative
from .store_constants_test import TEST_SIGNING_KEY_NEW_ADMIN
from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN
-from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT
-from .store_mocks_test import MockObjectType
-
-MONGO_CLIENT_CACHE = None
-
-locking_scenarios = [
- "nop",
- "threading",
-]
-
-
-def str_to_locking_config(conf: str) -> LockingConfig:
- if conf == "nop":
- return NoLockingConfig()
- elif conf == "threading":
- return ThreadingLockingConfig()
- else:
- raise NotImplementedError(f"unknown locking config {conf}")
def document_store_with_admin(
server_uid: UID, verify_key: SyftVerifyKey
) -> DocumentStore:
- document_store = DictDocumentStore(
- server_uid=server_uid, root_verify_key=verify_key
+ config = SQLiteDBConfig()
+ document_store = SQLiteDBManager(
+ server_uid=server_uid, root_verify_key=verify_key, config=config
)
password = uuid.uuid4().hex
@@ -94,307 +52,3 @@ def document_store_with_admin(
)
return document_store
-
-
-@pytest.fixture(scope="function")
-def sqlite_workspace() -> Generator:
- sqlite_db_name = token_hex(8) + ".sqlite"
- root = os.getenv("SYFT_TEMP_ROOT", "syft")
- sqlite_workspace_folder = Path(
- tempfile.gettempdir(), root, "fixture_sqlite_workspace"
- )
- sqlite_workspace_folder.mkdir(parents=True, exist_ok=True)
-
- db_path = sqlite_workspace_folder / sqlite_db_name
-
- if db_path.exists():
- db_path.unlink()
-
- yield sqlite_workspace_folder, sqlite_db_name
-
- try:
- db_path.exists() and db_path.unlink()
- except BaseException as e:
- print("failed to cleanup sqlite db", e)
-
-
-def sqlite_store_partition_fn(
- root_verify_key,
- sqlite_workspace: tuple[Path, str],
- locking_config_name: str = "nop",
-):
- workspace, db_name = sqlite_workspace
- sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace)
-
- locking_config = str_to_locking_config(locking_config_name)
- store_config = SQLiteStoreConfig(
- client_config=sqlite_config, locking_config=locking_config
- )
-
- settings = PartitionSettings(name="test", object_type=MockObjectType)
-
- store = SQLiteStorePartition(
- UID(), root_verify_key, settings=settings, store_config=store_config
- )
-
- store.init_store().unwrap()
-
- return store
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def sqlite_store_partition(
- root_verify_key, sqlite_workspace: tuple[Path, str], request
-):
- locking_config_name = request.param
- store = sqlite_store_partition_fn(
- root_verify_key, sqlite_workspace, locking_config_name=locking_config_name
- )
-
- yield store
-
-
-def sqlite_document_store_fn(
- root_verify_key,
- sqlite_workspace: tuple[Path, str],
- locking_config_name: str = "nop",
-):
- workspace, db_name = sqlite_workspace
- sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace)
-
- locking_config = str_to_locking_config(locking_config_name)
- store_config = SQLiteStoreConfig(
- client_config=sqlite_config, locking_config=locking_config
- )
-
- return SQLiteDocumentStore(UID(), root_verify_key, store_config=store_config)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def sqlite_document_store(root_verify_key, sqlite_workspace: tuple[Path, str], request):
- locking_config_name = request.param
- store = sqlite_document_store_fn(
- root_verify_key, sqlite_workspace, locking_config_name=locking_config_name
- )
- yield store
-
-
-def sqlite_queue_stash_fn(
- root_verify_key,
- sqlite_workspace: tuple[Path, str],
- locking_config_name: str = "threading",
-):
- store = sqlite_document_store_fn(
- root_verify_key,
- sqlite_workspace,
- locking_config_name=locking_config_name,
- )
- return QueueStash(store=store)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def sqlite_queue_stash(root_verify_key, sqlite_workspace: tuple[Path, str], request):
- locking_config_name = request.param
- yield sqlite_queue_stash_fn(
- root_verify_key, sqlite_workspace, locking_config_name=locking_config_name
- )
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def sqlite_action_store(sqlite_workspace: tuple[Path, str], request):
- workspace, db_name = sqlite_workspace
- locking_config_name = request.param
-
- sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace)
-
- locking_config = str_to_locking_config(locking_config_name)
- store_config = SQLiteStoreConfig(
- client_config=sqlite_config,
- locking_config=locking_config,
- )
-
- ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT)
-
- server_uid = UID()
- document_store = document_store_with_admin(server_uid, ver_key)
-
- yield SQLiteActionStore(
- server_uid=server_uid,
- store_config=store_config,
- root_verify_key=ver_key,
- document_store=document_store,
- )
-
-
-def mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name: str = "mongo_db",
- locking_config_name: str = "nop",
-):
- mongo_config = MongoStoreClientConfig(client=mongo_client)
-
- locking_config = str_to_locking_config(locking_config_name)
-
- store_config = MongoStoreConfig(
- client_config=mongo_config,
- db_name=mongo_db_name,
- locking_config=locking_config,
- )
- settings = PartitionSettings(name="test", object_type=MockObjectType)
-
- return MongoStorePartition(
- UID(), root_verify_key, settings=settings, store_config=store_config
- )
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def mongo_store_partition(root_verify_key, mongo_client, request):
- mongo_db_name = token_hex(8)
- locking_config_name = request.param
-
- partition = mongo_store_partition_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- locking_config_name=locking_config_name,
- )
- yield partition
-
- # cleanup db
- try:
- mongo_client.drop_database(mongo_db_name)
- except BaseException as e:
- print("failed to cleanup mongo fixture", e)
-
-
-def mongo_document_store_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name: str = "mongo_db",
- locking_config_name: str = "nop",
-):
- locking_config = str_to_locking_config(locking_config_name)
- mongo_config = MongoStoreClientConfig(client=mongo_client)
- store_config = MongoStoreConfig(
- client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config
- )
-
- mongo_client.drop_database(mongo_db_name)
-
- return MongoDocumentStore(UID(), root_verify_key, store_config=store_config)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def mongo_document_store(root_verify_key, mongo_client, request):
- locking_config_name = request.param
- mongo_db_name = token_hex(8)
- yield mongo_document_store_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- locking_config_name=locking_config_name,
- )
-
-
-def mongo_queue_stash_fn(mongo_document_store):
- return QueueStash(store=mongo_document_store)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def mongo_queue_stash(root_verify_key, mongo_client, request):
- mongo_db_name = token_hex(8)
- locking_config_name = request.param
-
- store = mongo_document_store_fn(
- mongo_client,
- root_verify_key,
- mongo_db_name=mongo_db_name,
- locking_config_name=locking_config_name,
- )
- yield mongo_queue_stash_fn(store)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def mongo_action_store(mongo_client, request):
- mongo_db_name = token_hex(8)
- locking_config_name = request.param
- locking_config = str_to_locking_config(locking_config_name)
-
- mongo_config = MongoStoreClientConfig(client=mongo_client)
- store_config = MongoStoreConfig(
- client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config
- )
- ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT)
- server_uid = UID()
- document_store = document_store_with_admin(server_uid, ver_key)
- mongo_action_store = MongoActionStore(
- server_uid=server_uid,
- store_config=store_config,
- root_verify_key=ver_key,
- document_store=document_store,
- )
-
- yield mongo_action_store
-
-
-def dict_store_partition_fn(
- root_verify_key,
- locking_config_name: str = "nop",
-):
- locking_config = str_to_locking_config(locking_config_name)
- store_config = DictStoreConfig(locking_config=locking_config)
- settings = PartitionSettings(name="test", object_type=MockObjectType)
-
- return DictStorePartition(
- UID(), root_verify_key, settings=settings, store_config=store_config
- )
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def dict_store_partition(root_verify_key, request):
- locking_config_name = request.param
- yield dict_store_partition_fn(
- root_verify_key, locking_config_name=locking_config_name
- )
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def dict_action_store(request):
- locking_config_name = request.param
- locking_config = str_to_locking_config(locking_config_name)
-
- store_config = DictStoreConfig(locking_config=locking_config)
- ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT)
- server_uid = UID()
- document_store = document_store_with_admin(server_uid, ver_key)
-
- yield DictActionStore(
- server_uid=server_uid,
- store_config=store_config,
- root_verify_key=ver_key,
- document_store=document_store,
- )
-
-
-def dict_document_store_fn(root_verify_key, locking_config_name: str = "nop"):
- locking_config = str_to_locking_config(locking_config_name)
- store_config = DictStoreConfig(locking_config=locking_config)
- return DictDocumentStore(UID(), root_verify_key, store_config=store_config)
-
-
-@pytest.fixture(scope="function", params=locking_scenarios)
-def dict_document_store(root_verify_key, request):
- locking_config_name = request.param
- yield dict_document_store_fn(
- root_verify_key, locking_config_name=locking_config_name
- )
-
-
-def dict_queue_stash_fn(dict_document_store):
- return QueueStash(store=dict_document_store)
-
-
-@pytest.fixture(scope="function")
-def dict_queue_stash(dict_document_store):
- yield dict_queue_stash_fn(dict_document_store)
diff --git a/packages/syft/tests/syft/types/errors_test.py b/packages/syft/tests/syft/types/errors_test.py
index 4ac185ba421..ca8e557ef11 100644
--- a/packages/syft/tests/syft/types/errors_test.py
+++ b/packages/syft/tests/syft/types/errors_test.py
@@ -5,6 +5,7 @@
import pytest
# syft absolute
+import syft
from syft.service.context import AuthedServiceContext
from syft.service.user.user_roles import ServiceRole
from syft.types.errors import SyftException
@@ -52,3 +53,24 @@ def test_get_message(role, private_msg, public_msg, expected_message):
mock_context.dev_mode = False
exception = SyftException(private_msg, public_message=public_msg)
assert exception.get_message(mock_context) == expected_message
+
+
+def test_syfterror_raise_works_in_pytest():
+ """
+ SyftError has own exception handler that wasnt working in notebook testing environments,
+ this is just a sanity check to make sure it works in pytest.
+ """
+ with pytest.raises(SyftException):
+ raise SyftException(public_message="-")
+
+ with syft.raises(SyftException(public_message="-")):
+ raise SyftException(public_message="-")
+
+ # syft.raises works with wildcard
+ with syft.raises(SyftException(public_message="*test message*")):
+ raise SyftException(public_message="longer test message")
+
+ # syft.raises with different public message should raise
+ with pytest.raises(AssertionError):
+ with syft.raises(SyftException(public_message="*different message*")):
+ raise SyftException(public_message="longer test message")
diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py
index 9410cecd695..5de444197ba 100644
--- a/packages/syft/tests/syft/users/user_code_test.py
+++ b/packages/syft/tests/syft/users/user_code_test.py
@@ -16,6 +16,7 @@
from syft.service.response import SyftError
from syft.service.response import SyftSuccess
from syft.service.user.user import User
+from syft.service.user.user import UserView
from syft.service.user.user_roles import ServiceRole
from syft.types.errors import SyftException
@@ -46,12 +47,10 @@ def test_repr_markdown_not_throwing_error(guest_client: DatasiteClient) -> None:
assert result[0]._repr_markdown_()
-@pytest.mark.parametrize("delete_original_admin", [False, True])
def test_new_admin_can_list_user_code(
worker: Worker,
ds_client: DatasiteClient,
faker: Faker,
- delete_original_admin: bool,
) -> None:
root_client = worker.root_client
@@ -66,17 +65,16 @@ def test_new_admin_can_list_user_code(
admin = root_client.login(email=email, password=pw)
- root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN)
-
- if delete_original_admin:
- res = root_client.api.services.user.delete(root_client.account.id)
- assert not isinstance(res, SyftError)
+ result: UserView = root_client.api.services.user.update(
+ uid=admin.account.id, role=ServiceRole.ADMIN
+ )
+ assert result.role == ServiceRole.ADMIN
user_code_stash = worker.services.user_code.stash
- user_code = user_code_stash.get_all(user_code_stash.store.root_verify_key).ok()
+ user_codes = user_code_stash._data
- assert len(user_code) == len(admin.code.get_all())
- assert {c.id for c in user_code} == {c.id for c in admin.code}
+ assert 1 == len(admin.code.get_all())
+ assert {c.id for c in user_codes} == {c.id for c in admin.code}
def test_user_code(worker) -> None:
diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py
index 45e31da18fe..59e905ee657 100644
--- a/packages/syft/tests/syft/users/user_service_test.py
+++ b/packages/syft/tests/syft/users/user_service_test.py
@@ -9,7 +9,9 @@
from pytest import MonkeyPatch
# syft absolute
+import syft as sy
from syft import orchestra
+from syft.client.client import SyftClient
from syft.server.credentials import SyftVerifyKey
from syft.server.worker import Worker
from syft.service.context import AuthedServiceContext
@@ -208,7 +210,7 @@ def test_userservice_get_all_success(
expected_output = [x.to(UserView) for x in mock_get_all_output]
@as_result(StashException)
- def mock_get_all(credentials: SyftVerifyKey) -> list[User]:
+ def mock_get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]:
return mock_get_all_output
monkeypatch.setattr(user_service.stash, "get_all", mock_get_all)
@@ -222,23 +224,6 @@ def mock_get_all(credentials: SyftVerifyKey) -> list[User]:
)
-def test_userservice_get_all_error(
- monkeypatch: MonkeyPatch,
- user_service: UserService,
- authed_context: AuthedServiceContext,
-) -> None:
- @as_result(StashException)
- def mock_get_all(credentials: SyftVerifyKey) -> NoReturn:
- raise StashException
-
- monkeypatch.setattr(user_service.stash, "get_all", mock_get_all)
-
- with pytest.raises(StashException) as exc:
- user_service.get_all(authed_context)
-
- assert exc.type == StashException
-
-
def test_userservice_search(
monkeypatch: MonkeyPatch,
user_service: UserService,
@@ -246,13 +231,13 @@ def test_userservice_search(
guest_user: User,
) -> None:
@as_result(SyftException)
- def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> list[User]:
+ def get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]:
for key in kwargs.keys():
if hasattr(guest_user, key):
return [guest_user]
return []
- monkeypatch.setattr(user_service.stash, "find_all", mock_find_all)
+ monkeypatch.setattr(user_service.stash, "get_all", get_all)
expected_output = [guest_user.to(UserView)]
@@ -541,27 +526,10 @@ def mock_get_by_email(credentials: SyftVerifyKey, email: str) -> NoReturn:
assert exc.value.public_message == expected_output
-def test_userservice_admin_verify_key_error(
- monkeypatch: MonkeyPatch, user_service: UserService
-) -> None:
- expected_output = "failed to get admin verify_key"
-
- def mock_admin_verify_key() -> UID:
- raise SyftException(public_message=expected_output)
-
- monkeypatch.setattr(user_service.stash, "admin_verify_key", mock_admin_verify_key)
-
- with pytest.raises(SyftException) as exc:
- user_service.admin_verify_key()
-
- assert exc.type == SyftException
- assert exc.value.public_message == expected_output
-
-
def test_userservice_admin_verify_key_success(
monkeypatch: MonkeyPatch, user_service: UserService, worker
) -> None:
- response = user_service.admin_verify_key()
+ response = user_service.root_verify_key
assert isinstance(response, SyftVerifyKey)
assert response == worker.root_client.credentials.verify_key
@@ -586,7 +554,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> User:
new_callable=mock.PropertyMock,
return_value=settings_with_signup_enabled(worker),
):
- mock_worker = Worker.named(name="mock-server")
+ mock_worker = Worker.named(name="mock-server", db_url="sqlite://")
server_context = ServerServiceContext(server=mock_worker)
with pytest.raises(SyftException) as exc:
@@ -616,7 +584,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> NoReturn:
new_callable=mock.PropertyMock,
return_value=settings_with_signup_enabled(worker),
):
- mock_worker = Worker.named(name="mock-server")
+ mock_worker = Worker.named(name="mock-server", db_url="sqlite://")
server_context = ServerServiceContext(server=mock_worker)
with pytest.raises(StashException) as exc:
@@ -645,7 +613,7 @@ def mock_set(*args, **kwargs) -> User:
new_callable=mock.PropertyMock,
return_value=settings_with_signup_enabled(worker),
):
- mock_worker = Worker.named(name="mock-server")
+ mock_worker = Worker.named(name="mock-server", db_url="sqlite://")
server_context = ServerServiceContext(server=mock_worker)
monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email)
@@ -684,7 +652,7 @@ def mock_set(
new_callable=mock.PropertyMock,
return_value=settings_with_signup_enabled(worker),
):
- mock_worker = Worker.named(name="mock-server")
+ mock_worker = Worker.named(name="mock-server", db_url="sqlite://")
server_context = ServerServiceContext(server=mock_worker)
monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email)
@@ -790,3 +758,43 @@ def test_userservice_update_via_client_with_mixed_args():
email="new_user@openmined.org", password="newpassword"
)
assert user_client.account.name == "User name"
+
+
+def test_reset_password():
+ server = orchestra.launch(name="datasite-test", reset=True)
+
+ datasite_client = server.login(email="info@openmined.org", password="changethis")
+ datasite_client.register(
+ email="new_syft_user@openmined.org",
+ password="verysecurepassword",
+ password_verify="verysecurepassword",
+ name="New User",
+ )
+ guest_client: SyftClient = server.login_as_guest()
+ guest_client.forgot_password(email="new_syft_user@openmined.org")
+ temp_token = datasite_client.users.request_password_reset(
+ datasite_client.notifications[-1].linked_obj.resolve.id
+ )
+ guest_client.reset_password(token=temp_token, new_password="Password123")
+ server.login(email="new_syft_user@openmined.org", password="Password123")
+
+
+def test_root_cannot_be_deleted():
+ server = orchestra.launch(name="datasite-test", reset=True)
+ datasite_client = server.login(email="info@openmined.org", password="changethis")
+
+ new_admin_email = "admin@openmined.org"
+ new_admin_pass = "changethis2"
+ datasite_client.register(
+ name="second admin",
+ email=new_admin_email,
+ password=new_admin_pass,
+ password_verify=new_admin_pass,
+ )
+ # update role
+ new_user_id = datasite_client.users.search(email=new_admin_email)[0].id
+ datasite_client.users.update(uid=new_user_id, role="admin")
+
+ new_admin_client = server.login(email=new_admin_email, password=new_admin_pass)
+ with sy.raises(sy.SyftException):
+ new_admin_client.users.delete(datasite_client.account.id)
diff --git a/packages/syft/tests/syft/users/user_stash_test.py b/packages/syft/tests/syft/users/user_stash_test.py
index ee8e4b1edc9..584e616d093 100644
--- a/packages/syft/tests/syft/users/user_stash_test.py
+++ b/packages/syft/tests/syft/users/user_stash_test.py
@@ -1,5 +1,6 @@
# third party
from faker import Faker
+import pytest
# syft absolute
from syft.server.credentials import SyftSigningKey
@@ -14,7 +15,7 @@
def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> User:
# prepare: add mock data
- result = user_stash.partition.set(root_datasite_client.credentials.verify_key, user)
+ result = user_stash.set(root_datasite_client.credentials.verify_key, user)
assert result.is_ok()
user = result.ok()
@@ -26,30 +27,29 @@ def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> Us
def test_userstash_set(
root_datasite_client, user_stash: UserStash, guest_user: User
) -> None:
- result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user)
- assert result.is_ok()
-
- created_user = result.ok()
+ created_user = user_stash.set(
+ root_datasite_client.credentials.verify_key, guest_user
+ ).unwrap()
assert isinstance(created_user, User)
assert guest_user == created_user
- assert guest_user.id in user_stash.partition.data
+ assert user_stash.exists(
+ root_datasite_client.credentials.verify_key, created_user.id
+ )
def test_userstash_set_duplicate(
root_datasite_client, user_stash: UserStash, guest_user: User
) -> None:
- result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user)
- assert result.is_ok()
+ _ = user_stash.set(root_datasite_client.credentials.verify_key, guest_user).unwrap()
+ original_count = len(user_stash._data)
- original_count = len(user_stash.partition.data)
+ with pytest.raises(SyftException) as exc:
+ _ = user_stash.set(
+ root_datasite_client.credentials.verify_key, guest_user
+ ).unwrap()
+ assert exc.public_message
- result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user)
- assert result.is_err()
- exc = result.err()
- assert type(exc) == SyftException
- assert exc.public_message
-
- assert len(user_stash.partition.data) == original_count
+ assert len(user_stash._data) == original_count
def test_userstash_get_by_uid(
@@ -171,11 +171,9 @@ def test_userstash_get_by_role(
# prepare: add mock data
user = add_mock_user(root_datasite_client, user_stash, guest_user)
- result = user_stash.get_by_role(
+ searched_user = user_stash.get_by_role(
root_datasite_client.credentials.verify_key, role=ServiceRole.GUEST
- )
- assert result.is_ok()
- searched_user = result.ok()
+ ).unwrap()
assert user == searched_user
diff --git a/packages/syft/tests/syft/users/user_test.py b/packages/syft/tests/syft/users/user_test.py
index 1a9830ed48c..4c727b4f1fe 100644
--- a/packages/syft/tests/syft/users/user_test.py
+++ b/packages/syft/tests/syft/users/user_test.py
@@ -1,5 +1,6 @@
# stdlib
from secrets import token_hex
+import time
# third party
from faker import Faker
@@ -388,6 +389,9 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non
admin_client = get_mock_client(worker.root_client, ServiceRole.ADMIN)
assert admin_client.account.role == ServiceRole.ADMIN
+ # wait for the user to be created for sorting purposes
+ time.sleep(0.01)
+
admin_client.register(
name="Sheldon Cooper",
email="sheldon@caltech.edu",
@@ -424,6 +428,7 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non
with pytest.raises(SyftException):
ds_client.account.update(role="guest")
+ with pytest.raises(SyftException):
ds_client.account.update(role="data_scientist")
# now we set sheldon's role to admin. Only now he can change his role
diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py
index ce0f027a101..f52772038cf 100644
--- a/packages/syft/tests/syft/worker_test.py
+++ b/packages/syft/tests/syft/worker_test.py
@@ -16,16 +16,17 @@
from syft.server.credentials import SyftVerifyKey
from syft.server.worker import Worker
from syft.service.action.action_object import ActionObject
-from syft.service.action.action_store import DictActionStore
+from syft.service.action.action_store import ActionObjectStash
from syft.service.context import AuthedServiceContext
from syft.service.queue.queue_stash import QueueItem
from syft.service.response import SyftError
from syft.service.user.user import User
from syft.service.user.user import UserCreate
from syft.service.user.user import UserView
+from syft.service.user.user_stash import UserStash
+from syft.store.db.sqlite import SQLiteDBManager
from syft.types.errors import SyftException
from syft.types.result import Ok
-from syft.types.uid import UID
test_signing_key_string = (
"b7803e90a6f3f4330afbd943cef3451c716b338b17a9cf40a0a309bc38bc366d"
@@ -75,30 +76,42 @@ def test_signing_key() -> None:
assert test_verify_key == test_verify_key_2
-def test_action_store() -> None:
+@pytest.fixture(
+ scope="function",
+ params=[
+ "tODOsqlite_address",
+ # "TODOpostgres_address", # will be used when we have a postgres CI tests
+ ],
+)
+def action_object_stash() -> ActionObjectStash:
+ root_verify_key = SyftVerifyKey.from_string(test_verify_key_string)
+ db_manager = SQLiteDBManager.random(root_verify_key=root_verify_key)
+ stash = ActionObjectStash(store=db_manager)
+ _ = UserStash(store=db_manager)
+ stash.db.init_tables()
+ yield stash
+
+
+def test_action_store(action_object_stash: ActionObjectStash) -> None:
test_signing_key = SyftSigningKey.from_string(test_signing_key_string)
- action_store = DictActionStore(server_uid=UID())
- uid = UID()
+ test_verify_key = test_signing_key.verify_key
raw_data = np.array([1, 2, 3])
test_object = ActionObject.from_obj(raw_data)
+ uid = test_object.id
- set_result = action_store.set(
+ action_object_stash.set_or_update(
uid=uid,
- credentials=test_signing_key,
+ credentials=test_verify_key,
syft_object=test_object,
has_result_read_permission=True,
- )
- assert set_result.is_ok()
- test_object_result = action_store.get(uid=uid, credentials=test_signing_key)
- assert test_object_result.is_ok()
- assert (test_object == test_object_result.ok()).all()
+ ).unwrap()
+ from_stash = action_object_stash.get(uid=uid, credentials=test_verify_key).unwrap()
+ assert (test_object == from_stash).all()
test_verift_key_2 = SyftVerifyKey.from_string(test_verify_key_string_2)
- test_object_result_fail = action_store.get(uid=uid, credentials=test_verift_key_2)
- assert test_object_result_fail.is_err()
- exc = test_object_result_fail.err()
- assert type(exc) == SyftException
- assert "denied" in exc.public_message
+ with pytest.raises(SyftException) as exc:
+ action_object_stash.get(uid=uid, credentials=test_verift_key_2).unwrap()
+ assert "denied" in exc.public_message
def test_user_transform() -> None:
@@ -223,14 +236,6 @@ def post_add(context: Any, name: str, new_result: Any) -> Any:
action_object.syft_post_hooks__["__add__"] = []
-def test_worker_serde(worker) -> None:
- ser = sy.serialize(worker, to_bytes=True)
- de = sy.deserialize(ser, from_bytes=True)
-
- assert de.signing_key == worker.signing_key
- assert de.id == worker.id
-
-
@pytest.fixture(params=[0])
def worker_with_proc(request):
worker = Worker(
diff --git a/packages/syftcli/manifest.yml b/packages/syftcli/manifest.yml
index c17457f6e6b..4b8754f154f 100644
--- a/packages/syftcli/manifest.yml
+++ b/packages/syftcli/manifest.yml
@@ -6,7 +6,7 @@ dockerTag: 0.9.2-beta.3
images:
- docker.io/openmined/syft-frontend:0.9.2-beta.3
- docker.io/openmined/syft-backend:0.9.2-beta.3
- - docker.io/library/mongo:7.0.4
+ - docker.io/library/postgres:16.1
- docker.io/traefik:v2.11.0
configFiles:
diff --git a/scripts/dev_tools.sh b/scripts/dev_tools.sh
index 20a74b597e0..c23b56c05b9 100755
--- a/scripts/dev_tools.sh
+++ b/scripts/dev_tools.sh
@@ -23,15 +23,13 @@ function docker_list_exposed_ports() {
if [[ -z "$1" ]]; then
# list db, redis, rabbitmq, and seaweedfs ports
- docker_list_exposed_ports "db\|seaweedfs\|mongo"
+ docker_list_exposed_ports "db\|seaweedfs"
else
PORT=$1
if docker ps | grep ":${PORT}" | grep -q 'redis'; then
${command} redis://127.0.0.1:${PORT}
elif docker ps | grep ":${PORT}" | grep -q 'postgres'; then
${command} postgresql://postgres:changethis@127.0.0.1:${PORT}/app
- elif docker ps | grep ":${PORT}" | grep -q 'mongo'; then
- ${command} mongodb://root:example@127.0.0.1:${PORT}
else
${command} http://localhost:${PORT}
fi
diff --git a/scripts/reset_k8s.sh b/scripts/reset_k8s.sh
index d0d245be6f2..033cb24ed31 100755
--- a/scripts/reset_k8s.sh
+++ b/scripts/reset_k8s.sh
@@ -1,22 +1,49 @@
#!/bin/bash
-# WARNING: this will drop the 'app' database in your mongo-0 instance in the syft namespace
echo $1
-# Dropping the database on mongo-0
-if [ -z $1 ]; then
- MONGO_POD_NAME="mongo-0"
-else
- MONGO_POD_NAME=$1
-fi
+# Default pod name
+DEFAULT_POD_NAME="postgres-0"
+
+# Use the provided pod name or the default
+POSTGRES_POD_NAME=${1:-$DEFAULT_POD_NAME}
+
+# SQL commands to reset all tables
+RESET_COMMAND="
+DO \$\$
+DECLARE
+ r RECORD;
+BEGIN
+ -- Disable all triggers
+ SET session_replication_role = 'replica';
+
+ -- Truncate all tables in the current schema
+ FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = current_schema()) LOOP
+ EXECUTE 'TRUNCATE TABLE ' || quote_ident(r.tablename) || ' CASCADE';
+ END LOOP;
+
+ -- Re-enable all triggers
+ SET session_replication_role = 'origin';
+END \$\$;
+
+-- Reset all sequences
+DO \$\$
+DECLARE
+ r RECORD;
+BEGIN
+ FOR r IN (SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = current_schema()) LOOP
+ EXECUTE 'ALTER SEQUENCE ' || quote_ident(r.sequence_name) || ' RESTART WITH 1';
+ END LOOP;
+END \$\$;
+"
-DROPCMD="<&1
+echo "All tables in $POSTGRES_POD_NAME have been reset."
# Resetting the backend pod
BACKEND_POD=$(kubectl get pods -n syft -o jsonpath="{.items[*].metadata.name}" | tr ' ' '\n' | grep -E ".*backend.*")
diff --git a/scripts/reset_mongo.sh b/scripts/reset_mongo.sh
deleted file mode 100755
index ac1641f68e4..00000000000
--- a/scripts/reset_mongo.sh
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/bin/bash
-
-# WARNING: this will drop the app database in all your mongo dbs
-echo $1
-
-if [ -z $1 ]; then
- MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo)
-else
- MONGO_CONTAINER_NAME=$1
-fi
-
-DROPCMD="<&1
\ No newline at end of file
diff --git a/scripts/reset_network.sh b/scripts/reset_network.sh
deleted file mode 100755
index ce5f863ff14..00000000000
--- a/scripts/reset_network.sh
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/bin/bash
-
-MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo)
-DROPCMD="<&1
-
-# flush the worker queue
-. ${BASH_SOURCE%/*}/flush_queue.sh
-
-# reset docker service to clear out weird network issues
-sudo service docker restart
-
-# make sure all containers start
-. ${BASH_SOURCE%/*}/../packages/grid/scripts/containers.sh
diff --git a/tox.ini b/tox.ini
index 958dc2a8772..6e37f408ded 100644
--- a/tox.ini
+++ b/tox.ini
@@ -467,7 +467,7 @@ commands =
; sleep 30
# wait for test-datasite-1
- bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
+ bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
@@ -721,12 +721,12 @@ commands =
sleep 30
# wait for test gateway 1
- bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft
+ bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft
# wait for test datasite 1
- bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
+ bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
@@ -813,7 +813,7 @@ commands =
sleep 30
# wait for test-datasite-1
- bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
+ bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
@@ -887,7 +887,7 @@ allowlist_externals =
setenv =
CLUSTER_NAME = {env:CLUSTER_NAME:syft}
CLUSTER_HTTP_PORT = {env:SERVER_PORT:8080}
-; Usage for posargs: names of the relevant services among {frontend backend proxy mongo seaweedfs registry}
+; Usage for posargs: names of the relevant services among {frontend backend proxy postgres seaweedfs registry}
commands =
bash -c "env; date; k3d version"
@@ -906,7 +906,7 @@ commands =
fi"
# Mongo
- bash -c "if echo '{posargs}' | grep -q 'mongo'; then echo 'Checking readiness of Mongo'; ./scripts/wait_for.sh service mongo --context k3d-$CLUSTER_NAME --namespace syft; fi"
+ bash -c "if echo '{posargs}' | grep -q 'postgres'; then echo 'Checking readiness of Postgres'; ./scripts/wait_for.sh service postgres --context k3d-$CLUSTER_NAME --namespace syft; fi"
# Proxy
bash -c "if echo '{posargs}' | grep -q 'proxy'; then echo 'Checking readiness of proxy'; ./scripts/wait_for.sh service proxy --context k3d-$CLUSTER_NAME --namespace syft; fi"
@@ -950,7 +950,7 @@ commands =
echo "Installing local helm charts"; \
if [[ "{posargs}" == "override" ]]; then \
echo "Overriding resourcesPreset"; \
- helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \
+ helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \
else \
helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \
fi \
@@ -960,14 +960,14 @@ commands =
helm repo update openmined; \
if [[ "{posargs}" == "override" ]]; then \
echo "Overriding resourcesPreset"; \
- helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \
+ helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \
else \
helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \
fi \
fi'
; wait for everything else to be loaded
- tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry
+ tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry
# Run Notebook tests
tox -e e2e.test.notebook
@@ -1234,7 +1234,9 @@ setenv=
DEVSPACE_PROFILE={env:DEVSPACE_PROFILE}
allowlist_externals =
tox
+ bash
commands =
+ bash -c "CLUSTER_NAME=${CLUSTER_NAME} tox -e dev.k8s.destroy"
tox -e dev.k8s.start
tox -e dev.k8s.{posargs:deploy}
@@ -1460,7 +1462,7 @@ commands =
'
; wait for everything else to be loaded
- tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry
+ tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry
bash -c 'python -c "import syft as sy; print(\"Migrating from syft version:\", sy.__version__)"'
@@ -1541,7 +1543,7 @@ commands =
sleep 30
; # wait for test-datasite-1
- bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
+ bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft
bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft