From abf44deca3c2f634a570a6e96ef777ff9f2b0713 Mon Sep 17 00:00:00 2001 From: Smartappli Date: Sat, 3 Aug 2024 19:54:48 +0000 Subject: [PATCH] style fixes by ruff --- notebooks/api/0.8/00-load-data.ipynb | 4 +- .../api/0.8/02-review-code-and-approve.ipynb | 2 +- notebooks/api/0.8/04-pytorch-example.ipynb | 11 +- notebooks/api/0.8/05-custom-policy.ipynb | 32 +- .../api/0.8/06-multiple-code-requests.ipynb | 2 +- .../07-datasite-register-control-flow.ipynb | 2 +- notebooks/api/0.8/10-container-images.ipynb | 12 +- .../api/0.8/11-container-images-k8s.ipynb | 48 ++- .../api/0.8/12-custom-api-endpoint.ipynb | 17 +- .../api/0.8/13-forgot-user-password.ipynb | 7 +- .../01-uploading-private-data.ipynb | 8 +- .../data-owner/02-account-management.ipynb | 4 +- .../data-owner/03-messages-and-requests.ipynb | 4 +- .../data-owner/05-syft-services-api.ipynb | 2 +- .../03-working-with-private-datasets.ipynb | 6 +- .../data-scientist/04-syft-functions.ipynb | 4 +- .../05-messaging-and-requests.ipynb | 4 +- .../deployments/01-deploy-python.ipynb | 2 +- .../tutorials/hello-syft/01-hello-syft.ipynb | 8 +- .../model-auditing/colab/01-user-log.ipynb | 11 +- .../00-data-owner-upload-data.ipynb | 7 +- .../01-data-scientist-submit-code.ipynb | 14 +- .../02-data-owner-review-approve-code.ipynb | 4 +- .../03-data-scientist-download-results.ipynb | 11 +- .../tutorials/model-training/mnist_dataset.py | 7 +- .../pandas-cookbook/00_cache_test.py | 3 +- .../01-reading-from-a-csv.ipynb | 7 +- ...lecting-data-finding-common-complain.ipynb | 16 +- ...orough-has-the-most-noise-complaints.ipynb | 16 +- ...-weekday-bike-most-groupby-aggregate.ipynb | 10 +- ...ing-dataframes-scraping-weather-data.ipynb | 13 +- ...rations-which-month-was-the-snowiest.ipynb | 9 +- .../07-cleaning-up-messy-data.ipynb | 13 +- .../08-how-to-deal-with-timestamps.ipynb | 13 +- .../0-prepare-migration-data.ipynb | 8 +- .../1-dump-database-to-file.ipynb | 2 +- .../2-migrate-from-file.ipynb | 2 +- packages/grid/backend/grid/api/new/new.py | 3 +- packages/grid/backend/grid/api/router.py | 3 +- packages/grid/backend/grid/bootstrap.py | 13 +- packages/grid/backend/grid/core/config.py | 21 +- packages/grid/backend/grid/core/server.py | 28 +- packages/grid/backend/grid/main.py | 5 +- .../attestation/server/attestation_main.py | 8 +- .../attestation/server/cpu_attestation.py | 4 +- .../attestation/server/gpu_attestation.py | 2 +- packages/grid/helm/generate_helm_notes.py | 9 +- packages/grid/seaweedfs/src/api.py | 11 +- packages/grid/seaweedfs/src/automount.py | 6 +- packages/grid/seaweedfs/src/buckets.py | 10 +- packages/grid/seaweedfs/src/mount.py | 26 +- packages/grid/seaweedfs/src/mount_cmd.py | 5 +- packages/grid/seaweedfs/src/mount_options.py | 7 +- packages/grid/seaweedfs/src/util.py | 2 - packages/grid/seaweedfs/tests/conftest.py | 4 +- .../grid/seaweedfs/tests/mount_cmd_test.py | 10 +- .../seaweedfs/tests/mount_options_test.py | 37 +- packages/grid/seaweedfs/tests/mount_test.py | 22 +- packages/syft/src/syft/__init__.py | 86 ++--- packages/syft/src/syft/client/api.py | 170 ++++----- packages/syft/src/syft/client/client.py | 123 +++---- packages/syft/src/syft/client/connection.py | 3 +- .../syft/src/syft/client/datasite_client.py | 61 ++- .../syft/src/syft/client/enclave_client.py | 14 +- .../syft/src/syft/client/gateway_client.py | 17 +- packages/syft/src/syft/client/registry.py | 48 +-- packages/syft/src/syft/client/search.py | 16 +- packages/syft/src/syft/client/syncing.py | 29 +- .../syft/src/syft/custom_worker/builder.py | 31 +- .../src/syft/custom_worker/builder_docker.py | 12 +- .../src/syft/custom_worker/builder_k8s.py | 44 +-- .../src/syft/custom_worker/builder_types.py | 3 +- .../syft/src/syft/custom_worker/config.py | 9 +- packages/syft/src/syft/custom_worker/k8s.py | 18 +- .../syft/src/syft/custom_worker/runner_k8s.py | 22 +- packages/syft/src/syft/custom_worker/utils.py | 2 +- packages/syft/src/syft/dev/prof.py | 2 + packages/syft/src/syft/exceptions/user.py | 2 +- packages/syft/src/syft/orchestra.py | 37 +- .../syft/src/syft/protocol/data_protocol.py | 45 +-- packages/syft/src/syft/serde/__init__.py | 2 +- packages/syft/src/syft/serde/array.py | 3 +- packages/syft/src/syft/serde/arrow.py | 19 +- packages/syft/src/syft/serde/capnp.py | 3 +- packages/syft/src/syft/serde/deserialize.py | 3 +- .../syft/src/syft/serde/lib_permissions.py | 2 +- .../src/syft/serde/lib_service_registry.py | 28 +- packages/syft/src/syft/serde/mock.py | 2 +- packages/syft/src/syft/serde/recursive.py | 50 ++- .../src/syft/serde/recursive_primitives.py | 55 ++- packages/syft/src/syft/serde/serializable.py | 6 +- packages/syft/src/syft/serde/signature.py | 13 +- packages/syft/src/syft/serde/third_party.py | 54 ++- packages/syft/src/syft/server/credentials.py | 7 +- packages/syft/src/syft/server/routes.py | 41 +-- packages/syft/src/syft/server/run.py | 7 +- packages/syft/src/syft/server/server.py | 286 ++++++++------- .../syft/src/syft/server/service_registry.py | 10 +- packages/syft/src/syft/server/utils.py | 11 +- packages/syft/src/syft/server/uvicorn.py | 22 +- .../syft/src/syft/server/worker_settings.py | 7 +- .../syft/service/action/action_data_empty.py | 3 +- .../syft/service/action/action_endpoint.py | 10 +- .../src/syft/service/action/action_object.py | 247 +++++++------ .../syft/service/action/action_permissions.py | 2 +- .../src/syft/service/action/action_service.py | 234 ++++++------ .../src/syft/service/action/action_store.py | 80 ++-- .../src/syft/service/action/action_types.py | 10 +- .../syft/src/syft/service/action/numpy.py | 11 +- .../syft/src/syft/service/action/pandas.py | 9 +- packages/syft/src/syft/service/action/plan.py | 17 +- .../src/syft/service/action/verification.py | 18 +- packages/syft/src/syft/service/api/api.py | 88 +++-- .../syft/src/syft/service/api/api_service.py | 89 +++-- .../syft/src/syft/service/api/api_stash.py | 14 +- .../attestation/attestation_service.py | 20 +- .../service/blob_storage/remote_profile.py | 9 +- .../src/syft/service/blob_storage/service.py | 72 ++-- .../src/syft/service/blob_storage/stash.py | 6 +- .../src/syft/service/blob_storage/util.py | 5 +- .../syft/src/syft/service/code/code_parse.py | 2 +- .../src/syft/service/code/status_service.py | 32 +- .../syft/src/syft/service/code/user_code.py | 217 ++++++----- .../src/syft/service/code/user_code_parse.py | 4 +- .../syft/service/code/user_code_service.py | 121 +++--- .../src/syft/service/code/user_code_stash.py | 32 +- packages/syft/src/syft/service/code/utils.py | 15 +- .../syft/service/code_history/code_history.py | 12 +- .../code_history/code_history_service.py | 57 +-- .../code_history/code_history_stash.py | 20 +- packages/syft/src/syft/service/context.py | 19 +- .../syft/service/data_subject/data_subject.py | 20 +- .../data_subject/data_subject_member.py | 6 +- .../data_subject_member_service.py | 35 +- .../data_subject/data_subject_service.py | 39 +- .../syft/src/syft/service/dataset/dataset.py | 105 +++--- .../syft/service/dataset/dataset_service.py | 56 ++- .../src/syft/service/dataset/dataset_stash.py | 25 +- .../src/syft/service/job/html_template.py | 5 +- .../syft/src/syft/service/job/job_service.py | 67 ++-- .../syft/src/syft/service/job/job_stash.py | 140 ++++--- packages/syft/src/syft/service/log/log.py | 5 +- .../syft/src/syft/service/log/log_service.py | 14 +- .../syft/src/syft/service/log/log_stash.py | 6 +- .../syft/service/metadata/metadata_service.py | 5 +- .../syft/service/metadata/server_metadata.py | 14 +- .../service/migration/migration_service.py | 104 +++--- .../migration/object_migration_state.py | 46 +-- .../service/network/association_request.py | 35 +- .../syft/service/network/network_service.py | 262 +++++++------ .../service/network/rathole_config_builder.py | 50 ++- .../src/syft/service/network/rathole_toml.py | 11 - .../service/network/reverse_tunnel_service.py | 6 +- .../syft/src/syft/service/network/routes.py | 39 +- .../src/syft/service/network/server_peer.py | 96 ++--- .../syft/src/syft/service/network/utils.py | 24 +- .../service/notification/email_templates.py | 9 +- .../notification/notification_service.py | 62 ++-- .../notification/notification_stash.py | 41 +-- .../service/notification/notifications.py | 22 +- .../src/syft/service/notifier/notifier.py | 39 +- .../syft/service/notifier/notifier_enums.py | 3 +- .../syft/service/notifier/notifier_service.py | 79 ++-- .../syft/service/notifier/notifier_stash.py | 22 +- .../src/syft/service/notifier/smtp_client.py | 14 +- .../object_search/object_migration_state.py | 20 +- .../src/syft/service/output/output_service.py | 61 ++- .../syft/src/syft/service/policy/policy.py | 157 ++++---- .../src/syft/service/policy/policy_service.py | 14 +- .../syft/service/policy/user_policy_stash.py | 17 +- .../syft/src/syft/service/project/project.py | 136 ++++--- .../syft/service/project/project_service.py | 100 +++-- .../src/syft/service/project/project_stash.py | 20 +- .../syft/src/syft/service/queue/base_queue.py | 6 +- packages/syft/src/syft/service/queue/queue.py | 34 +- .../src/syft/service/queue/queue_service.py | 8 +- .../src/syft/service/queue/queue_stash.py | 35 +- .../syft/src/syft/service/queue/zmq_queue.py | 97 +++-- .../syft/src/syft/service/request/request.py | 150 ++++---- .../syft/service/request/request_service.py | 82 ++--- .../src/syft/service/request/request_stash.py | 19 +- packages/syft/src/syft/service/response.py | 10 +- packages/syft/src/syft/service/service.py | 65 ++-- .../src/syft/service/settings/settings.py | 34 +- .../syft/service/settings/settings_service.py | 69 ++-- .../syft/service/settings/settings_stash.py | 12 +- .../syft/src/syft/service/sync/diff_state.py | 145 ++++---- .../src/syft/service/sync/resolve_widget.py | 86 +++-- .../src/syft/service/sync/sync_service.py | 45 +-- .../syft/src/syft/service/sync/sync_stash.py | 15 +- .../syft/src/syft/service/sync/sync_state.py | 20 +- packages/syft/src/syft/service/user/user.py | 58 ++- .../syft/src/syft/service/user/user_roles.py | 8 +- .../src/syft/service/user/user_service.py | 198 +++++----- .../syft/src/syft/service/user/user_stash.py | 42 +-- packages/syft/src/syft/service/warnings.py | 7 +- .../syft/service/worker/image_identifier.py | 7 +- .../src/syft/service/worker/image_registry.py | 5 +- .../service/worker/image_registry_service.py | 16 +- .../service/worker/image_registry_stash.py | 17 +- .../syft/src/syft/service/worker/utils.py | 59 ++- .../syft/src/syft/service/worker/worker.py | 3 +- .../src/syft/service/worker/worker_image.py | 8 +- .../service/worker/worker_image_service.py | 63 ++-- .../syft/service/worker/worker_image_stash.py | 27 +- .../src/syft/service/worker/worker_pool.py | 22 +- .../service/worker/worker_pool_service.py | 116 +++--- .../syft/service/worker/worker_pool_stash.py | 21 +- .../src/syft/service/worker/worker_service.py | 77 ++-- .../src/syft/service/worker/worker_stash.py | 34 +- packages/syft/src/syft/store/__init__.py | 3 +- .../src/syft/store/blob_storage/__init__.py | 49 ++- .../src/syft/store/blob_storage/on_disk.py | 38 +- .../src/syft/store/blob_storage/seaweedfs.py | 62 ++-- .../src/syft/store/dict_document_store.py | 17 +- .../syft/src/syft/store/document_store.py | 107 +++--- .../syft/src/syft/store/kv_document_store.py | 95 ++--- packages/syft/src/syft/store/linked_obj.py | 14 +- packages/syft/src/syft/store/locks.py | 54 +-- packages/syft/src/syft/store/mongo_client.py | 161 ++++---- packages/syft/src/syft/store/mongo_codecs.py | 6 +- .../src/syft/store/mongo_document_store.py | 161 ++++---- .../src/syft/store/sqlite_document_store.py | 55 +-- packages/syft/src/syft/types/base.py | 3 +- packages/syft/src/syft/types/blob_storage.py | 59 ++- packages/syft/src/syft/types/datetime.py | 11 +- packages/syft/src/syft/types/dicttuple.py | 27 +- packages/syft/src/syft/types/identity.py | 1 - packages/syft/src/syft/types/server_url.py | 4 +- .../syft/src/syft/types/syft_metaclass.py | 3 +- .../syft/src/syft/types/syft_migration.py | 9 +- packages/syft/src/syft/types/syft_object.py | 128 +++---- .../src/syft/types/syft_object_registry.py | 23 +- .../syft/src/syft/types/syncable_object.py | 8 +- packages/syft/src/syft/types/transforms.py | 26 +- packages/syft/src/syft/types/twin_object.py | 15 +- packages/syft/src/syft/types/uid.py | 26 +- .../syft/src/syft/util/_std_stream_capture.py | 40 +- packages/syft/src/syft/util/assets.py | 2 +- packages/syft/src/syft/util/autoreload.py | 4 +- packages/syft/src/syft/util/decorators.py | 11 +- packages/syft/src/syft/util/env.py | 3 +- .../syft/src/syft/util/experimental_flags.py | 4 +- packages/syft/src/syft/util/markdown.py | 4 +- packages/syft/src/syft/util/misc_objs.py | 6 +- .../syft/util/notebook_ui/components/base.py | 5 +- .../syft/util/notebook_ui/components/sync.py | 16 +- .../components/tabulator_template.py | 36 +- .../syft/src/syft/util/notebook_ui/icons.py | 3 +- packages/syft/src/syft/util/patch_ipython.py | 23 +- packages/syft/src/syft/util/schema.py | 15 +- packages/syft/src/syft/util/table.py | 26 +- packages/syft/src/syft/util/telemetry.py | 12 +- .../syft/src/syft/util/trace_decorator.py | 18 +- packages/syft/src/syft/util/util.py | 81 ++-- .../syft/src/syft/util/version_compare.py | 12 +- packages/syft/tests/__init__.py | 2 +- packages/syft/tests/conftest.py | 109 +++--- packages/syft/tests/mongomock/__init__.py | 5 +- packages/syft/tests/mongomock/__init__.pyi | 20 +- packages/syft/tests/mongomock/aggregate.py | 225 ++++++------ .../syft/tests/mongomock/codec_options.py | 14 +- packages/syft/tests/mongomock/collection.py | 346 +++++++++--------- .../syft/tests/mongomock/command_cursor.py | 2 +- packages/syft/tests/mongomock/database.py | 52 +-- packages/syft/tests/mongomock/filtering.py | 39 +- packages/syft/tests/mongomock/gridfs.py | 12 +- packages/syft/tests/mongomock/helpers.py | 64 ++-- packages/syft/tests/mongomock/mongo_client.py | 22 +- .../syft/tests/mongomock/not_implemented.py | 2 +- packages/syft/tests/mongomock/object_id.py | 4 +- packages/syft/tests/mongomock/patch.py | 15 +- packages/syft/tests/mongomock/read_concern.py | 2 +- .../syft/tests/mongomock/read_preferences.py | 6 +- packages/syft/tests/mongomock/results.py | 14 +- packages/syft/tests/mongomock/store.py | 8 +- packages/syft/tests/mongomock/thread.py | 2 +- .../syft/tests/mongomock/write_concern.py | 6 +- packages/syft/tests/syft/action_test.py | 15 +- packages/syft/tests/syft/api_test.py | 6 +- packages/syft/tests/syft/assets_test.py | 4 +- .../syft/blob_storage/blob_storage_test.py | 33 +- .../syft/tests/syft/code_verification_test.py | 19 +- .../tests/syft/custom_worker/config_test.py | 47 +-- .../tests/syft/dataset/dataset_stash_test.py | 5 +- packages/syft/tests/syft/dataset/fixtures.py | 30 +- packages/syft/tests/syft/eager_test.py | 12 +- packages/syft/tests/syft/hash_test.py | 8 +- packages/syft/tests/syft/lineage_id_test.py | 3 +- packages/syft/tests/syft/locks_test.py | 18 +- .../syft/migrations/data_migration_test.py | 9 +- .../migrations/protocol_communication_test.py | 39 +- packages/syft/tests/syft/notebook_ui_test.py | 3 +- .../syft/tests/syft/notifications/fixtures.py | 30 +- .../notifications/notification_serde_test.py | 2 +- .../notification_service_test.py | 37 +- .../notifications/notification_stash_test.py | 73 ++-- .../syft/tests/syft/project/project_test.py | 12 +- packages/syft/tests/syft/request/fixtures.py | 14 +- .../request/request_code_accept_deny_test.py | 18 +- .../request/request_code_permissions_test.py | 4 +- .../tests/syft/request/request_stash_test.py | 18 +- packages/syft/tests/syft/serde/fixtures.py | 4 +- .../tests/syft/serde/numpy_functions_test.py | 9 +- packages/syft/tests/syft/serializable_test.py | 17 +- .../syft/service/action/action_object_test.py | 65 ++-- .../service/action/action_service_test.py | 2 +- .../syft/service/action/action_types_test.py | 3 +- .../service/dataset/dataset_service_test.py | 39 +- .../tests/syft/service/jobs/job_stash_test.py | 7 +- .../service/sync/sync_resolve_single_test.py | 38 +- .../tests/syft/service_permission_test.py | 6 +- packages/syft/tests/syft/settings/fixtures.py | 36 +- .../syft/tests/syft/settings/metadata_test.py | 4 +- .../syft/settings/settings_serde_test.py | 2 +- .../syft/settings/settings_service_test.py | 30 +- .../syft/settings/settings_stash_test.py | 7 +- .../tests/syft/stores/action_store_test.py | 20 +- .../syft/tests/syft/stores/base_stash_test.py | 93 ++--- .../syft/stores/dict_document_store_test.py | 41 +-- .../syft/stores/kv_document_store_test.py | 29 +- .../syft/stores/mongo_document_store_test.py | 138 ++++--- .../tests/syft/stores/permissions_test.py | 12 +- .../tests/syft/stores/queue_stash_test.py | 5 +- .../syft/stores/sqlite_document_store_test.py | 49 ++- .../tests/syft/stores/store_fixtures_test.py | 120 +++--- .../tests/syft/stores/store_mocks_test.py | 8 +- .../syft/transforms/transform_methods_test.py | 57 +-- .../tests/syft/transforms/transforms_test.py | 9 +- .../syft/tests/syft/types/dicttuple_test.py | 34 +- packages/syft/tests/syft/uid_test.py | 16 +- packages/syft/tests/syft/users/fixtures.py | 88 ++--- .../tests/syft/users/local_execution_test.py | 4 +- .../syft/tests/syft/users/user_code_test.py | 33 +- .../syft/tests/syft/users/user_serde_test.py | 2 +- .../tests/syft/users/user_service_test.py | 44 +-- .../syft/tests/syft/users/user_stash_test.py | 45 ++- packages/syft/tests/syft/users/user_test.py | 58 ++- .../worker_pool/worker_pool_service_test.py | 20 +- .../tests/syft/worker_pool/worker_test.py | 5 +- packages/syft/tests/syft/worker_test.py | 31 +- packages/syft/tests/syft/zmq_queue_test.py | 41 ++- packages/syft/tests/utils/custom_markers.py | 2 +- packages/syftcli/setup.py | 5 +- packages/syftcli/syftcli/bundle/create.py | 30 +- packages/syftcli/syftcli/cli.py | 3 +- packages/syftcli/syftcli/core/console.py | 3 +- .../syftcli/syftcli/core/container_engine.py | 17 +- packages/syftcli/syftcli/core/proc.py | 15 +- packages/syftcli/syftcli/core/register.py | 2 +- packages/syftcli/syftcli/core/syft_repo.py | 8 +- packages/syftcli/syftcli/core/syft_version.py | 4 +- packages/syftcli/tests/hello_test.py | 3 +- packages/syftcli/tests/version_test.py | 3 +- scripts/container_log_collector.py | 5 +- scripts/convert_to_pypi_readme.py | 4 +- scripts/create_syftcli_config.py | 2 +- scripts/generate_canonical_names.py | 20 +- scripts/patch_hosts.py | 14 +- scripts/print_fd.py | 8 +- tests/integration/conftest.py | 12 +- .../container_workload/blob_storage_test.py | 6 +- .../container_workload/pool_image_test.py | 42 +-- .../frontend/frontend_start_test.py | 4 +- tests/integration/local/enclave_local_test.py | 2 +- tests/integration/local/gateway_local_test.py | 55 ++- tests/integration/local/job_test.py | 28 +- .../local/request_multiple_nodes_test.py | 6 +- tests/integration/local/syft_function_test.py | 15 +- .../local/syft_worker_deletion_test.py | 16 +- tests/integration/local/twin_api_sync_test.py | 30 +- tests/integration/network/client_test.py | 4 +- tests/integration/network/gateway_test.py | 195 +++++----- 373 files changed, 5653 insertions(+), 6329 deletions(-) diff --git a/notebooks/api/0.8/00-load-data.ipynb b/notebooks/api/0.8/00-load-data.ipynb index 1c8a6d3ff6a..74e111d9e58 100644 --- a/notebooks/api/0.8/00-load-data.ipynb +++ b/notebooks/api/0.8/00-load-data.ipynb @@ -68,7 +68,7 @@ "source": [ "# Launch a fresh datasite server named \"test-datasite-1\" in dev mode on the local machine\n", "server = sy.orchestra.launch(\n", - " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True\n", + " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True,\n", ")" ] }, @@ -435,7 +435,7 @@ "source": [ "ctf = sy.Asset(name=\"canada_trade_flow\")\n", "ctf.set_description(\n", - " \"Canada trade flow represents export & import of different commodities to other countries\"\n", + " \"Canada trade flow represents export & import of different commodities to other countries\",\n", ")" ] }, diff --git a/notebooks/api/0.8/02-review-code-and-approve.ipynb b/notebooks/api/0.8/02-review-code-and-approve.ipynb index dc0327dc306..59edd869743 100644 --- a/notebooks/api/0.8/02-review-code-and-approve.ipynb +++ b/notebooks/api/0.8/02-review-code-and-approve.ipynb @@ -410,7 +410,7 @@ " reason=(\n", " \"The Submitted UserCode does not add differential privacy to the output.\"\n", " \"Kindly add differential privacy and resubmit the code.\"\n", - " )\n", + " ),\n", ")\n", "result" ] diff --git a/notebooks/api/0.8/04-pytorch-example.ipynb b/notebooks/api/0.8/04-pytorch-example.ipynb index 66cb5a7df54..6275ee89793 100644 --- a/notebooks/api/0.8/04-pytorch-example.ipynb +++ b/notebooks/api/0.8/04-pytorch-example.ipynb @@ -24,12 +24,11 @@ "outputs": [], "source": [ "# third party\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", "# syft absolute\n", "import syft as sy\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import nn\n", "\n", "sy.requires(SYFT_VERSION)" ] @@ -247,15 +246,15 @@ "source": [ "@sy.syft_function(\n", " input_policy=sy.ExactMatch(\n", - " weights=weight_datasite_obj.id, data=train_datasite_obj.id\n", + " weights=weight_datasite_obj.id, data=train_datasite_obj.id,\n", " ),\n", " output_policy=sy.SingleExecutionExactOutput(),\n", ")\n", "def train_mlp(weights, data):\n", " # third party\n", " import torch\n", - " import torch.nn as nn\n", " import torch.nn.functional as F\n", + " from torch import nn\n", "\n", " class MLP(nn.Module):\n", " def __init__(self, out_dims):\n", diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index 9854efe851e..299cadf4c36 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -42,7 +42,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True\n", + " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True,\n", ")" ] }, @@ -66,7 +66,7 @@ "outputs": [], "source": [ "datasite_client.register(\n", - " email=\"newuser@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\"\n", + " email=\"newuser@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\",\n", ")" ] }, @@ -244,12 +244,10 @@ "outputs": [], "source": [ "# third party\n", - "from result import Err\n", - "from result import Ok\n", + "from result import Err, Ok\n", "\n", "# syft absolute\n", - "from syft.client.api import AuthedServiceContext\n", - "from syft.client.api import ServerIdentity\n", + "from syft.client.api import AuthedServiceContext, ServerIdentity\n", "\n", "\n", "class CustomExactMatch(sy.CustomInputPolicy):\n", @@ -261,7 +259,7 @@ "\n", " try:\n", " allowed_inputs = self.allowed_ids_only(\n", - " allowed_inputs=self.inputs, kwargs=kwargs, context=context\n", + " allowed_inputs=self.inputs, kwargs=kwargs, context=context,\n", " )\n", " results = self.retrieve_from_db(\n", " code_item_id=code_item_id,\n", @@ -286,7 +284,7 @@ " # but we are not modifying the permissions of the private data\n", "\n", " root_context = AuthedServiceContext(\n", - " server=context.server, credentials=context.server.verify_key\n", + " server=context.server, credentials=context.server.verify_key,\n", " )\n", " if context.server.server_type == ServerType.DATASITE:\n", " for var_name, arg_id in allowed_inputs.items():\n", @@ -301,7 +299,7 @@ " code_inputs[var_name] = kwarg_value.ok()\n", " else:\n", " raise Exception(\n", - " f\"Invalid Server Type for Code Submission:{context.server.server_type}\"\n", + " f\"Invalid Server Type for Code Submission:{context.server.server_type}\",\n", " )\n", " return Ok(code_inputs)\n", "\n", @@ -312,8 +310,7 @@ " context,\n", " ):\n", " # syft absolute\n", - " from syft import ServerType\n", - " from syft import UID\n", + " from syft import UID, ServerType\n", "\n", " if context.server.server_type == ServerType.DATASITE:\n", " server_identity = ServerIdentity(\n", @@ -324,7 +321,7 @@ " allowed_inputs = allowed_inputs.get(server_identity, {})\n", " else:\n", " raise Exception(\n", - " f\"Invalid Server Type for Code Submission:{context.server.server_type}\"\n", + " f\"Invalid Server Type for Code Submission:{context.server.server_type}\",\n", " )\n", " filtered_kwargs = {}\n", " for key in allowed_inputs.keys():\n", @@ -336,7 +333,7 @@ "\n", " if uid != allowed_inputs[key]:\n", " raise Exception(\n", - " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\",\n", " )\n", " filtered_kwargs[key] = value\n", " return filtered_kwargs\n", @@ -369,7 +366,7 @@ " not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n", " if len(not_approved_kwargs) > 0:\n", " return Err(\n", - " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n", + " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\",\n", " )\n", " return Ok(True)\n", "\n", @@ -381,8 +378,7 @@ " context,\n", "):\n", " # syft absolute\n", - " from syft import ServerType\n", - " from syft import UID\n", + " from syft import UID, ServerType\n", " from syft.client.api import ServerIdentity\n", "\n", " if context.server.server_type == ServerType.DATASITE:\n", @@ -394,7 +390,7 @@ " allowed_inputs = allowed_inputs.get(server_identity, {})\n", " else:\n", " raise Exception(\n", - " f\"Invalid Server Type for Code Submission:{context.server.server_type}\"\n", + " f\"Invalid Server Type for Code Submission:{context.server.server_type}\",\n", " )\n", " filtered_kwargs = {}\n", " for key in allowed_inputs.keys():\n", @@ -406,7 +402,7 @@ "\n", " if uid != allowed_inputs[key]:\n", " raise Exception(\n", - " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\",\n", " )\n", " filtered_kwargs[key] = value\n", " return filtered_kwargs" diff --git a/notebooks/api/0.8/06-multiple-code-requests.ipynb b/notebooks/api/0.8/06-multiple-code-requests.ipynb index 70ce3d055bf..9cc87da2adb 100644 --- a/notebooks/api/0.8/06-multiple-code-requests.ipynb +++ b/notebooks/api/0.8/06-multiple-code-requests.ipynb @@ -42,7 +42,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"test-datasite-1\", port=\"auto\", reset=True, dev_mode=True\n", + " name=\"test-datasite-1\", port=\"auto\", reset=True, dev_mode=True,\n", ")" ] }, diff --git a/notebooks/api/0.8/07-datasite-register-control-flow.ipynb b/notebooks/api/0.8/07-datasite-register-control-flow.ipynb index d4c10dd0cae..fe6fa0c55d8 100644 --- a/notebooks/api/0.8/07-datasite-register-control-flow.ipynb +++ b/notebooks/api/0.8/07-datasite-register-control-flow.ipynb @@ -59,7 +59,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True\n", + " name=\"test-datasite-1\", port=\"auto\", dev_mode=True, reset=True,\n", ")" ] }, diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index 72eb72c367d..3b92418df81 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -228,7 +228,7 @@ "outputs": [], "source": [ "submit_result = datasite_client.api.services.worker_image.submit(\n", - " worker_config=docker_config\n", + " worker_config=docker_config,\n", ")" ] }, @@ -474,7 +474,7 @@ "source": [ "if running_as_container:\n", " assert workerimage.image_hash == get_image_hash(\n", - " workerimage.built_image_tag\n", + " workerimage.built_image_tag,\n", " ), \"Worker Image image_hash does not match with built image hash\"" ] }, @@ -627,7 +627,7 @@ " assert status.error is None\n", " if running_as_container:\n", " assert status.worker.image.image_hash == get_image_hash(\n", - " workerimage.built_image_tag\n", + " workerimage.built_image_tag,\n", " ), \"Worker Pool Image image_hash does not match with built image hash\"" ] }, @@ -762,7 +762,7 @@ "outputs": [], "source": [ "worker_delete_res = datasite_client.api.services.worker.delete(\n", - " uid=second_worker.id, force=True\n", + " uid=second_worker.id, force=True,\n", ")" ] }, @@ -1096,7 +1096,7 @@ "outputs": [], "source": [ "submit_result = datasite_client.api.services.worker_image.submit(\n", - " worker_config=docker_config_2\n", + " worker_config=docker_config_2,\n", ")\n", "submit_result" ] @@ -1159,7 +1159,7 @@ "source": [ "opendp_pool_name = \"second-opendp-pool\"\n", "pool_create_request = datasite_client.api.services.worker_pool.pool_creation_request(\n", - " pool_name=opendp_pool_name, num_workers=2, image_uid=workerimage_2.id\n", + " pool_name=opendp_pool_name, num_workers=2, image_uid=workerimage_2.id,\n", ")\n", "pool_create_request" ] diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index 2a6ca77dabb..200b0716a0b 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -20,6 +20,7 @@ "source": [ "# stdlib\n", "import os\n", + "from getpass import getpass # noqa\n", "\n", "# third party\n", "import kr8s\n", @@ -29,9 +30,6 @@ "# syft absolute\n", "import syft as sy\n", "\n", - "from getpass import getpass # noqa\n", - "\n", - "\n", "sy.requires(SYFT_VERSION)\n", "\n", "# syft absolute\n", @@ -60,7 +58,7 @@ "def get_statefulset_by_pool_name(pool_name):\n", " kr8s_client = get_kr8s_client()\n", " pool_list = kr8s_client.get(\n", - " \"statefulsets\", label_selector={\"app.kubernetes.io/component\": pool_name}\n", + " \"statefulsets\", label_selector={\"app.kubernetes.io/component\": pool_name},\n", " )\n", " if len(pool_list) == 0:\n", " return None\n", @@ -152,7 +150,7 @@ "outputs": [], "source": [ "result = datasite_client.api.services.worker_pool.scale(\n", - " number=3, pool_name=\"default-pool\"\n", + " number=3, pool_name=\"default-pool\",\n", ")\n", "assert not isinstance(result, sy.SyftError), str(result)\n", "result" @@ -200,7 +198,7 @@ "outputs": [], "source": [ "default_pool_scale_res = datasite_client.api.services.worker_pool.scale(\n", - " number=1, pool_name=\"default-pool\"\n", + " number=1, pool_name=\"default-pool\",\n", ")\n", "assert not isinstance(default_pool_scale_res, sy.SyftError), str(default_pool_scale_res)\n", "default_pool_scale_res" @@ -226,7 +224,7 @@ "outputs": [], "source": [ "default_worker_pool = datasite_client.api.services.worker_pool.get_by_name(\n", - " pool_name=\"default-pool\"\n", + " pool_name=\"default-pool\",\n", ")\n", "default_worker_pool" ] @@ -301,7 +299,7 @@ "outputs": [], "source": [ "submit_result = datasite_client.api.services.worker_image.submit(\n", - " worker_config=docker_config\n", + " worker_config=docker_config,\n", ")\n", "submit_result" ] @@ -667,10 +665,10 @@ "), \"Labels not found in custom pool pod metadata\"\n", "\n", "assert is_subset_dict(\n", - " custom_pool_pod_annotations, custom_pool_pod_metadata.annotations\n", + " custom_pool_pod_annotations, custom_pool_pod_metadata.annotations,\n", "), \"Annotations do not match in Custom pool pod metadata\"\n", "assert is_subset_dict(\n", - " custom_pool_pod_labels, custom_pool_pod_metadata.labels\n", + " custom_pool_pod_labels, custom_pool_pod_metadata.labels,\n", "), \"Labels do not match in Custom pool pod metadata\"" ] }, @@ -698,7 +696,7 @@ ")\n", "\n", "assert worker_pool is not None, str(\n", - " [worker_pool.__dict__ for worker_pool in worker_pool_list]\n", + " [worker_pool.__dict__ for worker_pool in worker_pool_list],\n", ")\n", "assert len(worker_pool.workers) == 3" ] @@ -712,7 +710,7 @@ "source": [ "# We can filter pools based on the image id upon which the pools were built\n", "filtered_result = datasite_client.api.services.worker_pool.filter_by_image_id(\n", - " image_uid=workerimage.id\n", + " image_uid=workerimage.id,\n", ")\n", "filtered_result" ] @@ -957,7 +955,7 @@ "source": [ "# Scale Down the workers\n", "custom_pool_scale_res = datasite_client.api.services.worker_pool.scale(\n", - " number=1, pool_name=worker_pool_name\n", + " number=1, pool_name=worker_pool_name,\n", ")\n", "assert not isinstance(custom_pool_scale_res, sy.SyftError), str(custom_pool_scale_res)\n", "custom_pool_scale_res" @@ -1006,7 +1004,7 @@ "source": [ "submit_result = None\n", "submit_result = datasite_client.api.services.worker_image.submit(\n", - " worker_config=docker_config_opendp\n", + " worker_config=docker_config_opendp,\n", ")\n", "submit_result" ] @@ -1110,7 +1108,7 @@ "assert workerimage_opendp.image_hash is not None, str(workerimage_opendp.__dict__)\n", "\n", "assert _images[workerimage_opendp.built_image_tag] == workerimage_opendp, str(\n", - " workerimage_opendp\n", + " workerimage_opendp,\n", ")\n", "\n", "workerimage_opendp" @@ -1236,10 +1234,10 @@ "\n", "\n", "assert is_subset_dict(\n", - " opendp_pod_annotations, opendp_pool_pod_metadata.annotations\n", + " opendp_pod_annotations, opendp_pool_pod_metadata.annotations,\n", "), \"Annotations do not match in opendp pool pod metadata\"\n", "assert is_subset_dict(\n", - " opendp_pod_labels, opendp_pool_pod_metadata.labels\n", + " opendp_pod_labels, opendp_pool_pod_metadata.labels,\n", "), \"Labels do not match in opendp pool pod metadata\"" ] }, @@ -1252,7 +1250,7 @@ "source": [ "# Scale Down the workers\n", "opendp_pool_scale_res = datasite_client.api.services.worker_pool.scale(\n", - " number=1, pool_name=pool_name_opendp\n", + " number=1, pool_name=pool_name_opendp,\n", ")\n", "assert not isinstance(opendp_pool_scale_res, sy.SyftError), str(opendp_pool_scale_res)\n", "opendp_pool_scale_res" @@ -1303,10 +1301,10 @@ "source": [ "pool_name_recordlinkage = \"recordlinkage-pool\"\n", "recordlinkage_pod_annotations = {\n", - " \"test-recordlinkage-pool\": \"Test annotation for recordlinkage pool\"\n", + " \"test-recordlinkage-pool\": \"Test annotation for recordlinkage pool\",\n", "}\n", "recordlinkage_pod_labels = {\n", - " \"test-recordlinkage-pool\": \"test_label_for_recordlinkage_pool\"\n", + " \"test-recordlinkage-pool\": \"test_label_for_recordlinkage_pool\",\n", "}\n", "pool_image_create_request = datasite_client.api.services.worker_pool.create_image_and_pool_request(\n", " pool_name=pool_name_recordlinkage,\n", @@ -1329,7 +1327,7 @@ "outputs": [], "source": [ "assert not isinstance(pool_image_create_request, sy.SyftError), str(\n", - " pool_image_create_request\n", + " pool_image_create_request,\n", ")" ] }, @@ -1412,14 +1410,14 @@ "\n", "\n", "assert is_subset_dict(\n", - " recordlinkage_pod_annotations, recordlinkage_pool_pod_metadata.annotations\n", + " recordlinkage_pod_annotations, recordlinkage_pool_pod_metadata.annotations,\n", "), \"Annotations not found in recordlinkage pool pod metadata\"\n", "assert (\n", " \"labels\" in recordlinkage_pool_pod_metadata\n", "), \"Labels not found in recordlinkage pool pod metadata\"\n", "\n", "assert is_subset_dict(\n", - " recordlinkage_pod_labels, recordlinkage_pool_pod_metadata.labels\n", + " recordlinkage_pod_labels, recordlinkage_pool_pod_metadata.labels,\n", "), \"Annotations do not match in recordlinkage pool pod metadata\"" ] }, @@ -1471,10 +1469,10 @@ "source": [ "# Scale down the workers\n", "recordlinkage_pool_scale_res = datasite_client.api.services.worker_pool.scale(\n", - " number=1, pool_name=pool_name_recordlinkage\n", + " number=1, pool_name=pool_name_recordlinkage,\n", ")\n", "assert not isinstance(recordlinkage_pool_scale_res, sy.SyftError), str(\n", - " recordlinkage_pool_scale_res\n", + " recordlinkage_pool_scale_res,\n", ")\n", "recordlinkage_pool_scale_res" ] diff --git a/notebooks/api/0.8/12-custom-api-endpoint.ipynb b/notebooks/api/0.8/12-custom-api-endpoint.ipynb index 478446bcd17..027dde16aaa 100644 --- a/notebooks/api/0.8/12-custom-api-endpoint.ipynb +++ b/notebooks/api/0.8/12-custom-api-endpoint.ipynb @@ -25,8 +25,7 @@ "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import SyftError\n", - "from syft import SyftSuccess\n", + "from syft import SyftError, SyftSuccess\n", "\n", "server = sy.orchestra.launch(\n", " name=\"test-datasite-1\",\n", @@ -286,7 +285,7 @@ "outputs": [], "source": [ "result = datasite_guest.code.job_function(\n", - " endpoint=datasite_client.api.services.third.query\n", + " endpoint=datasite_client.api.services.third.query,\n", ")\n", "result" ] @@ -298,7 +297,7 @@ "outputs": [], "source": [ "result = datasite_guest.code.job_function(\n", - " endpoint=datasite_client.api.services.third.query\n", + " endpoint=datasite_client.api.services.third.query,\n", ")\n", "result" ] @@ -319,7 +318,7 @@ "outputs": [], "source": [ "result = datasite_guest.code.job_function(\n", - " endpoint=datasite_client.api.services.third.query\n", + " endpoint=datasite_client.api.services.third.query,\n", ")\n", "result" ] @@ -428,7 +427,7 @@ "\n", "\n", "response = datasite_client.api.services.api.update(\n", - " endpoint_path=\"test.update\", mock_function=updated_public_function\n", + " endpoint_path=\"test.update\", mock_function=updated_public_function,\n", ")\n", "assert isinstance(response, SyftSuccess), response\n", "response" @@ -455,7 +454,7 @@ "\n", "\n", "response = datasite_client.api.services.api.update(\n", - " endpoint_path=\"test.update\", private_function=updated_private_function\n", + " endpoint_path=\"test.update\", private_function=updated_private_function,\n", ")\n", "assert isinstance(response, SyftSuccess), response\n", "response" @@ -539,7 +538,7 @@ "\n", "\n", "response = datasite_client.api.services.api.update(\n", - " endpoint_path=\"test.update\", mock_function=bad_public_function\n", + " endpoint_path=\"test.update\", mock_function=bad_public_function,\n", ")\n", "assert isinstance(response, SyftError), response" ] @@ -558,7 +557,7 @@ "outputs": [], "source": [ "response = datasite_client.api.services.api.update(\n", - " endpoint_path=\"nonexistent\", mock_function=bad_public_function\n", + " endpoint_path=\"nonexistent\", mock_function=bad_public_function,\n", ")\n", "assert isinstance(response, SyftError), response" ] diff --git a/notebooks/api/0.8/13-forgot-user-password.ipynb b/notebooks/api/0.8/13-forgot-user-password.ipynb index 8ad3cdf0918..b09293e2abb 100644 --- a/notebooks/api/0.8/13-forgot-user-password.ipynb +++ b/notebooks/api/0.8/13-forgot-user-password.ipynb @@ -27,8 +27,7 @@ "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import SyftError\n", - "from syft import SyftSuccess\n", + "from syft import SyftError, SyftSuccess\n", "\n", "server = sy.orchestra.launch(\n", " name=\"test-datasite-1\",\n", @@ -113,7 +112,7 @@ "outputs": [], "source": [ "temp_token = datasite_client.users.request_password_reset(\n", - " datasite_client.notifications[-1].linked_obj.resolve.id\n", + " datasite_client.notifications[-1].linked_obj.resolve.id,\n", ")\n", "\n", "if not isinstance(temp_token, str):\n", @@ -149,7 +148,7 @@ "outputs": [], "source": [ "new_user_session = server.login(\n", - " email=\"new_syft_user@openmined.org\", password=\"Password123\"\n", + " email=\"new_syft_user@openmined.org\", password=\"Password123\",\n", ")\n", "\n", "if isinstance(new_user_session, SyftError):\n", diff --git a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb index 71567b558b0..a9c8dea21ea 100644 --- a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb +++ b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb @@ -51,7 +51,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"private-data-example-datasite-1\", port=\"auto\", reset=True\n", + " name=\"private-data-example-datasite-1\", port=\"auto\", reset=True,\n", ")" ] }, @@ -133,7 +133,7 @@ " summary=\"Contains private and mock versions of data\",\n", " description=dataset_markdown_description,\n", " asset_list=[\n", - " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", + " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1])),\n", " ],\n", ")\n", "\n", @@ -290,8 +290,8 @@ " name=\"my dataset2\",\n", " asset_list=[\n", " sy.Asset(\n", - " name=\"my asset2\", data=np.array([1, 2, 3]), mock=sy.ActionObject.empty()\n", - " )\n", + " name=\"my asset2\", data=np.array([1, 2, 3]), mock=sy.ActionObject.empty(),\n", + " ),\n", " ],\n", ")" ] diff --git a/notebooks/tutorials/data-owner/02-account-management.ipynb b/notebooks/tutorials/data-owner/02-account-management.ipynb index 66e01be2644..9c628723206 100644 --- a/notebooks/tutorials/data-owner/02-account-management.ipynb +++ b/notebooks/tutorials/data-owner/02-account-management.ipynb @@ -49,7 +49,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"account-management-example-datasite-1\", port=8041, reset=True\n", + " name=\"account-management-example-datasite-1\", port=8041, reset=True,\n", ")" ] }, @@ -205,7 +205,7 @@ "outputs": [], "source": [ "updated_user = client.users.update(\n", - " uid=new_user.id, role=ServiceRole.DATA_SCIENTIST, password=\"123\"\n", + " uid=new_user.id, role=ServiceRole.DATA_SCIENTIST, password=\"123\",\n", ")" ] }, diff --git a/notebooks/tutorials/data-owner/03-messages-and-requests.ipynb b/notebooks/tutorials/data-owner/03-messages-and-requests.ipynb index 40868ce167d..0fb0a12e672 100644 --- a/notebooks/tutorials/data-owner/03-messages-and-requests.ipynb +++ b/notebooks/tutorials/data-owner/03-messages-and-requests.ipynb @@ -49,7 +49,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"messages-requests-example-datasite-1-do\", port=7021, reset=True\n", + " name=\"messages-requests-example-datasite-1-do\", port=7021, reset=True,\n", ")" ] }, @@ -92,7 +92,7 @@ "dataset = sy.Dataset(\n", " name=\"my dataset\",\n", " asset_list=[\n", - " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", + " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1])),\n", " ],\n", ")\n", "admin_client.upload_dataset(dataset)" diff --git a/notebooks/tutorials/data-owner/05-syft-services-api.ipynb b/notebooks/tutorials/data-owner/05-syft-services-api.ipynb index d891bc16cdf..d2b8a90507d 100644 --- a/notebooks/tutorials/data-owner/05-syft-services-api.ipynb +++ b/notebooks/tutorials/data-owner/05-syft-services-api.ipynb @@ -57,7 +57,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"services-api-example-datasite-1\", port=\"auto\", reset=True\n", + " name=\"services-api-example-datasite-1\", port=\"auto\", reset=True,\n", ")" ] }, diff --git a/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb b/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb index 727928b60d8..939d1bea87a 100644 --- a/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb +++ b/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb @@ -49,7 +49,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"private-datasets-example-datasite-1\", port=8062, reset=True\n", + " name=\"private-datasets-example-datasite-1\", port=8062, reset=True,\n", ")" ] }, @@ -100,7 +100,7 @@ "dataset = sy.Dataset(\n", " name=\"my dataset\",\n", " asset_list=[\n", - " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", + " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1])),\n", " ],\n", ")" ] @@ -251,7 +251,7 @@ "outputs": [], "source": [ "@sy.syft_function(\n", - " input_policy=sy.ExactMatch(inp=asset), output_policy=sy.SingleExecutionExactOutput()\n", + " input_policy=sy.ExactMatch(inp=asset), output_policy=sy.SingleExecutionExactOutput(),\n", ")\n", "def add_pow(inp):\n", " x = inp + 3\n", diff --git a/notebooks/tutorials/data-scientist/04-syft-functions.ipynb b/notebooks/tutorials/data-scientist/04-syft-functions.ipynb index b0787b96356..63355d5f098 100644 --- a/notebooks/tutorials/data-scientist/04-syft-functions.ipynb +++ b/notebooks/tutorials/data-scientist/04-syft-functions.ipynb @@ -49,7 +49,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"syft-functions-example-datasite-1\", port=7022, reset=True\n", + " name=\"syft-functions-example-datasite-1\", port=7022, reset=True,\n", ")" ] }, @@ -102,7 +102,7 @@ "dataset = sy.Dataset(\n", " name=\"my dataset\",\n", " asset_list=[\n", - " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", + " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1])),\n", " ],\n", ")\n", "admin_client.upload_dataset(dataset)" diff --git a/notebooks/tutorials/data-scientist/05-messaging-and-requests.ipynb b/notebooks/tutorials/data-scientist/05-messaging-and-requests.ipynb index 68aa03d284a..656d167dbf0 100644 --- a/notebooks/tutorials/data-scientist/05-messaging-and-requests.ipynb +++ b/notebooks/tutorials/data-scientist/05-messaging-and-requests.ipynb @@ -49,7 +49,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"messages-requests-example-datasite-1-ds\", port=7023, reset=True\n", + " name=\"messages-requests-example-datasite-1-ds\", port=7023, reset=True,\n", ")" ] }, @@ -92,7 +92,7 @@ "dataset = sy.Dataset(\n", " name=\"my dataset\",\n", " asset_list=[\n", - " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1]))\n", + " sy.Asset(name=\"my asset\", data=np.array([1, 2, 3]), mock=np.array([1, 1, 1])),\n", " ],\n", ")\n", "admin_client.upload_dataset(dataset)" diff --git a/notebooks/tutorials/deployments/01-deploy-python.ipynb b/notebooks/tutorials/deployments/01-deploy-python.ipynb index 0dedeb13ea5..1db0913671e 100644 --- a/notebooks/tutorials/deployments/01-deploy-python.ipynb +++ b/notebooks/tutorials/deployments/01-deploy-python.ipynb @@ -81,7 +81,7 @@ "outputs": [], "source": [ "server = sy.orchestra.launch(\n", - " name=\"dev-mode-example-datasite-1\", port=8020, reset=True, dev_mode=True\n", + " name=\"dev-mode-example-datasite-1\", port=8020, reset=True, dev_mode=True,\n", ")" ] }, diff --git a/notebooks/tutorials/hello-syft/01-hello-syft.ipynb b/notebooks/tutorials/hello-syft/01-hello-syft.ipynb index cc2da4bf58f..5b51bbf0307 100644 --- a/notebooks/tutorials/hello-syft/01-hello-syft.ipynb +++ b/notebooks/tutorials/hello-syft/01-hello-syft.ipynb @@ -153,13 +153,13 @@ " {\n", " \"Patient_ID\": [\"011\", \"015\", \"022\", \"034\", \"044\"],\n", " \"Age\": [40, 39, 35, 60, 25],\n", - " }\n", + " },\n", " ),\n", " mock=pd.DataFrame(\n", - " {\"Patient_ID\": [\"1\", \"2\", \"3\", \"4\", \"5\"], \"Age\": [50, 49, 45, 70, 35]}\n", + " {\"Patient_ID\": [\"1\", \"2\", \"3\", \"4\", \"5\"], \"Age\": [50, 49, 45, 70, 35]},\n", " ),\n", " mock_is_real=False,\n", - " )\n", + " ),\n", " ],\n", ")\n", "root_datasite_client.upload_dataset(dataset)" @@ -559,7 +559,7 @@ " \"ds_client.api.services.code\" in completions3,\n", " \"ds_client.api.code\" in completions4,\n", " \"ds_client.api.parse_raw\" not in completions4, # no pydantic completions on api\n", - " ]\n", + " ],\n", ")" ] }, diff --git a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb index e97442b5dac..08ef38342b4 100644 --- a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb +++ b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb @@ -234,7 +234,7 @@ "outputs": [], "source": [ "main_contributor = sy.Contributor(\n", - " name=\"Jeffrey Salazar\", role=\"Dataset Creator\", email=\"jsala@ailab.com\"\n", + " name=\"Jeffrey Salazar\", role=\"Dataset Creator\", email=\"jsala@ailab.com\",\n", ")\n", "\n", "gpt2_user_log = sy.Dataset(\n", @@ -248,7 +248,7 @@ " contributors=[main_contributor],\n", " data=model_log,\n", " mock=mock_model_log,\n", - " )\n", + " ),\n", " ],\n", ")" ] @@ -381,7 +381,7 @@ "indices, inputs = mock.id.tolist(), mock[\"result\"].tolist()\n", "toxicity_results = toxicity.compute(predictions=inputs)\n", "mock_result = pd.DataFrame(\n", - " toxicity_results[\"toxicity\"], index=indices, columns=[\"toxicity\"]\n", + " toxicity_results[\"toxicity\"], index=indices, columns=[\"toxicity\"],\n", ")" ] }, @@ -412,8 +412,7 @@ "source": [ "@sy.syft_function_single_use(data=dataset.assets[0])\n", "def model_output_analysis(data):\n", - " \"\"\"\n", - " Evaluate the model's quantify the toxicity of the input texts using the R4 Target Model,\n", + " \"\"\"Evaluate the model's quantify the toxicity of the input texts using the R4 Target Model,\n", " a pretrained hate speech classification model\n", " Evaluate the model's estimated language polarity towards and social perceptions of a demographic\n", " (e.g. gender, race, sexual orientation).\n", @@ -426,7 +425,7 @@ " indices, inputs = data.id.tolist(), data[\"result\"].tolist()\n", " toxicity_results = toxicity.compute(predictions=inputs)\n", " return pd.DataFrame(\n", - " toxicity_results[\"toxicity\"], index=indices, columns=[\"toxicity\"]\n", + " toxicity_results[\"toxicity\"], index=indices, columns=[\"toxicity\"],\n", " )" ] }, diff --git a/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb b/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb index adca3805b12..0779bede98f 100644 --- a/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb +++ b/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb @@ -9,15 +9,14 @@ "source": [ "# third party\n", "import matplotlib.pyplot as plt\n", - "\n", - "# relative import\n", - "from mnist_dataset import mnist\n", - "from mnist_dataset import mnist_raw\n", "import numpy as np\n", "\n", "# syft absolute\n", "import syft as sy\n", "\n", + "# relative import\n", + "from mnist_dataset import mnist, mnist_raw\n", + "\n", "print(f\"{sy.__version__ = }\")" ] }, diff --git a/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb b/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb index 13e52c83015..485ad111219 100644 --- a/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb +++ b/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb @@ -191,8 +191,7 @@ "def mnist_3_linear_layers_torch(mnist_images, mnist_labels):\n", " # third party\n", " import torch\n", - " import torch.nn as nn\n", - " import torch.optim as optim\n", + " from torch import nn, optim\n", " from torch.utils.data import TensorDataset\n", "\n", " # Convert NumPy arrays to PyTorch tensors\n", @@ -202,7 +201,7 @@ " custom_dataset = TensorDataset(images_tensor, labels_tensor)\n", " # Define the data loader\n", " train_loader = torch.utils.data.DataLoader(\n", - " custom_dataset, batch_size=4, shuffle=True\n", + " custom_dataset, batch_size=4, shuffle=True,\n", " )\n", "\n", " # Define the neural network class\n", @@ -275,7 +274,7 @@ "outputs": [], "source": [ "train_accs, params = mnist_3_linear_layers_torch(\n", - " mnist_images=mock_images, mnist_labels=mock_labels\n", + " mnist_images=mock_images, mnist_labels=mock_labels,\n", ")" ] }, @@ -332,15 +331,14 @@ "source": [ "@sy.syft_function(\n", " input_policy=sy.ExactMatch(\n", - " mnist_images=mock_images_ptr, mnist_labels=mock_labels_ptr\n", + " mnist_images=mock_images_ptr, mnist_labels=mock_labels_ptr,\n", " ),\n", " output_policy=sy.SingleExecutionExactOutput(),\n", ")\n", "def mnist_3_linear_layers_torch(mnist_images, mnist_labels):\n", " # third party\n", " import torch\n", - " import torch.nn as nn\n", - " import torch.optim as optim\n", + " from torch import nn, optim\n", " from torch.utils.data import TensorDataset\n", "\n", " # Convert NumPy arrays to PyTorch tensors\n", @@ -350,7 +348,7 @@ " custom_dataset = TensorDataset(images_tensor, labels_tensor)\n", " # Define the data loader\n", " train_loader = torch.utils.data.DataLoader(\n", - " custom_dataset, batch_size=4, shuffle=True\n", + " custom_dataset, batch_size=4, shuffle=True,\n", " )\n", "\n", " # Define the neural network class\n", diff --git a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb index 5606ec79111..07283460b16 100644 --- a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb +++ b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb @@ -168,7 +168,7 @@ "outputs": [], "source": [ "mock_train_accs, mock_params = users_function(\n", - " mnist_images=mock_images, mnist_labels=mock_labels\n", + " mnist_images=mock_images, mnist_labels=mock_labels,\n", ")" ] }, @@ -224,7 +224,7 @@ "outputs": [], "source": [ "train_accs, params = users_function(\n", - " mnist_images=private_images, mnist_labels=private_labels\n", + " mnist_images=private_images, mnist_labels=private_labels,\n", ")" ] }, diff --git a/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb b/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb index 250a9f23dcc..5f4ba5da5c0 100644 --- a/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb +++ b/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb @@ -8,11 +8,10 @@ "outputs": [], "source": [ "# third party\n", - "from mnist_dataset import mnist\n", - "import torch\n", - "\n", "# syft absolute\n", - "import syft as sy" + "import syft as sy\n", + "import torch\n", + "from mnist_dataset import mnist" ] }, { @@ -75,7 +74,7 @@ "outputs": [], "source": [ "result = ds_client.code.mnist_3_linear_layers_torch(\n", - " mnist_images=training_images, mnist_labels=training_labels\n", + " mnist_images=training_images, mnist_labels=training_labels,\n", ")" ] }, @@ -156,7 +155,7 @@ "outputs": [], "source": [ "# third party\n", - "import torch.nn as nn\n", + "from torch import nn\n", "\n", "\n", "class MLP(nn.Module):\n", diff --git a/notebooks/tutorials/model-training/mnist_dataset.py b/notebooks/tutorials/model-training/mnist_dataset.py index 77b7c2c7afe..713ad12c9d9 100644 --- a/notebooks/tutorials/model-training/mnist_dataset.py +++ b/notebooks/tutorials/model-training/mnist_dataset.py @@ -1,5 +1,4 @@ -""" -Code for the MNIST dataset +"""Code for the MNIST dataset Source: https://github.com/google/jax/blob/main/examples/datasets.py """ @@ -7,9 +6,9 @@ import array import gzip import os -from os import path import struct import urllib.request +from os import path # third party import numpy as np @@ -52,7 +51,7 @@ def parse_images(filename): with gzip.open(filename, "rb") as fh: _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( - num_data, rows, cols + num_data, rows, cols, ) for filename in [ diff --git a/notebooks/tutorials/pandas-cookbook/00_cache_test.py b/notebooks/tutorials/pandas-cookbook/00_cache_test.py index 7d7c1ae4bd5..45af3386a28 100644 --- a/notebooks/tutorials/pandas-cookbook/00_cache_test.py +++ b/notebooks/tutorials/pandas-cookbook/00_cache_test.py @@ -7,8 +7,7 @@ def test_cache_download() -> None: import pandas as pd # syft absolute - from syft.util.util import PANDAS_DATA - from syft.util.util import autocache + from syft.util.util import PANDAS_DATA, autocache encoding = {"bikes.csv": "ISO-8859-1"} diff --git a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb index c0e76617809..25a5e5e52a2 100644 --- a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb +++ b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb @@ -97,9 +97,10 @@ "# stdlib\n", "from datetime import timedelta\n", "\n", + "import pandas as pd\n", + "\n", "# third party\n", "from dateutil.parser import parse\n", - "import pandas as pd\n", "from pandas._libs.tslibs.timestamps import Timestamp\n", "\n", "# syft absolute\n", @@ -295,7 +296,7 @@ }, "outputs": [], "source": [ - "# todo: give user data scientist role" + "# TODO: give user data scientist role" ] }, { @@ -514,7 +515,7 @@ "outputs": [], "source": [ "@sy.syft_function(\n", - " input_policy=sy.ExactMatch(df=asset), output_policy=sy.SingleExecutionExactOutput()\n", + " input_policy=sy.ExactMatch(df=asset), output_policy=sy.SingleExecutionExactOutput(),\n", ")\n", "def get_column(df):\n", " return df[\"Berri 1\"]" diff --git a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb index 434741c19fc..72d9032fd73 100644 --- a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb +++ b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb @@ -149,7 +149,7 @@ "source": [ "# because of mixed types we specify dtype to prevent any errors\n", "complaints = pd.read_csv(\n", - " sy.autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\"\n", + " sy.autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\",\n", ")" ] }, @@ -219,7 +219,7 @@ " \"X Coordinate (State Plane)\": lambda x: randint(1, 1000000),\n", " \"Y Coordinate (State Plane)\": lambda x: randint(1, 1000000),\n", " \"Complaint Type\": lambda x: random.choice(\n", - " [\"Illegal Parking\", \"Noise - Street/Sidewalk\", \"'Animal in a Park'\"]\n", + " [\"Illegal Parking\", \"Noise - Street/Sidewalk\", \"'Animal in a Park'\"],\n", " ),\n", " \"Descriptor\": lambda x: random.choice(\n", " [\n", @@ -227,7 +227,7 @@ " \"Branches Damaged\",\n", " \"Broken Fence\",\n", " \"Broken Glass\",\n", - " ]\n", + " ],\n", " ),\n", " \"School Number\": lambda x: random.choice(\n", " [\n", @@ -241,7 +241,7 @@ " \"B102\",\n", " \"B109\",\n", " \"B111\",\n", - " ]\n", + " ],\n", " ),\n", " \"Bridge Highway Segment\": lambda x: random.choice(\n", " [\n", @@ -250,7 +250,7 @@ " \"GrandCentral Pkwy/VanWyck Expwy/College Point Blvd (Exit 22 A-E)\",\n", " \"Hamilton Ave (Exit 2A) - Gowanus Expwy (I-278) (Exit 1)\",\n", " \"Harding Ave (Exit 9) - Throgs Neck Br\",\n", - " ]\n", + " ],\n", " ),\n", "}" ] @@ -295,7 +295,7 @@ " values = list(set(complaints[col]))\n", " mock_func = lambda x: random.choice(values) # noqa: E731,B023\n", " else:\n", - " for trigger in fake_triggers.keys():\n", + " for trigger in fake_triggers:\n", " if trigger in col:\n", " mock_func = fake_triggers[trigger]\n", " mock_data[col] = [mock_func(None) for x in range(len(complaints))]" @@ -337,7 +337,7 @@ "dataset = sy.Dataset(\n", " name=\"test\",\n", " asset_list=[\n", - " sy.Asset(name=\"complaints\", data=complaints, mock=mock, mock_is_real=False)\n", + " sy.Asset(name=\"complaints\", data=complaints, mock=mock, mock_is_real=False),\n", " ],\n", ")\n", "datasite_client.upload_dataset(dataset)" @@ -369,7 +369,7 @@ " website=\"https://www.caltech.edu/\",\n", ")\n", "\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "\n", "guest_datasite_client = server.client\n", "\n", diff --git a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb index c81eb08e469..5a714603853 100644 --- a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb +++ b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb @@ -137,7 +137,7 @@ "source": [ "# because of mixed types we specify dtype to prevent any errors\n", "complaints = pd.read_csv(\n", - " sy.autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\"\n", + " sy.autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\",\n", ")" ] }, @@ -234,7 +234,7 @@ " \"X Coordinate (State Plane)\": lambda x: randint(1, 1000000),\n", " \"Y Coordinate (State Plane)\": lambda x: randint(1, 1000000),\n", " \"Complaint Type\": lambda x: random.choice(\n", - " [\"Illegal Parking\", \"Noise - Street/Sidewalk\", \"'Animal in a Park'\"]\n", + " [\"Illegal Parking\", \"Noise - Street/Sidewalk\", \"'Animal in a Park'\"],\n", " ),\n", " \"Descriptor\": lambda x: random.choice(\n", " [\n", @@ -242,7 +242,7 @@ " \"Branches Damaged\",\n", " \"Broken Fence\",\n", " \"Broken Glass\",\n", - " ]\n", + " ],\n", " ),\n", " \"School Number\": lambda x: random.choice(\n", " [\n", @@ -256,7 +256,7 @@ " \"B102\",\n", " \"B109\",\n", " \"B111\",\n", - " ]\n", + " ],\n", " ),\n", " \"Bridge Highway Segment\": lambda x: random.choice(\n", " [\n", @@ -265,7 +265,7 @@ " \"GrandCentral Pkwy/VanWyck Expwy/College Point Blvd (Exit 22 A-E)\",\n", " \"Hamilton Ave (Exit 2A) - Gowanus Expwy (I-278) (Exit 1)\",\n", " \"Harding Ave (Exit 9) - Throgs Neck Br\",\n", - " ]\n", + " ],\n", " ),\n", "}" ] @@ -310,7 +310,7 @@ " values = list(set(complaints[col]))\n", " mock_func = lambda x: random.choice(values) # noqa: E731,B023\n", " else:\n", - " for trigger in fake_triggers.keys():\n", + " for trigger in fake_triggers:\n", " if trigger in col:\n", " mock_func = fake_triggers[trigger]\n", " mock_data[col] = [mock_func(None) for x in range(len(complaints))]" @@ -352,7 +352,7 @@ "dataset = sy.Dataset(\n", " name=\"bikes\",\n", " asset_list=[\n", - " sy.Asset(name=\"complaints\", data=complaints, mock=mock, mock_is_real=False)\n", + " sy.Asset(name=\"complaints\", data=complaints, mock=mock, mock_is_real=False),\n", " ],\n", ")\n", "datasite_client.upload_dataset(dataset)" @@ -384,7 +384,7 @@ " website=\"https://www.caltech.edu/\",\n", ")\n", "\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "\n", "guest_datasite_client = server.client\n", "\n", diff --git a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb index 3f99c0b2cf9..6f4086bee4c 100644 --- a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb +++ b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb @@ -89,15 +89,15 @@ "# stdlib\n", "from datetime import timedelta\n", "\n", + "import pandas as pd\n", + "\n", "# third party\n", "from dateutil.parser import parse\n", - "import pandas as pd\n", "from pandas._libs.tslibs.timestamps import Timestamp\n", "\n", "# syft absolute\n", "from syft.service.project.project import Project\n", - "from syft.util.util import PANDAS_DATA\n", - "from syft.util.util import autocache" + "from syft.util.util import PANDAS_DATA, autocache" ] }, { @@ -240,7 +240,7 @@ " website=\"https://www.caltech.edu/\",\n", ")\n", "\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "\n", "guest_datasite_client = server.client\n", "\n", @@ -577,7 +577,7 @@ }, "outputs": [], "source": [ - "# Todo, fix indexes in function" + "# TODO, fix indexes in function" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb index 0878c0a9cfe..eaf1b1a0935 100644 --- a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb @@ -127,8 +127,8 @@ "outputs": [], "source": [ "# stdlib\n", - "from datetime import timedelta\n", "import random\n", + "from datetime import timedelta\n", "\n", "# third party\n", "from dateutil.parser import parse\n", @@ -136,8 +136,7 @@ "\n", "# syft absolute\n", "from syft.service.project.project import Project\n", - "from syft.util.util import PANDAS_DATA\n", - "from syft.util.util import autocache" + "from syft.util.util import PANDAS_DATA, autocache" ] }, { @@ -150,7 +149,7 @@ "outputs": [], "source": [ "weather_2012_final = pd.read_csv(\n", - " autocache(f\"{PANDAS_DATA}/weather_2012.csv\"), index_col=\"Date/Time\"\n", + " autocache(f\"{PANDAS_DATA}/weather_2012.csv\"), index_col=\"Date/Time\",\n", ")" ] }, @@ -276,7 +275,7 @@ " )\n", "\n", " assets.append(\n", - " sy.Asset(name=f\"weather{month}\", data=weather, mock=mock, mock_is_real=False)\n", + " sy.Asset(name=f\"weather{month}\", data=weather, mock=mock, mock_is_real=False),\n", " )" ] }, @@ -339,7 +338,7 @@ " website=\"https://www.caltech.edu/\",\n", ")\n", "\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "\n", "guest_datasite_client = server.client\n", "\n", @@ -773,7 +772,7 @@ "source": [ "@sy.syft_function(\n", " input_policy=sy.ExactMatch(\n", - " month1df=ds.assets[\"weather1\"], month2df=ds.assets[\"weather2\"]\n", + " month1df=ds.assets[\"weather1\"], month2df=ds.assets[\"weather2\"],\n", " ),\n", " output_policy=sy.SingleExecutionExactOutput(),\n", ")\n", diff --git a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb index d7c4cafd3d0..0b248f3dcaf 100644 --- a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb +++ b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb @@ -127,8 +127,8 @@ "outputs": [], "source": [ "# stdlib\n", - "from datetime import timedelta\n", "import random\n", + "from datetime import timedelta\n", "\n", "# third party\n", "from dateutil.parser import parse\n", @@ -136,8 +136,7 @@ "\n", "# syft absolute\n", "from syft.service.project.project import Project\n", - "from syft.util.util import PANDAS_DATA\n", - "from syft.util.util import autocache" + "from syft.util.util import PANDAS_DATA, autocache" ] }, { @@ -284,7 +283,7 @@ "dataset = sy.Dataset(\n", " name=\"test\",\n", " asset_list=[\n", - " sy.Asset(name=\"weather\", data=weather_2012_final, mock=mock, mock_is_real=False)\n", + " sy.Asset(name=\"weather\", data=weather_2012_final, mock=mock, mock_is_real=False),\n", " ],\n", ")\n", "root_datasite_client.upload_dataset(dataset)" @@ -327,7 +326,7 @@ " institution=\"Caltech\",\n", " website=\"https://www.caltech.edu/\",\n", ")\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "guest_datasite_client = server.client\n", "guest_client = guest_datasite_client.login(email=\"jane@caltech.edu\", password=\"abc123\")" ] diff --git a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb index 303e0a808d4..735b361c541 100644 --- a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb @@ -104,8 +104,7 @@ "\n", "# syft absolute\n", "from syft.service.project.project import Project\n", - "from syft.util.util import PANDAS_DATA\n", - "from syft.util.util import autocache\n", + "from syft.util.util import PANDAS_DATA, autocache\n", "\n", "# Make the graphs a bit prettier, and bigger\n", "plt.style.use(\"ggplot\")\n", @@ -146,7 +145,7 @@ "outputs": [], "source": [ "service_requests = pd.read_csv(\n", - " autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\"\n", + " autocache(f\"{PANDAS_DATA}/311-service-requests.csv\"), dtype=\"unicode\",\n", ")" ] }, @@ -304,7 +303,7 @@ " data=service_requests,\n", " mock=mock,\n", " mock_is_real=False,\n", - " )\n", + " ),\n", " ],\n", ")\n", "root_datasite_client.upload_dataset(dataset)" @@ -335,7 +334,7 @@ " institution=\"Caltech\",\n", " website=\"https://www.caltech.edu/\",\n", ")\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "guest_datasite_client = server.client\n", "guest_client = guest_datasite_client.login(email=\"jane@caltech.edu\", password=\"abc123\")" ] @@ -461,7 +460,7 @@ "outputs": [], "source": [ "na_values = [\"NO CLUE\", \"N/A\", \"0\"]\n", - "requests.replace(na_values, np.NaN)" + "requests.replace(na_values, np.nan)" ] }, { @@ -740,7 +739,7 @@ "\n", " df[\"Incident Zip\"] = fix_zip_codes(df[\"Incident Zip\"])\n", " result = df[\"Incident Zip\"].unique()\n", - " # todo, we are adding list(result) here to fix serialization errors\n", + " # TODO, we are adding list(result) here to fix serialization errors\n", " return list(result)" ] }, diff --git a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb index 2e40e221bfb..574040bac5c 100644 --- a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb +++ b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb @@ -114,8 +114,7 @@ "\n", "# syft absolute\n", "from syft.service.project.project import Project\n", - "from syft.util.util import PANDAS_DATA\n", - "from syft.util.util import autocache\n", + "from syft.util.util import PANDAS_DATA, autocache\n", "\n", "plt.style.use(\"ggplot\")\n", "plt.rcParams[\"figure.figsize\"] = (15, 3)\n", @@ -208,8 +207,8 @@ "outputs": [], "source": [ "# stdlib\n", - "from datetime import timedelta\n", "import random\n", + "from datetime import timedelta\n", "from random import randint\n", "\n", "# third party\n", @@ -250,7 +249,7 @@ " \"ubuntu-extras-keyring\",\n", " \"libbsd0\",\n", " \"libxres-dev\",\n", - " ]\n", + " ],\n", " ),\n", " \"mru-program\": lambda: random.choice(\n", " [\n", @@ -258,7 +257,7 @@ " \"/usr/bin/onboard\",\n", " \"/lib/init/upstart-job\",\n", " \"/usr/bin/page\",\n", - " ]\n", + " ],\n", " ),\n", " \"tag\": lambda: random.choice([\"\", \"\", \"nan\"]),\n", "}" @@ -325,7 +324,7 @@ " data=popcon,\n", " mock=mock,\n", " mock_is_real=False,\n", - " )\n", + " ),\n", " ],\n", ")\n", "root_datasite_client.upload_dataset(dataset)" @@ -369,7 +368,7 @@ " institution=\"Caltech\",\n", " website=\"https://www.caltech.edu/\",\n", ")\n", - "# todo: give user data scientist role\n", + "# TODO: give user data scientist role\n", "guest_datasite_client = server.client\n", "guest_client = guest_datasite_client.login(email=\"jane@caltech.edu\", password=\"abc123\")" ] diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index ed8a496c407..0d478dac30b 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -43,7 +43,7 @@ "# this notebook should only be used to run the latest deployed version of syft\n", "# the notebooks after this (1a/1b and 2), will test migrating from that latest version\n", "print(\n", - " f\"latest deployed version: {latest_deployed_version}, installed version: {sy.__version__}\"\n", + " f\"latest deployed version: {latest_deployed_version}, installed version: {sy.__version__}\",\n", ")\n", "# assert (\n", "# latest_deployed_version == sy.__version__\n", @@ -93,7 +93,7 @@ "outputs": [], "source": [ "client.register(\n", - " email=\"ds@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\"\n", + " email=\"ds@openmined.org\", name=\"John Doe\", password=\"pw\", password_verify=\"pw\",\n", ")" ] }, @@ -131,7 +131,7 @@ " mock=np.array([10, 11, 12, 13, 14]),\n", " data=np.array([[15, 16, 17, 18, 19] for _ in range(100_000)]),\n", " mock_is_real=True,\n", - " )\n", + " ),\n", " ],\n", ")\n", "\n", @@ -212,7 +212,7 @@ "metadata": {}, "outputs": [], "source": [ - "# todo: add more data" + "# TODO: add more data" ] }, { diff --git a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb index ad247d91a38..f3e5da16981 100644 --- a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb +++ b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb @@ -102,7 +102,7 @@ "assert blob_path.exists()\n", "assert yaml_path.exists()\n", "\n", - "print(f\"Saved migration data to {str(blob_path.resolve())}\")" + "print(f\"Saved migration data to {blob_path.resolve()!s}\")" ] }, { diff --git a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb index 993894948d4..02c0a2ab6c9 100644 --- a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb @@ -71,7 +71,7 @@ "outputs": [], "source": [ "blob_path = Path(\"./migration.blob\")\n", - "print(f\"Loading migration data from {str(blob_path.resolve())}\")\n", + "print(f\"Loading migration data from {blob_path.resolve()!s}\")\n", "\n", "res = client.load_migration_data(blob_path)\n", "assert isinstance(res, sy.SyftSuccess)" diff --git a/packages/grid/backend/grid/api/new/new.py b/packages/grid/backend/grid/api/new/new.py index 3ff61bb041b..7b2630d1b56 100644 --- a/packages/grid/backend/grid/api/new/new.py +++ b/packages/grid/backend/grid/api/new/new.py @@ -1,7 +1,6 @@ # syft absolute -from syft.server.routes import make_routes - # server absolute from grid.core.server import worker +from syft.server.routes import make_routes router = make_routes(worker=worker) diff --git a/packages/grid/backend/grid/api/router.py b/packages/grid/backend/grid/api/router.py index 8412869db53..ea9e2d2f9d1 100644 --- a/packages/grid/backend/grid/api/router.py +++ b/packages/grid/backend/grid/api/router.py @@ -1,5 +1,4 @@ -""" -Add each api routes to the application main router. +"""Add each api routes to the application main router. Accesing a specific URL the user would be redirected to the correct router and the specific request handler. """ diff --git a/packages/grid/backend/grid/bootstrap.py b/packages/grid/backend/grid/bootstrap.py index 914411ec864..61b10691c73 100644 --- a/packages/grid/backend/grid/bootstrap.py +++ b/packages/grid/backend/grid/bootstrap.py @@ -1,9 +1,9 @@ # stdlib import argparse -from collections.abc import Callable import json import os import uuid +from collections.abc import Callable # third party from nacl.encoding import HexEncoder @@ -109,10 +109,9 @@ def validate_uid(server_uid: str) -> str: def get_credential( - key: str, validation_func: Callable, generation_func: Callable + key: str, validation_func: Callable, generation_func: Callable, ) -> str: - """ - This method will try to get a credential and if it isn't supplied or doesn't exist + """This method will try to get a credential and if it isn't supplied or doesn't exist it will generate one and save it. If the one supplied doesn't match the one saved it will raise an Exception. """ @@ -141,7 +140,7 @@ def get_credential( def get_private_key() -> str: return get_credential( - SERVER_PRIVATE_KEY, validate_private_key, generate_private_key + SERVER_PRIVATE_KEY, validate_private_key, generate_private_key, ) @@ -159,10 +158,10 @@ def delete_credential_file() -> None: parser.add_argument("--private_key", action="store_true", help="Get Private Key") parser.add_argument("--uid", action="store_true", help="Get UID") parser.add_argument( - "--file", action="store_true", help="Generate credentials as file" + "--file", action="store_true", help="Generate credentials as file", ) parser.add_argument( - "--debug", action="store_true", help="Show ENV and file credentials" + "--debug", action="store_true", help="Show ENV and file credentials", ) args = parser.parse_args() diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 0cd6e026e03..339e8f0f417 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -4,13 +4,8 @@ from typing import Any # third party -from pydantic import AnyHttpUrl -from pydantic import EmailStr -from pydantic import HttpUrl -from pydantic import field_validator -from pydantic import model_validator -from pydantic_settings import BaseSettings -from pydantic_settings import SettingsConfigDict +from pydantic import AnyHttpUrl, EmailStr, HttpUrl, field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self _truthy = {"yes", "y", "true", "t", "on", "1"} @@ -18,7 +13,7 @@ def _distutils_strtoint(s: str) -> int: - """implements the deprecated distutils.util.strtoint""" + """Implements the deprecated distutils.util.strtoint""" ls = s.lower() if ls in _truthy: return 1 @@ -81,14 +76,14 @@ def get_project_name(self) -> Self: EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 EMAIL_TEMPLATES_DIR: str = os.path.expandvars( - "$HOME/app/grid/email-templates/build" + "$HOME/app/grid/email-templates/build", ) EMAILS_ENABLED: bool = False @model_validator(mode="after") def get_emails_enabled(self) -> Self: self.EMAILS_ENABLED = bool( - self.SMTP_HOST and self.SMTP_PORT and self.EMAILS_FROM_EMAIL + self.SMTP_HOST and self.SMTP_PORT and self.EMAILS_FROM_EMAIL, ) return self @@ -113,7 +108,7 @@ def get_emails_enabled(self) -> Self: S3_ROOT_PWD: str | None = os.getenv("S3_ROOT_PWD", "admin") S3_REGION: str = os.getenv("S3_REGION", "us-east-1") S3_PRESIGNED_TIMEOUT_SECS: int = int( - os.getenv("S3_PRESIGNED_TIMEOUT_SECS", 1800) + os.getenv("S3_PRESIGNED_TIMEOUT_SECS", 1800), ) # 30 minutes in seconds SEAWEED_MOUNT_PORT: int = int(os.getenv("SEAWEED_MOUNT_PORT", 4001)) @@ -153,11 +148,11 @@ def get_emails_enabled(self) -> Self: ) ASSOCIATION_TIMEOUT: int = 10 ASSOCIATION_REQUEST_AUTO_APPROVAL: bool = str_to_bool( - os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False") + os.getenv("ASSOCIATION_REQUEST_AUTO_APPROVAL", "False"), ) MIN_SIZE_BLOB_STORAGE_MB: int = int(os.getenv("MIN_SIZE_BLOB_STORAGE_MB", 1)) REVERSE_TUNNEL_ENABLED: bool = str_to_bool( - os.getenv("REVERSE_TUNNEL_ENABLED", "false") + os.getenv("REVERSE_TUNNEL_ENABLED", "false"), ) model_config = SettingsConfigDict(case_sensitive=True) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 9802eece8e8..c662b1d7364 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -1,23 +1,21 @@ # syft absolute from syft.abstract_server import ServerType -from syft.server.datasite import Datasite -from syft.server.datasite import Server +from syft.server.datasite import Datasite, Server from syft.server.enclave import Enclave from syft.server.gateway import Gateway -from syft.server.server import get_default_bucket_name -from syft.server.server import get_enable_warnings -from syft.server.server import get_server_name -from syft.server.server import get_server_side_type -from syft.server.server import get_server_type -from syft.server.server import get_server_uid_env -from syft.service.queue.zmq_queue import ZMQClientConfig -from syft.service.queue.zmq_queue import ZMQQueueConfig -from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig -from syft.store.blob_storage.seaweedfs import SeaweedFSConfig +from syft.server.server import ( + get_default_bucket_name, + get_enable_warnings, + get_server_name, + get_server_side_type, + get_server_type, + get_server_uid_env, +) +from syft.service.queue.zmq_queue import ZMQClientConfig, ZMQQueueConfig +from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig, SeaweedFSConfig from syft.store.mongo_client import MongoStoreClientConfig from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.sqlite_document_store import SQLiteStoreClientConfig -from syft.store.sqlite_document_store import SQLiteStoreConfig +from syft.store.sqlite_document_store import SQLiteStoreClientConfig, SQLiteStoreConfig from syft.types.uid import UID # server absolute @@ -31,7 +29,7 @@ def queue_config() -> ZMQQueueConfig: queue_port=settings.QUEUE_PORT, n_consumers=settings.N_CONSUMERS, consumer_service=settings.CONSUMER_SERVICE_NAME, - ) + ), ) return queue_config diff --git a/packages/grid/backend/grid/main.py b/packages/grid/backend/grid/main.py index 497a2dd7a90..9512ea861cc 100644 --- a/packages/grid/backend/grid/main.py +++ b/packages/grid/backend/grid/main.py @@ -1,6 +1,6 @@ # stdlib -from contextlib import asynccontextmanager import logging +from contextlib import asynccontextmanager from typing import Any # third party @@ -70,8 +70,7 @@ async def lifespan(app: FastAPI) -> Any: response_class=JSONResponse, ) def healthcheck() -> dict[str, str]: - """ - Currently, all service backends must satisfy either of the following requirements to + """Currently, all service backends must satisfy either of the following requirements to pass the HTTP health checks sent to it from the GCE loadbalancer: 1. Respond with a 200 on '/'. The content does not matter. 2. Expose an arbitrary url as a readiness probe on the pods backing the Service. diff --git a/packages/grid/enclave/attestation/server/attestation_main.py b/packages/grid/enclave/attestation/server/attestation_main.py index fb0658d7151..e44854b66be 100644 --- a/packages/grid/enclave/attestation/server/attestation_main.py +++ b/packages/grid/enclave/attestation/server/attestation_main.py @@ -7,9 +7,11 @@ from loguru import logger # relative -from .attestation_models import CPUAttestationResponseModel -from .attestation_models import GPUAttestationResponseModel -from .attestation_models import ResponseModel +from .attestation_models import ( + CPUAttestationResponseModel, + GPUAttestationResponseModel, + ResponseModel, +) from .cpu_attestation import attest_cpu from .gpu_attestation import attest_gpu diff --git a/packages/grid/enclave/attestation/server/cpu_attestation.py b/packages/grid/enclave/attestation/server/cpu_attestation.py index 356157a2e0f..dafe4bbc4fb 100644 --- a/packages/grid/enclave/attestation/server/cpu_attestation.py +++ b/packages/grid/enclave/attestation/server/cpu_attestation.py @@ -8,7 +8,7 @@ def attest_cpu() -> tuple[str, str]: # Fetch report from Micrsoft Attestation library cpu_report = subprocess.run( - ["/app/AttestationClient"], capture_output=True, text=True + ["/app/AttestationClient"], capture_output=True, text=True, check=False, ) logger.debug(f"Stdout: {cpu_report.stdout}") logger.debug(f"Stderr: {cpu_report.stderr}") @@ -20,7 +20,7 @@ def attest_cpu() -> tuple[str, str]: # Fetch token from Micrsoft Attestation library cpu_token = subprocess.run( - ["/app/AttestationClient", "-o", "token"], capture_output=True, text=True + ["/app/AttestationClient", "-o", "token"], capture_output=True, text=True, check=False, ) logger.debug(f"Stdout: {cpu_token.stdout}") logger.debug(f"Stderr: {cpu_token.stderr}") diff --git a/packages/grid/enclave/attestation/server/gpu_attestation.py b/packages/grid/enclave/attestation/server/gpu_attestation.py index 38eccc8a6df..791d4cffbfa 100644 --- a/packages/grid/enclave/attestation/server/gpu_attestation.py +++ b/packages/grid/enclave/attestation/server/gpu_attestation.py @@ -29,7 +29,7 @@ def attest_gpu() -> tuple[str, str]: logger.info("[RemoteGPUTest] server name : {}", client.get_name()) client.add_verifier( - attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "" + attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "", ) # Step 1: Redirect stdout diff --git a/packages/grid/helm/generate_helm_notes.py b/packages/grid/helm/generate_helm_notes.py index 6869a79a97d..a9cebb8119f 100644 --- a/packages/grid/helm/generate_helm_notes.py +++ b/packages/grid/helm/generate_helm_notes.py @@ -1,13 +1,12 @@ # stdlib import json import os -from pathlib import Path import sys +from pathlib import Path def add_notes(helm_chart_template_dir: str) -> None: """Add notes or information post helm install or upgrade.""" - notes = """ Thank you for installing {{ .Chart.Name }}. Your release is named {{ .Release.Name }}. @@ -37,15 +36,15 @@ def get_protocol_changes() -> str: "../../", "syft/src/syft/protocol", "protocol_version.json", - ) - ) + ), + ), ) protocol_changes = "" if protocol_path.exists(): dev_protocol_changes = json.loads(protocol_path.read_text()).get("dev", {}) protocol_changes = json.dumps( - dev_protocol_changes.get("object_versions", {}), indent=4 + dev_protocol_changes.get("object_versions", {}), indent=4, ) protocol_changelog = f""" diff --git a/packages/grid/seaweedfs/src/api.py b/packages/grid/seaweedfs/src/api.py index b3d48bab313..894428efdec 100644 --- a/packages/grid/seaweedfs/src/api.py +++ b/packages/grid/seaweedfs/src/api.py @@ -1,11 +1,10 @@ # stdlib import logging -from pathlib import Path import subprocess +from pathlib import Path # third party -from fastapi import FastAPI -from fastapi import HTTPException +from fastapi import FastAPI, HTTPException # first party from src.automount import automount @@ -49,7 +48,7 @@ def mount(opts: MountOptions, overwrite: bool = False) -> dict: raise HTTPException(status_code=400, detail=str(e)) except subprocess.CalledProcessError as e: logger.error( - f"Mount error: code={e.returncode} stdout={e.stdout} stderr={e.stderr}" + f"Mount error: code={e.returncode} stdout={e.stdout} stderr={e.stderr}", ) raise HTTPException(status_code=500, detail=str(e)) except Exception as e: @@ -66,7 +65,7 @@ def configure_azure(first_res: dict) -> str: bucket_name = first_res["bucket_name"] logger.info( - f"Configuring azure bucket name={bucket_name} remote={remote_name} container={container_name}" + f"Configuring azure bucket name={bucket_name} remote={remote_name} container={container_name}", ) # popen a daemon process @@ -79,6 +78,6 @@ def configure_azure(first_res: dict) -> str: bucket_name, container_name, account_key, - ] + ], check=False, ) return str(res.returncode) diff --git a/packages/grid/seaweedfs/src/automount.py b/packages/grid/seaweedfs/src/automount.py index 790e42a0a3e..dac4f94bdc3 100644 --- a/packages/grid/seaweedfs/src/automount.py +++ b/packages/grid/seaweedfs/src/automount.py @@ -1,8 +1,8 @@ # stdlib import logging import logging.config -from pathlib import Path import subprocess +from pathlib import Path # third party import yaml @@ -24,7 +24,7 @@ def automount(automount_conf: Path, mount_conf_dir: Path) -> None: try: logger.info( f"Auto mount type={mount_opts.remote_bucket.type} " - f"bucket={mount_opts.remote_bucket.bucket_name}" + f"bucket={mount_opts.remote_bucket.bucket_name}", ) result = mount_bucket( mount_opts, @@ -36,7 +36,7 @@ def automount(automount_conf: Path, mount_conf_dir: Path) -> None: logger.info(e) except subprocess.CalledProcessError as e: logger.error( - f"Mount error: code={e.returncode} stdout={e.stdout} stderr={e.stderr}" + f"Mount error: code={e.returncode} stdout={e.stdout} stderr={e.stderr}", ) except Exception as e: logger.error(f"Unhandled exception: {e}") diff --git a/packages/grid/seaweedfs/src/buckets.py b/packages/grid/seaweedfs/src/buckets.py index 4d5184f3938..cf56e69e496 100644 --- a/packages/grid/seaweedfs/src/buckets.py +++ b/packages/grid/seaweedfs/src/buckets.py @@ -1,14 +1,11 @@ # stdlib -from enum import Enum -from enum import unique import json +from enum import Enum, unique from pathlib import Path from typing import Any # third party -from pydantic import BaseModel -from pydantic import Field -from pydantic import field_validator +from pydantic import BaseModel, Field, field_validator __all__ = [ "BucketType", @@ -37,7 +34,7 @@ class BaseBucket(BaseModel): @property def full_bucket_name(self) -> str: - raise NotImplementedError() + raise NotImplementedError def check_creds(v: Any) -> Any: @@ -48,7 +45,6 @@ def check_creds(v: Any) -> Any: def check_and_read_creds(v: Any) -> Any: """Check if creds are provided as a path to a JSON file, load them if so.""" - v = check_creds(v) if isinstance(v, str | Path): return json.loads(Path(v).read_text()) diff --git a/packages/grid/seaweedfs/src/mount.py b/packages/grid/seaweedfs/src/mount.py index 3a04fb65716..3203dcd2e65 100644 --- a/packages/grid/seaweedfs/src/mount.py +++ b/packages/grid/seaweedfs/src/mount.py @@ -1,22 +1,21 @@ # stdlib -from hashlib import sha256 import logging -from pathlib import Path import re import shutil import subprocess +from hashlib import sha256 +from pathlib import Path from typing import Any # relative -from .buckets import AzureCreds -from .buckets import BucketType -from .buckets import GCSCreds -from .buckets import S3Creds -from .mount_cmd import MountCmdArgs -from .mount_cmd import SupervisordConfArgs -from .mount_cmd import create_mount_cmd -from .mount_cmd import create_supervisord_conf -from .mount_cmd import create_sync_cmd +from .buckets import AzureCreds, BucketType, GCSCreds, S3Creds +from .mount_cmd import ( + MountCmdArgs, + SupervisordConfArgs, + create_mount_cmd, + create_supervisord_conf, + create_sync_cmd, +) from .mount_options import MountOptions from .util import dict_upper_keys @@ -46,7 +45,6 @@ def mount_bucket( overwrite: bool = False, ) -> dict: """Mount a remote bucket in seaweedfs""" - # create a seaweedfs safe config name swfs_config_name = seaweed_safe_config_name( remote_name=opts.remote_bucket.type.value, @@ -72,7 +70,7 @@ def mount_bucket( remote_bucket=opts.remote_bucket.bucket_name, remote_type=opts.remote_bucket.type.value, remote_creds=get_remote_cred_args(opts.remote_bucket.type), - ) + ), ) proc = subprocess.run( mount_shell_cmd, @@ -102,7 +100,7 @@ def mount_bucket( SupervisordConfArgs( name=supervisord_conf_name, command=f'sh -c "{sync_cmd}"', - ) + ), ) # write the config to the supervisord include directory mount_conf_file.write_text(mount_conf) diff --git a/packages/grid/seaweedfs/src/mount_cmd.py b/packages/grid/seaweedfs/src/mount_cmd.py index 647cadc195e..497d4282c3d 100644 --- a/packages/grid/seaweedfs/src/mount_cmd.py +++ b/packages/grid/seaweedfs/src/mount_cmd.py @@ -1,8 +1,7 @@ # stdlib # third party -from pydantic import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field __MOUNT_CMD_TEMPLATE = """ . ./scripts/wait_for_swfs.sh && @@ -46,7 +45,6 @@ class SupervisordConfArgs(BaseModel): def create_mount_cmd(args: MountCmdArgs) -> str: """Generate the seaweedfs mount command""" - return __MOUNT_CMD_TEMPLATE.format(**args.model_dump()).replace("\n", " ").strip() @@ -57,5 +55,4 @@ def create_sync_cmd(local_bucket: str) -> str: def create_supervisord_conf(args: SupervisordConfArgs) -> str: """Generate the supervisord configuration for a command""" - return __SUPERVISORD_TEMPLATE.format(**args.model_dump()).strip() diff --git a/packages/grid/seaweedfs/src/mount_options.py b/packages/grid/seaweedfs/src/mount_options.py index 317fe65c905..ba8d62167f0 100644 --- a/packages/grid/seaweedfs/src/mount_options.py +++ b/packages/grid/seaweedfs/src/mount_options.py @@ -2,13 +2,10 @@ from typing import Any # third party -from pydantic import BaseModel -from pydantic import field_validator +from pydantic import BaseModel, field_validator # relative -from .buckets import AzureBucket -from .buckets import GCSBucket -from .buckets import S3Bucket +from .buckets import AzureBucket, GCSBucket, S3Bucket __all__ = ["MountOptions"] diff --git a/packages/grid/seaweedfs/src/util.py b/packages/grid/seaweedfs/src/util.py index 0c96b809759..d877c3fc184 100644 --- a/packages/grid/seaweedfs/src/util.py +++ b/packages/grid/seaweedfs/src/util.py @@ -1,10 +1,8 @@ def dict_upper_keys(env_dict: dict[str, str]) -> dict[str, str]: """Convert all keys in a dictionary to uppercase""" - return {key.upper(): val for key, val in env_dict.items()} def dict_lower_keys(env_dict: dict[str, str]) -> dict[str, str]: """Convert all keys in a dictionary to lowercase""" - return {key.lower(): val for key, val in env_dict.items()} diff --git a/packages/grid/seaweedfs/tests/conftest.py b/packages/grid/seaweedfs/tests/conftest.py index 3feef9c85c3..02cc5b949e6 100644 --- a/packages/grid/seaweedfs/tests/conftest.py +++ b/packages/grid/seaweedfs/tests/conftest.py @@ -1,7 +1,7 @@ # stdlib +import shutil from pathlib import Path from secrets import token_hex -import shutil from tempfile import gettempdir # third party @@ -10,7 +10,7 @@ __all__ = ["random_path"] -@pytest.fixture +@pytest.fixture() def random_path() -> Path: # type: ignore path = Path(gettempdir(), f"{token_hex(8)}") yield path diff --git a/packages/grid/seaweedfs/tests/mount_cmd_test.py b/packages/grid/seaweedfs/tests/mount_cmd_test.py index 4489989cc34..3eaaba37bde 100644 --- a/packages/grid/seaweedfs/tests/mount_cmd_test.py +++ b/packages/grid/seaweedfs/tests/mount_cmd_test.py @@ -6,10 +6,12 @@ import pytest # first party -from src.mount_cmd import MountCmdArgs -from src.mount_cmd import SupervisordConfArgs -from src.mount_cmd import create_mount_cmd -from src.mount_cmd import create_supervisord_conf +from src.mount_cmd import ( + MountCmdArgs, + SupervisordConfArgs, + create_mount_cmd, + create_supervisord_conf, +) def test_mount_cmd() -> None: diff --git a/packages/grid/seaweedfs/tests/mount_options_test.py b/packages/grid/seaweedfs/tests/mount_options_test.py index 4fe42e6cecc..2595e351cfb 100644 --- a/packages/grid/seaweedfs/tests/mount_options_test.py +++ b/packages/grid/seaweedfs/tests/mount_options_test.py @@ -4,12 +4,7 @@ from secrets import token_hex # first party -from src.buckets import AzureBucket -from src.buckets import AzureCreds -from src.buckets import GCSBucket -from src.buckets import GCSCreds -from src.buckets import S3Bucket -from src.buckets import S3Creds +from src.buckets import AzureBucket, AzureCreds, GCSBucket, GCSCreds, S3Bucket, S3Creds from src.mount_options import MountOptions @@ -20,14 +15,11 @@ def test_mount_options_s3(random_path: Path) -> None: "aws_secret_access_key": token_hex(8), } opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "s3", "bucket_name": token_hex(8), "creds": creds_obj, }, - } ) assert isinstance(opts.remote_bucket, S3Bucket) @@ -36,14 +28,11 @@ def test_mount_options_s3(random_path: Path) -> None: # test creds as path random_path.write_text(json.dumps(creds_obj)) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "s3", "bucket_name": token_hex(8), "creds": str(random_path), }, - } ) assert isinstance(opts.remote_bucket, S3Bucket) @@ -62,14 +51,11 @@ def test_mount_options_gcs(random_path: Path) -> None: "universe_domain": "googleapis.com", } opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "gcs", "bucket_name": token_hex(8), "creds": creds_obj, }, - } ) assert isinstance(opts.remote_bucket, GCSBucket) @@ -78,14 +64,11 @@ def test_mount_options_gcs(random_path: Path) -> None: # test creds as path random_path.write_text(json.dumps(creds_obj)) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "gcs", "bucket_name": token_hex(8), "creds": random_path, }, - } ) assert isinstance(opts.remote_bucket, GCSBucket) @@ -99,14 +82,11 @@ def test_mount_options_azure(random_path: Path) -> None: "azure_account_key": token_hex(8), } opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "azure", "container_name": token_hex(8), "creds": creds_obj, }, - } ) assert isinstance(opts.remote_bucket, AzureBucket) assert isinstance(opts.remote_bucket.creds, AzureCreds) @@ -114,14 +94,11 @@ def test_mount_options_azure(random_path: Path) -> None: # test creds as path random_path.write_text(json.dumps(creds_obj)) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "azure", "container_name": token_hex(8), "creds": random_path, }, - } ) assert isinstance(opts.remote_bucket, AzureBucket) assert isinstance(opts.remote_bucket.creds, AzureCreds) diff --git a/packages/grid/seaweedfs/tests/mount_test.py b/packages/grid/seaweedfs/tests/mount_test.py index 12f4102f268..654dcab56a9 100644 --- a/packages/grid/seaweedfs/tests/mount_test.py +++ b/packages/grid/seaweedfs/tests/mount_test.py @@ -1,8 +1,8 @@ # stdlib -from pathlib import Path import re -from secrets import token_hex import shutil +from pathlib import Path +from secrets import token_hex from subprocess import CompletedProcess # third party @@ -11,8 +11,7 @@ # first party # from src.mount import create_mount_dotenv -from src.mount import mount_bucket -from src.mount import seaweed_safe_config_name +from src.mount import mount_bucket, seaweed_safe_config_name from src.mount_options import MountOptions @@ -29,9 +28,7 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: fake_process.register([fake_process.any()], callback=subprocess_cb) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "s3", "bucket_name": token_hex(8), "creds": { @@ -39,7 +36,6 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: "aws_secret_access_key": token_hex(8), }, }, - } ) result = mount_bucket(opts, random_path) conf_path = result["path"] @@ -60,9 +56,7 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: fake_process.register([fake_process.any()], callback=subprocess_cb) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "gcs", "bucket_name": token_hex(8), "creds": { @@ -74,7 +68,6 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: "universe_domain": "googleapis.com", }, }, - } ) result = mount_bucket(opts, random_path) @@ -97,9 +90,7 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: fake_process.register([fake_process.any()], callback=subprocess_cb) opts = MountOptions( - **{ - "local_bucket": token_hex(8), - "remote_bucket": { + local_bucket=token_hex(8), remote_bucket={ "type": "azure", "container_name": token_hex(8), "creds": { @@ -107,7 +98,6 @@ def subprocess_cb(process: FakePopen) -> CompletedProcess: "azure_account_key": token_hex(8), }, }, - } ) result = mount_bucket(opts, random_path) conf_path = result["path"] diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 44f2efc202c..04efc3c18f0 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -1,38 +1,29 @@ __version__ = "0.9.1-beta.1" # stdlib -from collections.abc import Callable import pathlib -from pathlib import Path import sys +from collections.abc import Callable +from pathlib import Path from typing import Any # relative -from .abstract_server import ServerSideType -from .abstract_server import ServerType -from .client.client import connect -from .client.client import login -from .client.client import login_as_guest -from .client.client import register +from .abstract_server import ServerSideType, ServerType +from .client.client import connect, login, login_as_guest, register from .client.datasite_client import DatasiteClient from .client.gateway_client import GatewayClient -from .client.registry import DatasiteRegistry -from .client.registry import EnclaveRegistry -from .client.registry import NetworkRegistry -from .client.search import Search -from .client.search import SearchResults -from .client.syncing import compare_clients -from .client.syncing import compare_states -from .client.syncing import sync -from .client.user_settings import UserSettings -from .client.user_settings import settings -from .custom_worker.config import DockerWorkerConfig -from .custom_worker.config import PrebuiltWorkerConfig +from .client.registry import DatasiteRegistry, EnclaveRegistry, NetworkRegistry +from .client.search import Search, SearchResults +from .client.syncing import compare_clients, compare_states, sync +from .client.user_settings import UserSettings, settings +from .custom_worker.config import DockerWorkerConfig, PrebuiltWorkerConfig from .orchestra import Orchestra as orchestra -from .protocol.data_protocol import bump_protocol_version -from .protocol.data_protocol import check_or_stage_protocol -from .protocol.data_protocol import get_data_protocol -from .protocol.data_protocol import stage_protocol_changes +from .protocol.data_protocol import ( + bump_protocol_version, + check_or_stage_protocol, + get_data_protocol, + stage_protocol_changes, +) from .serde import NOTHING from .serde.deserialize import _deserialize as deserialize from .serde.serializable import serializable @@ -46,47 +37,43 @@ from .server.worker import Worker from .service.action.action_data_empty import ActionDataEmpty from .service.action.action_object import ActionObject -from .service.action.plan import Plan -from .service.action.plan import planify -from .service.api.api import api_endpoint -from .service.api.api import api_endpoint_method +from .service.action.plan import Plan, planify +from .service.api.api import api_endpoint, api_endpoint_method from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint -from .service.code.user_code import UserCodeStatus -from .service.code.user_code import syft_function -from .service.code.user_code import syft_function_single_use +from .service.code.user_code import ( + UserCodeStatus, + syft_function, + syft_function_single_use, +) from .service.data_subject import DataSubjectCreate as DataSubject from .service.dataset.dataset import Contributor from .service.dataset.dataset import CreateAsset as Asset from .service.dataset.dataset import CreateDataset as Dataset from .service.notification.notifications import NotificationStatus from .service.policy.policy import CreatePolicyRuleConstant as Constant -from .service.policy.policy import CustomInputPolicy -from .service.policy.policy import CustomOutputPolicy -from .service.policy.policy import ExactMatch -from .service.policy.policy import MixedInputPolicy -from .service.policy.policy import SingleExecutionExactOutput -from .service.policy.policy import UserInputPolicy -from .service.policy.policy import UserOutputPolicy +from .service.policy.policy import ( + CustomInputPolicy, + CustomOutputPolicy, + ExactMatch, + MixedInputPolicy, + SingleExecutionExactOutput, + UserInputPolicy, + UserOutputPolicy, +) from .service.project.project import ProjectSubmit as Project from .service.request.request import SubmitRequest as Request -from .service.response import SyftError -from .service.response import SyftNotReady -from .service.response import SyftSuccess +from .service.response import SyftError, SyftNotReady, SyftSuccess from .service.user.roles import Roles as roles from .service.user.user_service import UserService from .stable_version import LATEST_STABLE_SYFT from .types.twin_object import TwinObject from .types.uid import UID -from .util import filterwarnings -from .util import options -from .util.autoreload import disable_autoreload -from .util.autoreload import enable_autoreload +from .util import filterwarnings, options +from .util.autoreload import disable_autoreload, enable_autoreload from .util.commit import __commit__ from .util.patch_ipython import patch_ipython from .util.telemetry import instrument -from .util.util import autocache -from .util.util import get_nb_secrets -from .util.util import get_root_data_path +from .util.util import autocache, get_nb_secrets, get_root_data_path from .util.version_compare import make_requires requires = make_requires(LATEST_STABLE_SYFT, __version__) @@ -103,7 +90,8 @@ def module_property(func: Any) -> Callable: """Decorator to turn module functions into properties. - Function names must be prefixed with an underscore.""" + Function names must be prefixed with an underscore. + """ module = sys.modules[func.__module__] def base_getattr(name: str) -> None: diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 0ff9299bdfe..9c690d4aabf 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -1,69 +1,53 @@ # future from __future__ import annotations +import inspect +import types + # stdlib from collections import OrderedDict from collections.abc import Callable -import inspect -from inspect import Parameter -from inspect import signature -import types -from typing import Any -from typing import TYPE_CHECKING -from typing import _GenericAlias -from typing import cast -from typing import get_args -from typing import get_origin +from inspect import Parameter, signature +from typing import TYPE_CHECKING, Any, _GenericAlias, cast, get_args, get_origin # third party from nacl.exceptions import BadSignatureError -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import EmailStr -from pydantic import TypeAdapter -from result import OkErr -from result import Result -from typeguard import TypeCheckError -from typeguard import check_type +from pydantic import BaseModel, ConfigDict, EmailStr, TypeAdapter +from result import OkErr, Result +from typeguard import TypeCheckError, check_type # relative from ..abstract_server import AbstractServer -from ..protocol.data_protocol import PROTOCOL_TYPE -from ..protocol.data_protocol import get_data_protocol -from ..protocol.data_protocol import migrate_args_and_kwargs +from ..protocol.data_protocol import ( + PROTOCOL_TYPE, + get_data_protocol, + migrate_args_and_kwargs, +) from ..serde.deserialize import _deserialize from ..serde.serializable import serializable from ..serde.serialize import _serialize -from ..serde.signature import Signature -from ..serde.signature import signature_remove_context -from ..serde.signature import signature_remove_self -from ..server.credentials import SyftSigningKey -from ..server.credentials import SyftVerifyKey -from ..service.context import AuthedServiceContext -from ..service.context import ChangeContext +from ..serde.signature import Signature, signature_remove_context, signature_remove_self +from ..server.credentials import SyftSigningKey, SyftVerifyKey +from ..service.context import AuthedServiceContext, ChangeContext from ..service.metadata.server_metadata import ServerMetadataJSON -from ..service.response import SyftAttributeError -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..service.service import UserLibConfigRegistry -from ..service.service import UserServiceConfigRegistry +from ..service.response import SyftAttributeError, SyftError, SyftSuccess +from ..service.service import UserLibConfigRegistry, UserServiceConfigRegistry from ..service.user.user_roles import ServiceRole -from ..service.warnings import APIEndpointWarning -from ..service.warnings import WarningContext +from ..service.warnings import APIEndpointWarning, WarningContext from ..types.cache_object import CachedSyftObject from ..types.identity import Identity -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftMigrationRegistry -from ..types.syft_object import SyftObject -from ..types.uid import LineageID -from ..types.uid import UID +from ..types.syft_object import ( + SYFT_OBJECT_VERSION_1, + SyftBaseObject, + SyftMigrationRegistry, + SyftObject, +) +from ..types.uid import UID, LineageID from ..util.autoreload import autoreload_enabled from ..util.markdown import as_markdown_python_code from ..util.notebook_ui.components.tabulator_template import build_tabulator_table from ..util.telemetry import instrument -from ..util.util import index_syft_by_module_name -from ..util.util import prompt_warning_message +from ..util.util import index_syft_by_module_name, prompt_warning_message from .connection import ServerConnection if TYPE_CHECKING: @@ -185,7 +169,7 @@ def message(self) -> SyftAPICall: if self.cached_deseralized_message is None: self.cached_deseralized_message = _deserialize( - blob=self.serialized_message, from_bytes=True + blob=self.serialized_message, from_bytes=True, ) return self.cached_deseralized_message @@ -194,7 +178,7 @@ def message(self) -> SyftAPICall: def is_valid(self) -> Result[SyftSuccess, SyftError]: try: _ = self.credentials.verify_key.verify( - self.serialized_message, self.signature + self.serialized_message, self.signature, ) except BadSignatureError: return SyftError(message="BadSignatureError") @@ -274,7 +258,7 @@ def __ipython_inspector_signature_override__(self) -> Signature | None: return self.signature def prepare_args_and_kwargs( - self, args: list | tuple, kwargs: dict[str, Any] + self, args: list | tuple, kwargs: dict[str, Any], ) -> SyftError | tuple[tuple, dict[str, Any]]: # Validate and migrate args and kwargs res = validate_callable_args_and_kwargs(args, kwargs, self.signature) @@ -283,17 +267,17 @@ def prepare_args_and_kwargs( args, kwargs = res args, kwargs = migrate_args_and_kwargs( - to_protocol=self.communication_protocol, args=args, kwargs=kwargs + to_protocol=self.communication_protocol, args=args, kwargs=kwargs, ) return tuple(args), kwargs def function_call( - self, path: str, *args: Any, cache_result: bool = True, **kwargs: Any + self, path: str, *args: Any, cache_result: bool = True, **kwargs: Any, ) -> Any: if "blocking" in self.signature.parameters: raise Exception( - f"Signature {self.signature} can't have 'blocking' kwarg because it's reserved" + f"Signature {self.signature} can't have 'blocking' kwarg because it's reserved", ) blocking = True @@ -321,7 +305,7 @@ def function_call( allowed = self.warning.show() if self.warning else True if not allowed: - return + return None result = self.make_call(api_call=api_call, cache_result=cache_result) # TODO: annotate this on the service method decorator @@ -332,7 +316,7 @@ def function_call( self.refresh_api_callback() result, _ = migrate_args_and_kwargs( - [result], kwargs={}, to_latest_protocol=True + [result], kwargs={}, to_latest_protocol=True, ) result = result[0] return result @@ -348,7 +332,7 @@ def mock(self) -> Any: class PrivateCustomAPIReference: def __call__(self, *args: Any, **kwargs: Any) -> Any: return remote_func.function_call( - "api.call_public_in_jobs", *args, **kwargs + "api.call_public_in_jobs", *args, **kwargs, ) @property @@ -357,7 +341,7 @@ def context(self) -> Any: return PrivateCustomAPIReference() return SyftError( - message="This function doesn't support mock/private calls as it's not custom." + message="This function doesn't support mock/private calls as it's not custom.", ) @property @@ -368,7 +352,7 @@ def private(self) -> Any: class PrivateCustomAPIReference: def __call__(self, *args: Any, **kwargs: Any) -> Any: return remote_func.function_call( - "api.call_private_in_jobs", *args, **kwargs + "api.call_private_in_jobs", *args, **kwargs, ) @property @@ -377,7 +361,7 @@ def context(self) -> Any: return PrivateCustomAPIReference() return SyftError( - message="This function doesn't support mock/private calls as it's not custom." + message="This function doesn't support mock/private calls as it's not custom.", ) def custom_function_actionobject_id(self) -> UID | SyftError: @@ -445,7 +429,7 @@ class RemoteUserCodeFunction(RemoteFunction): api: SyftAPI def prepare_args_and_kwargs( - self, args: list | tuple, kwargs: dict[str, Any] + self, args: list | tuple, kwargs: dict[str, Any], ) -> tuple[tuple, dict[str, Any]] | SyftError: # relative from ..service.action.action_object import convert_to_pointers @@ -474,7 +458,7 @@ def prepare_args_and_kwargs( ) args, kwargs = migrate_args_and_kwargs( - to_protocol=self.communication_protocol, args=args, kwargs=kwargs + to_protocol=self.communication_protocol, args=args, kwargs=kwargs, ) return tuple(args), kwargs @@ -512,7 +496,7 @@ def generate_remote_function( ) -> RemoteFunction: if "blocking" in signature.parameters: raise Exception( - f"Signature {signature} can't have 'blocking' kwarg because it's reserved" + f"Signature {signature} can't have 'blocking' kwarg because it's reserved", ) # UserCodes are always code.call with a user_code_id @@ -557,7 +541,7 @@ def generate_remote_lib_function( ) -> Any: if "blocking" in signature.parameters: raise Exception( - f"Signature {signature} can't have 'blocking' kwarg because its reserved" + f"Signature {signature} can't have 'blocking' kwarg because its reserved", ) def wrapper(*args: Any, **kwargs: Any) -> SyftError | Any: @@ -588,12 +572,14 @@ def wrapper(*args: Any, **kwargs: Any) -> SyftError | Any: _valid_kwargs.update(pre_kwargs) # relative - from ..service.action.action_object import Action - from ..service.action.action_object import ActionType - from ..service.action.action_object import convert_to_pointers + from ..service.action.action_object import ( + Action, + ActionType, + convert_to_pointers, + ) action_args, action_kwargs = convert_to_pointers( - api, wrapper_server_uid, _valid_args, _valid_kwargs + api, wrapper_server_uid, _valid_args, _valid_kwargs, ) # e.g. numpy.array -> numpy, array @@ -664,7 +650,7 @@ def has_submodule(self, name: str) -> bool: return False def _add_submodule( - self, attr_name: str, module_or_func: Callable | APIModule + self, attr_name: str, module_or_func: Callable | APIModule, ) -> None: setattr(self, attr_name, module_or_func) self._modules.append(attr_name) @@ -697,7 +683,7 @@ def __getattr__(self, name: str) -> Any: raise SyftAttributeError( f"'APIModule' api{self.path} object has no submodule or method '{name}', " "you may not have permission to access the module you are trying to access." - "If you think this is an error, try calling `client.refresh()` to update the API." + "If you think this is an error, try calling `client.refresh()` to update the API.", ) def __getitem__(self, key: str | int) -> Any: @@ -729,7 +715,7 @@ def recursively_get_submodules( x.path for x in children if isinstance(x, RemoteFunction) ] views.append( - APISubModulesView(submodule=submodule_name, endpoints=child_paths) + APISubModulesView(submodule=submodule_name, endpoints=child_paths), ) return build_tabulator_table(views) @@ -757,7 +743,7 @@ def downgrade_signature(signature: Signature, object_versions: dict) -> Signatur migrated_parameters = [] for parameter in signature.parameters.values(): annotation = unwrap_and_migrate_annotation( - parameter.annotation, object_versions + parameter.annotation, object_versions, ) migrated_parameter = Parameter( name=parameter.name, @@ -768,7 +754,7 @@ def downgrade_signature(signature: Signature, object_versions: dict) -> Signatur migrated_parameters.append(migrated_parameter) migrated_return_annotation = unwrap_and_migrate_annotation( - signature.return_annotation, object_versions + signature.return_annotation, object_versions, ) try: @@ -792,7 +778,7 @@ def unwrap_and_migrate_annotation(annotation: Any, object_versions: dict) -> Any and annotation.__canonical_name__ in object_versions ): downgrade_to_version = int( - max(object_versions[annotation.__canonical_name__]) + max(object_versions[annotation.__canonical_name__]), ) downgrade_klass_name = SyftMigrationRegistry.__migration_version_registry__[ annotation.__canonical_name__ @@ -819,8 +805,7 @@ def unwrap_and_migrate_annotation(annotation: Any, object_versions: dict) -> Any def result_needs_api_update(api_call_result: Any) -> bool: # relative - from ..service.request.request import Request - from ..service.request.request import UserCodeStatusChange + from ..service.request.request import Request, UserCodeStatusChange if isinstance(api_call_result, Request) and any( isinstance(x, UserCodeStatusChange) for x in api_call_result.changes @@ -839,7 +824,7 @@ def result_needs_api_update(api_call_result: Any) -> bool: "server_name", "lib_endpoints", "communication_protocol", - ] + ], ) class SyftAPI(SyftObject): # version @@ -879,7 +864,7 @@ def __getattr__(self, name: str) -> Any: raise SyftAttributeError( f"'SyftAPI' object has no submodule or method '{name}', " "you may not have permission to access the module you are trying to access." - "If you think this is an error, try calling `client.refresh()` to update the API." + "If you think this is an error, try calling `client.refresh()` to update the API.", ) @staticmethod @@ -902,7 +887,7 @@ def for_user( endpoints: dict[str, APIEndpoint] = {} lib_endpoints: dict[str, LibEndpoint] = {} warning_context = WarningContext( - server=server, role=role, credentials=user_verify_key + server=server, role=role, credentials=user_verify_key, ) # If server uses a higher protocol version than client, then @@ -912,13 +897,13 @@ def for_user( signature_needs_downgrade = True else: signature_needs_downgrade = server.current_protocol != "dev" and int( - server.current_protocol + server.current_protocol, ) > int(communication_protocol) data_protocol = get_data_protocol() if signature_needs_downgrade: object_version_for_protocol = data_protocol.get_object_versions( - communication_protocol + communication_protocol, ) for ( @@ -970,7 +955,7 @@ def for_user( # 🟡 TODO 35: fix root context context = AuthedServiceContext(server=server, credentials=user_verify_key) method = server.get_method_with_context( - UserCodeService.get_all_for_user, context + UserCodeService.get_all_for_user, context, ) code_items = method() @@ -1061,10 +1046,9 @@ def update_api(self, api_call_result: Any) -> None: self.refresh_api_callback() def _add_route( - self, api_module: APIModule, endpoint: APIEndpoint, endpoint_method: Callable + self, api_module: APIModule, endpoint: APIEndpoint, endpoint_method: Callable, ) -> None: """Recursively create a module path to the route endpoint.""" - _modules = endpoint.module_path.split(".")[:-1] + [endpoint.name] _self = api_module @@ -1078,7 +1062,7 @@ def _add_route( _self._add_submodule( module, APIModule( - path=submodule_path, refresh_callback=self.refresh_api_callback + path=submodule_path, refresh_callback=self.refresh_api_callback, ), ) _self = getattr(_self, module) @@ -1086,7 +1070,7 @@ def _add_route( def generate_endpoints(self) -> None: def build_endpoint_tree( - endpoints: dict[str, LibEndpoint], communication_protocol: PROTOCOL_TYPE + endpoints: dict[str, LibEndpoint], communication_protocol: PROTOCOL_TYPE, ) -> APIModule: api_module = APIModule(path="", refresh_callback=self.refresh_api_callback) for v in endpoints.values(): @@ -1123,10 +1107,10 @@ def build_endpoint_tree( if self.lib_endpoints is not None: self.libs = build_endpoint_tree( - self.lib_endpoints, self.communication_protocol + self.lib_endpoints, self.communication_protocol, ) self.api_module = build_endpoint_tree( - self.endpoints, self.communication_protocol + self.endpoints, self.communication_protocol, ) @property @@ -1159,7 +1143,7 @@ def __repr__(self) -> str: for func_name in module_or_func._modules: func = getattr(module_or_func, func_name) sig = getattr( - func, "__ipython_inspector_signature_override__", "" + func, "__ipython_inspector_signature_override__", "", ) _repr_str += f"{module_path_str}.{func_name}{sig}\n\n" return _repr_str @@ -1168,8 +1152,7 @@ def __repr__(self) -> str: # code from here: # https://github.com/ipython/ipython/blob/339c0d510a1f3cb2158dd8c6e7f4ac89aa4c89d8/IPython/core/oinspect.py#L370 def _render_signature(obj_signature: Signature, obj_name: str) -> str: - """ - This was mostly taken from inspect.Signature.__str__. + """This was mostly taken from inspect.Signature.__str__. Look there for the comments. The only change is to add linebreaks when this gets too long. """ @@ -1211,7 +1194,8 @@ def _render_signature(obj_signature: Signature, obj_name: str) -> str: def _getdef(self: Any, obj: Any, oname: str = "") -> str | None: """Return the call signature for any callable object. If any exception is generated, None is returned instead and the - exception is suppressed.""" + exception is suppressed. + """ try: return _render_signature(signature(obj), oname) except: # noqa: E722 @@ -1222,7 +1206,7 @@ def monkey_patch_getdef(self: Any, obj: Any, oname: str = "") -> str | None: try: if hasattr(obj, "__ipython_inspector_signature_override__"): return _render_signature( - obj.__ipython_inspector_signature_override__, oname + obj.__ipython_inspector_signature_override__, oname, ) return _getdef(self, obj, oname) except Exception: @@ -1250,7 +1234,7 @@ def from_api(api: SyftAPI) -> ServerIdentity: # stores the name root verify key of the datasite server if api.connection is None: raise ValueError( - "{api}'s connection is None. Can't get the server identity" + "{api}'s connection is None. Can't get the server identity", ) server_metadata = api.connection.get_server_metadata(api.signing_key) return ServerIdentity( @@ -1277,7 +1261,7 @@ def from_server(cls, server: Server) -> ServerIdentity: verify_key=server.signing_key.verify_key, ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ServerIdentity): return False return ( @@ -1294,7 +1278,7 @@ def __repr__(self) -> str: def validate_callable_args_and_kwargs( - args: list, kwargs: dict, signature: Signature + args: list, kwargs: dict, signature: Signature, ) -> tuple[list, dict] | SyftError: _valid_kwargs = {} if "kwargs" in signature.parameters: @@ -1303,7 +1287,7 @@ def validate_callable_args_and_kwargs( for key, value in kwargs.items(): if key not in signature.parameters: return SyftError( - message=f"""Invalid parameter: `{key}`. Valid Parameters: {list(signature.parameters)}""" + message=f"""Invalid parameter: `{key}`. Valid Parameters: {list(signature.parameters)}""", ) param = signature.parameters[key] if isinstance(param.annotation, str): @@ -1329,7 +1313,7 @@ def validate_callable_args_and_kwargs( except Exception: _type_str = getattr(t, "__name__", str(t)) return SyftError( - message=f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`" + message=f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`", ) _valid_kwargs[key] = value diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 243a3d5dc9a..9f8fd0025fc 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -3,66 +3,51 @@ # stdlib import base64 -from collections.abc import Callable -from collections.abc import Generator -from collections.abc import Iterable -from enum import Enum -from getpass import getpass import json import logging -from typing import Any -from typing import TYPE_CHECKING -from typing import cast +from collections.abc import Callable, Generator, Iterable +from enum import Enum +from getpass import getpass +from typing import TYPE_CHECKING, Any, cast + +import requests # third party from argon2 import PasswordHasher -from cachetools import TTLCache -from cachetools import cached +from cachetools import TTLCache, cached from pydantic import field_validator -import requests -from requests import Response -from requests import Session +from requests import Response, Session from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry # type: ignore[import-untyped] from typing_extensions import Self # relative from .. import __version__ -from ..abstract_server import AbstractServer -from ..abstract_server import ServerSideType -from ..abstract_server import ServerType -from ..protocol.data_protocol import DataProtocol -from ..protocol.data_protocol import PROTOCOL_TYPE -from ..protocol.data_protocol import get_data_protocol +from ..abstract_server import AbstractServer, ServerSideType, ServerType +from ..protocol.data_protocol import PROTOCOL_TYPE, DataProtocol, get_data_protocol from ..serde.deserialize import _deserialize from ..serde.serializable import serializable from ..serde.serialize import _serialize -from ..server.credentials import SyftSigningKey -from ..server.credentials import SyftVerifyKey -from ..server.credentials import UserLoginCredentials +from ..server.credentials import SyftSigningKey, SyftVerifyKey, UserLoginCredentials from ..service.context import ServerServiceContext -from ..service.metadata.server_metadata import ServerMetadata -from ..service.metadata.server_metadata import ServerMetadataJSON -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..service.user.user import UserCreate -from ..service.user.user import UserPrivateKey -from ..service.user.user import UserView +from ..service.metadata.server_metadata import ServerMetadata, ServerMetadataJSON +from ..service.response import SyftError, SyftSuccess +from ..service.user.user import UserCreate, UserPrivateKey, UserView from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..types.server_url import ServerURL from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.uid import UID from ..util.telemetry import instrument -from ..util.util import prompt_warning_message -from ..util.util import thread_ident -from ..util.util import verify_tls -from .api import APIModule -from .api import APIRegistry -from .api import SignedSyftAPICall -from .api import SyftAPI -from .api import SyftAPICall -from .api import debox_signed_syftapicall_response +from ..util.util import prompt_warning_message, thread_ident, verify_tls +from .api import ( + APIModule, + APIRegistry, + SignedSyftAPICall, + SyftAPI, + SyftAPICall, + debox_signed_syftapicall_response, +) from .connection import ServerConnection from .protocol import SyftProtocol @@ -193,7 +178,7 @@ def session(self) -> Session: return self.session_cache def _make_get( - self, path: str, params: dict | None = None, stream: bool = False + self, path: str, params: dict | None = None, stream: bool = False, ) -> bytes | Iterable: if params is None: return self._make_get_no_params(path, stream=stream) @@ -217,7 +202,7 @@ def _make_get( ) if response.status_code != 200: raise requests.ConnectionError( - f"Failed to fetch {url}. Response returned with code {response.status_code}" + f"Failed to fetch {url}. Response returned with code {response.status_code}", ) # upgrade to tls if available @@ -245,7 +230,7 @@ def _make_get_no_params(self, path: str, stream: bool = False) -> bytes | Iterab ) if response.status_code != 200: raise requests.ConnectionError( - f"Failed to fetch {url}. Response returned with code {response.status_code}" + f"Failed to fetch {url}. Response returned with code {response.status_code}", ) # upgrade to tls if available @@ -257,7 +242,7 @@ def _make_get_no_params(self, path: str, stream: bool = False) -> bytes | Iterab return response.content def _make_put( - self, path: str, data: bytes | Generator, stream: bool = False + self, path: str, data: bytes | Generator, stream: bool = False, ) -> Response: url = self.url @@ -277,7 +262,7 @@ def _make_put( ) if response.status_code != 200: raise requests.ConnectionError( - f"Failed to fetch {url}. Response returned with code {response.status_code}" + f"Failed to fetch {url}. Response returned with code {response.status_code}", ) # upgrade to tls if available @@ -309,7 +294,7 @@ def _make_post( ) if response.status_code != 200: raise requests.ConnectionError( - f"Failed to fetch {url}. Response returned with code {response.status_code}" + f"Failed to fetch {url}. Response returned with code {response.status_code}", ) # upgrade to tls if available @@ -320,12 +305,12 @@ def _make_post( def stream_data(self, credentials: SyftSigningKey) -> Response: url = self.url.with_path(self.routes.STREAM.value) response = self.session.get( - str(url), verify=verify_tls(), proxies={}, stream=True, headers=self.headers + str(url), verify=verify_tls(), proxies={}, stream=True, headers=self.headers, ) return response def get_server_metadata( - self, credentials: SyftSigningKey + self, credentials: SyftSigningKey, ) -> ServerMetadataJSON | SyftError: if self.proxy_target_uid: response = forward_message_to_proxy( @@ -424,7 +409,7 @@ def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError: if response.status_code != 200: raise requests.ConnectionError( - f"Failed to fetch metadata. Response returned with code {response.status_code}" + f"Failed to fetch metadata. Response returned with code {response.status_code}", ) result = _deserialize(response.content, from_bytes=True) @@ -471,7 +456,7 @@ def with_proxy(self, proxy_target_uid: UID) -> Self: return PythonConnection(server=self.server, proxy_target_uid=proxy_target_uid) def get_server_metadata( - self, credentials: SyftSigningKey + self, credentials: SyftSigningKey, ) -> ServerMetadataJSON | SyftError: if self.proxy_target_uid: response = forward_message_to_proxy( @@ -497,7 +482,7 @@ def get_api( communication_protocol: int, metadata: ServerMetadataJSON | None = None, ) -> SyftAPI: - # todo: its a bit odd to identify a user by its verify key maybe? + # TODO: its a bit odd to identify a user by its verify key maybe? if self.proxy_target_uid: obj = forward_message_to_proxy( self.make_call, @@ -527,10 +512,10 @@ def get_cache_key(self) -> str: def exchange_credentials(self, email: str, password: str) -> UserPrivateKey | None: context = self.server.get_unauthed_context( - login_credentials=UserLoginCredentials(email=email, password=password) + login_credentials=UserLoginCredentials(email=email, password=password), ) method = self.server.get_method_with_context( - UserService.exchange_credentials, context + UserService.exchange_credentials, context, ) result = method() return result @@ -638,7 +623,7 @@ def post_init(self) -> None: self._fetch_server_metadata(self.credentials) self.metadata = cast(ServerMetadataJSON, self.metadata) self.communication_protocol = self._get_communication_protocol( - self.metadata.supported_protocols + self.metadata.supported_protocols, ) def set_headers(self, headers: dict[str, str]) -> None | SyftError: @@ -647,11 +632,11 @@ def set_headers(self, headers: dict[str, str]) -> None | SyftError: return None return SyftError( # type: ignore message="Incompatible connection type." - + f"Expected HTTPConnection, got {type(self.connection)}" + + f"Expected HTTPConnection, got {type(self.connection)}", ) def _get_communication_protocol( - self, protocols_supported_by_server: list + self, protocols_supported_by_server: list, ) -> int | str: data_protocol: DataProtocol = get_data_protocol() protocols_supported_by_client: list[PROTOCOL_TYPE] = ( @@ -660,12 +645,12 @@ def _get_communication_protocol( self.current_protocol = data_protocol.latest_version common_protocols = set(protocols_supported_by_client).intersection( - protocols_supported_by_server + protocols_supported_by_server, ) if len(common_protocols) == 0: raise Exception( - "No common communication protocol found between the client and the server." + "No common communication protocol found between the client and the server.", ) if "dev" in common_protocols: @@ -673,7 +658,7 @@ def _get_communication_protocol( return max(common_protocols) def create_project( - self, name: str, description: str, user_email_address: str + self, name: str, description: str, user_email_address: str, ) -> Any: # relative from ..service.project.project import ProjectSubmit @@ -778,7 +763,7 @@ def exchange_route( ) else: raise ValueError( - f"Invalid Route Exchange SyftProtocol: {protocol}.Supported protocols are {SyftProtocol.all()}" + f"Invalid Route Exchange SyftProtocol: {protocol}.Supported protocols are {SyftProtocol.all()}", ) @property @@ -841,7 +826,7 @@ def login_as_guest(self) -> Self: if self.metadata is not None: print( f"Logged into <{self.name}: {self.metadata.server_side_type.capitalize()}-side " - f"{self.metadata.server_type.capitalize()}> as GUEST" + f"{self.metadata.server_type.capitalize()}> as GUEST", ) return _guest_client @@ -853,7 +838,7 @@ def login_as(self, email: str) -> Self: if self.metadata is not None: print( f"Logged into <{self.name}: {self.metadata.server_side_type.capitalize()}-side " - f"{self.metadata.server_type.capitalize()}> as {email}" + f"{self.metadata.server_type.capitalize()}> as {email}", ) return self.__class__( @@ -877,7 +862,7 @@ def login( if register: self.register( - email=email, password=password, password_verify=password, **kwargs + email=email, password=password, password_verify=password, **kwargs, ) user_private_key = self.connection.login(email=email, password=password) @@ -901,7 +886,7 @@ def login( if signing_key is not None and client.metadata is not None: print( f"Logged into <{client.name}: {client.metadata.server_side_type.capitalize()} side " - f"{client.metadata.server_type.capitalize()}> as <{email}>" + f"{client.metadata.server_type.capitalize()}> as <{email}>", ) # relative from ..server.server import get_default_root_password @@ -992,7 +977,7 @@ def register( "host datasets with private information." ) if self.metadata.show_warnings and not prompt_warning_message( - message=message + message=message, ): return None @@ -1004,7 +989,7 @@ def register( def __hash__(self) -> int: return hash(self.id) + hash(self.connection) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SyftClient): return False return ( @@ -1121,7 +1106,7 @@ def login_as_guest( if verbose and _client.metadata is not None: print( f"Logged into <{_client.name}: {_client.metadata.server_side_type.capitalize()}-" - f"side {_client.metadata.server_type.capitalize()}> as GUEST" + f"side {_client.metadata.server_type.capitalize()}> as GUEST", ) return _client.guest() @@ -1163,7 +1148,7 @@ def login( ) if _client_cache: print( - f"Using cached client for {_client.name} as <{login_credentials.email}>" + f"Using cached client for {_client.name} as <{login_credentials.email}>", ) _client = _client_cache @@ -1185,7 +1170,7 @@ class SyftClientSessionCache: @classmethod def _get_key(cls, email: str, password: str, connection: str) -> str: key = cls.__cache_key_format__.format( - email=email, password=password, connection=connection + email=email, password=password, connection=connection, ) ph = PasswordHasher() return ph.hash(key) @@ -1214,14 +1199,14 @@ def add_client_by_uid_and_verify_key( @classmethod def get_client_by_uid_and_verify_key( - cls, verify_key: SyftVerifyKey, server_uid: UID + cls, verify_key: SyftVerifyKey, server_uid: UID, ) -> SyftClient | None: hash_key = str(server_uid) + str(verify_key) return cls.__client_cache__.get(hash_key, None) @classmethod def get_client( - cls, email: str, password: str, connection: ServerConnection + cls, email: str, password: str, connection: ServerConnection, ) -> SyftClient | None: # we have some bugs here so lets disable until they are fixed. return None diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py index 1bb0ea6ceb8..d29e8fe3b7d 100644 --- a/packages/syft/src/syft/client/connection.py +++ b/packages/syft/src/syft/client/connection.py @@ -2,8 +2,7 @@ from typing import Any # relative -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftObject +from ..types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ..types.uid import UID diff --git a/packages/syft/src/syft/client/datasite_client.py b/packages/syft/src/syft/client/datasite_client.py index fdd30bbed03..06aa26f5d86 100644 --- a/packages/syft/src/syft/client/datasite_client.py +++ b/packages/syft/src/syft/client/datasite_client.py @@ -3,12 +3,11 @@ # stdlib import logging -from pathlib import Path import re -from string import Template import traceback -from typing import TYPE_CHECKING -from typing import cast +from pathlib import Path +from string import Template +from typing import TYPE_CHECKING, cast # third party import markdown @@ -19,15 +18,13 @@ from ..abstract_server import ServerSideType from ..serde.serializable import serializable from ..service.action.action_object import ActionObject -from ..service.code_history.code_history import CodeHistoriesDict -from ..service.code_history.code_history import UsersCodeHistoriesDict -from ..service.dataset.dataset import Contributor -from ..service.dataset.dataset import CreateAsset -from ..service.dataset.dataset import CreateDataset +from ..service.code_history.code_history import ( + CodeHistoriesDict, + UsersCodeHistoriesDict, +) +from ..service.dataset.dataset import Contributor, CreateAsset, CreateDataset from ..service.migration.object_migration_state import MigrationData -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..service.response import SyftWarning +from ..service.response import SyftError, SyftSuccess, SyftWarning from ..service.sync.diff_state import ResolvedSyncState from ..service.sync.sync_state import SyncState from ..service.user.roles import Roles @@ -35,12 +32,9 @@ from ..types.blob_storage import BlobFile from ..types.uid import UID from ..util.misc_objs import HTMLObject -from ..util.util import get_mb_size -from ..util.util import prompt_warning_message +from ..util.util import get_mb_size, prompt_warning_message from .api import APIModule -from .client import SyftClient -from .client import login -from .client import login_as_guest +from .client import SyftClient, login, login_as_guest from .connection import ServerConnection from .protocol import SyftProtocol @@ -70,7 +64,7 @@ def _contains_subdir(dir: Path) -> bool: def add_default_uploader( - user: UserView, obj: CreateDataset | CreateAsset + user: UserView, obj: CreateDataset | CreateAsset, ) -> CreateDataset | CreateAsset: uploader = None for contributor in obj.contributors: @@ -129,7 +123,7 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: prompt_warning_message(message=message, confirm=True) with tqdm( - total=len(dataset.asset_list), colour="green", desc="Uploading" + total=len(dataset.asset_list), colour="green", desc="Uploading", ) as pbar: for asset in dataset.asset_list: try: @@ -150,7 +144,7 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: if isinstance(res, SyftWarning): logger.debug(res.message) response = self.api.services.action.set( - twin, ignore_detached_objs=contains_empty + twin, ignore_detached_objs=contains_empty, ) if isinstance(response, SyftError): tqdm.write(f"Failed to upload asset: {asset.name}") @@ -242,10 +236,10 @@ def upload_files( elif path.is_dir(): if not allow_recursive and _contains_subdir(path): res = input( - f"Do you want to include all files recursively in {path.absolute()}? [y/n]: " + f"Do you want to include all files recursively in {path.absolute()}? [y/n]: ", ).lower() print( - f'{"Recursively uploading all files" if res == "y" else "Uploading files"} in {path.absolute()}' + f'{"Recursively uploading all files" if res == "y" else "Uploading files"} in {path.absolute()}', ) allow_recursive = res == "y" expanded_file_list.extend(_get_files_from_dir(path, allow_recursive)) @@ -256,7 +250,7 @@ def upload_files( return SyftError(message="No files to upload were found") print( - f"Uploading {len(expanded_file_list)} {'file' if len(expanded_file_list) == 1 else 'files'}:" + f"Uploading {len(expanded_file_list)} {'file' if len(expanded_file_list) == 1 else 'files'}:", ) if show_files: @@ -279,7 +273,7 @@ def upload_files( return ActionObject.from_obj(result).send(self) except Exception as err: return SyftError( - message=f"Failed to upload files: {err}.\n{traceback.format_exc()}" + message=f"Failed to upload files: {err}.\n{traceback.format_exc()}", ) def connect_to_gateway( @@ -287,7 +281,7 @@ def connect_to_gateway( via_client: SyftClient | None = None, url: str | None = None, port: int | None = None, - handle: ServerHandle | None = None, # noqa: F821 + handle: ServerHandle | None = None, email: str | None = None, password: str | None = None, protocol: str | SyftProtocol = SyftProtocol.HTTP, @@ -321,7 +315,7 @@ def connect_to_gateway( f"Connected {self.metadata.server_type} " f"'{self.metadata.name}' to gateway '{client.name}'. " f"{res.message}" - ) + ), ) else: return SyftSuccess(message=f"Connected to '{client.name}' gateway") @@ -334,10 +328,10 @@ def _get_service_by_name_if_exists(self, name: str) -> APIModule | None: return None def set_server_side_type_dangerous( - self, server_side_type: str + self, server_side_type: str, ) -> Result[SyftSuccess, SyftError]: return self.api.services.settings.set_server_side_type_dangerous( - server_side_type + server_side_type, ) @property @@ -405,7 +399,7 @@ def migration(self) -> APIModule | None: return self._get_service_by_name_if_exists("migration") def get_migration_data( - self, include_blobs: bool = True + self, include_blobs: bool = True, ) -> MigrationData | SyftError: res = self.api.services.migration.get_migration_data() if isinstance(res, SyftError): @@ -425,12 +419,12 @@ def load_migration_data(self, path: str | Path) -> SyftSuccess | SyftError: if self.id != migration_data.server_uid: return SyftError( message=f"This Migration data is not for this server. Expected server id {self.id}, " - f"got {migration_data.server_uid}" + f"got {migration_data.server_uid}", ) if migration_data.signing_key.verify_key != self.verify_key: return SyftError( - message="Root verify key in migration data does not match this client's verify key" + message="Root verify key in migration data does not match this client's verify key", ) res = migration_data.migrate_and_upload_blobs() @@ -439,7 +433,7 @@ def load_migration_data(self, path: str | Path) -> SyftSuccess | SyftError: migration_data_without_blobs = migration_data.copy_without_blobs() return self.api.services.migration.apply_migration_data( - migration_data_without_blobs + migration_data_without_blobs, ) def get_project( @@ -448,7 +442,6 @@ def get_project( uid: UID | None = None, ) -> Project | None: """Get project by name or UID""" - if not self.api.has_service("project"): return None @@ -465,7 +458,7 @@ def _repr_html_(self) -> str: if isinstance(obj, SyftError): return obj.message updated_template_str = Template(obj.text).safe_substitute( - server_url=getattr(self.connection, "url", None) + server_url=getattr(self.connection, "url", None), ) # If it's a markdown structured file if not isinstance(obj, HTMLObject): diff --git a/packages/syft/src/syft/client/enclave_client.py b/packages/syft/src/syft/client/enclave_client.py index dc5da06ed3c..3ee12144717 100644 --- a/packages/syft/src/syft/client/enclave_client.py +++ b/packages/syft/src/syft/client/enclave_client.py @@ -9,16 +9,12 @@ from ..serde.serializable import serializable from ..service.metadata.server_metadata import ServerMetadataJSON from ..service.network.routes import ServerRouteType -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftObject +from ..service.response import SyftError, SyftSuccess +from ..types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ..util.assets import load_png_base64 from ..util.notebook_ui.styles import FONT_CSS from .api import APIModule -from .client import SyftClient -from .client import login -from .client import login_as_guest +from .client import SyftClient, login, login_as_guest from .protocol import SyftProtocol if TYPE_CHECKING: @@ -64,7 +60,7 @@ def connect_to_gateway( via_client: SyftClient | None = None, url: str | None = None, port: int | None = None, - handle: ServerHandle | None = None, # noqa: F821 + handle: ServerHandle | None = None, email: str | None = None, password: str | None = None, protocol: str | SyftProtocol = SyftProtocol.HTTP, @@ -95,7 +91,7 @@ def connect_to_gateway( f"Connected {self.metadata.server_type} " f"'{self.metadata.name}' to gateway '{client.name}'. " f"{res.message}" - ) + ), ) else: return SyftSuccess(message=f"Connected to '{client.name}' gateway") diff --git a/packages/syft/src/syft/client/gateway_client.py b/packages/syft/src/syft/client/gateway_client.py index 87958f8405b..5626cc66a43 100644 --- a/packages/syft/src/syft/client/gateway_client.py +++ b/packages/syft/src/syft/client/gateway_client.py @@ -2,16 +2,13 @@ from typing import Any # relative -from ..abstract_server import ServerSideType -from ..abstract_server import ServerType +from ..abstract_server import ServerSideType, ServerType from ..serde.serializable import serializable from ..server.credentials import SyftSigningKey from ..service.metadata.server_metadata import ServerMetadataJSON from ..service.network.server_peer import ServerPeer -from ..service.response import SyftError -from ..service.response import SyftException -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftObject +from ..service.response import SyftError, SyftException +from ..types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ..util.assets import load_png_base64 from ..util.notebook_ui.styles import FONT_CSS from .client import SyftClient @@ -29,7 +26,7 @@ def proxy_to(self, peer: Any) -> SyftClient: connection: type[ServerConnection] = self.connection.with_proxy(peer.id) metadata: ServerMetadataJSON | SyftError = connection.get_server_metadata( - credentials=SyftSigningKey.generate() + credentials=SyftSigningKey.generate(), ) if isinstance(metadata, SyftError): return metadata @@ -39,7 +36,7 @@ def proxy_to(self, peer: Any) -> SyftClient: client_type = EnclaveClient else: raise SyftException( - f"Unknown server type {metadata.server_type} to create proxy client" + f"Unknown server type {metadata.server_type} to create proxy client", ) client = client_type( @@ -163,14 +160,14 @@ class ProxyClient(SyftObject): def retrieve_servers(self) -> list[ServerPeer]: if self.server_type in [ServerType.DATASITE, ServerType.ENCLAVE]: return self.routing_client.api.services.network.get_peers_by_type( - server_type=self.server_type + server_type=self.server_type, ) elif self.server_type is None: # if server type is None, return all servers return self.routing_client.api.services.network.get_all_peers() else: raise SyftException( - f"Unknown server type {self.server_type} to retrieve proxy client" + f"Unknown server type {self.server_type} to retrieve proxy client", ) def _repr_html_(self) -> str: diff --git a/packages/syft/src/syft/client/registry.py b/packages/syft/src/syft/client/registry.py index 48627172eeb..6cadaccfdf8 100644 --- a/packages/syft/src/syft/client/registry.py +++ b/packages/syft/src/syft/client/registry.py @@ -1,11 +1,12 @@ # future from __future__ import annotations -# stdlib -from concurrent import futures import json import logging import os + +# stdlib +from concurrent import futures from typing import Any # third party @@ -14,8 +15,7 @@ # relative from ..service.metadata.server_metadata import ServerMetadataJSON -from ..service.network.server_peer import ServerPeer -from ..service.network.server_peer import ServerPeerConnectionStatus +from ..service.network.server_peer import ServerPeer, ServerPeerConnectionStatus from ..service.response import SyftException from ..types.server_url import ServerURL from ..util.constants import DEFAULT_TIMEOUT @@ -40,11 +40,11 @@ def __init__(self) -> None: try: network_json = self.load_network_registry_json() self.all_networks = _get_all_networks( - network_json=network_json, version="2.0.0" + network_json=network_json, version="2.0.0", ) except Exception as e: logger.warning( - f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. Exception: {e}" + f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. Exception: {e}", ) @staticmethod @@ -65,7 +65,7 @@ def load_network_registry_json() -> dict: except Exception as e: logger.warning( - f"Failed to get Network Registry from {NETWORK_REGISTRY_REPO}. Exception: {e}" + f"Failed to get Network Registry from {NETWORK_REGISTRY_REPO}. Exception: {e}", ) return {} @@ -91,7 +91,7 @@ def check_network(network: dict) -> dict[Any, Any] | None: online = False if online: - version = network.get("version", None) + version = network.get("version") # Check if syft version was described in NetworkRegistry # If it's unknown, try to update it to an available version. if not version or version == "unknown": @@ -112,7 +112,7 @@ def check_network(network: dict) -> dict[Any, Any] | None: with futures.ThreadPoolExecutor(max_workers=20) as executor: # map _online_networks = list( - executor.map(lambda network: check_network(network), networks) + executor.map(lambda network: check_network(network), networks), ) return [network for network in _online_networks if network is not None] @@ -125,9 +125,9 @@ def _repr_html_(self) -> str: total_df = pd.DataFrame( [ [ - f"{len(on)} / {len(self.all_networks)} (online networks / all networks)" + f"{len(on)} / {len(self.all_networks)} (online networks / all networks)", ] - + [""] * (len(df.columns) - 1) + + [""] * (len(df.columns) - 1), ], columns=df.columns, index=["Total"], @@ -143,9 +143,9 @@ def __repr__(self) -> str: total_df = pd.DataFrame( [ [ - f"{len(on)} / {len(self.all_networks)} (online networks / all networks)" + f"{len(on)} / {len(self.all_networks)} (online networks / all networks)", ] - + [""] * (len(df.columns) - 1) + + [""] * (len(df.columns) - 1), ], columns=df.columns, index=["Total"], @@ -189,12 +189,12 @@ def __init__(self) -> None: try: network_json = NetworkRegistry.load_network_registry_json() self.all_networks = _get_all_networks( - network_json=network_json, version="2.0.0" + network_json=network_json, version="2.0.0", ) self._get_all_datasites() except Exception as e: logger.warning( - f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. {e}" + f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. {e}", ) def _get_all_datasites(self) -> None: @@ -226,7 +226,7 @@ def check_network(network: dict) -> dict[Any, Any] | None: online = False if online: - version = network.get("version", None) + version = network.get("version") # Check if syft version was described in NetworkRegistry # If it's unknown, try to update it to an available version. if not version or version == "unknown": @@ -248,7 +248,7 @@ def check_network(network: dict) -> dict[Any, Any] | None: with futures.ThreadPoolExecutor(max_workers=20) as executor: # map _online_networks = list( - executor.map(lambda network: check_network(network), networks) + executor.map(lambda network: check_network(network), networks), ) return [network for network in _online_networks if network is not None] @@ -305,9 +305,9 @@ def _repr_html_(self) -> str: total_df = pd.DataFrame( [ [ - f"{len(on)} / {len(self.all_datasites)} (online datasites / all datasites)" + f"{len(on)} / {len(self.all_datasites)} (online datasites / all datasites)", ] - + [""] * (len(df.columns) - 1) + + [""] * (len(df.columns) - 1), ], columns=df.columns, index=["Total"], @@ -323,9 +323,9 @@ def __repr__(self) -> str: total_df = pd.DataFrame( [ [ - f"{len(on)} / {len(self.all_datasites)} (online datasites / all datasites)" + f"{len(on)} / {len(self.all_datasites)} (online datasites / all datasites)", ] - + [""] * (len(df.columns) - 1) + + [""] * (len(df.columns) - 1), ], columns=df.columns, index=["Total"], @@ -367,7 +367,7 @@ def __init__(self) -> None: self.all_enclaves = enclaves_json["2.0.0"]["enclaves"] except Exception as e: logger.warning( - f"Failed to get Enclave Registry, go checkout: {ENCLAVE_REGISTRY_REPO}. {e}" + f"Failed to get Enclave Registry, go checkout: {ENCLAVE_REGISTRY_REPO}. {e}", ) @property @@ -383,7 +383,7 @@ def check_enclave(enclave: dict) -> dict[Any, Any] | None: online = False if online: - version = enclave.get("version", None) + version = enclave.get("version") # Check if syft version was described in EnclaveRegistry # If it's unknown, try to update it to an available version. if not version or version == "unknown": @@ -404,7 +404,7 @@ def check_enclave(enclave: dict) -> dict[Any, Any] | None: with futures.ThreadPoolExecutor(max_workers=20) as executor: # map _online_enclaves = list( - executor.map(lambda enclave: check_enclave(enclave), enclaves) + executor.map(lambda enclave: check_enclave(enclave), enclaves), ) online_enclaves = [each for each in _online_enclaves if each is not None] diff --git a/packages/syft/src/syft/client/search.py b/packages/syft/src/syft/client/search.py index e4450987aff..b37bc59f31d 100644 --- a/packages/syft/src/syft/client/search.py +++ b/packages/syft/src/syft/client/search.py @@ -34,9 +34,7 @@ def __getitem__(self, key: int | str | UID) -> Dataset: if dataset.id == key: return dataset elif isinstance(key, str): - if dataset.name == key: - return dataset - elif str(dataset.id) == key: + if dataset.name == key or str(dataset.id) == key: return dataset raise KeyError @@ -65,7 +63,7 @@ def __init__(self, datasites: DatasiteRegistry) -> None: @staticmethod def __search_one_server( - peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str + peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str, ) -> tuple[SyftClient | None, list[Dataset]]: try: peer, server_metadata = peer_tuple @@ -74,7 +72,7 @@ def __search_one_server( return (client, results) except Exception as e: # noqa warning = SyftWarning( - message=f"Got exception {e} at server {server_metadata.name}" + message=f"Got exception {e} at server {server_metadata.name}", ) display(warning) return (None, []) @@ -88,7 +86,7 @@ def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]: executor.map( lambda peer_tuple: self.__search_one_server(peer_tuple, name), self.datasites, - ) + ), ) # filter out SyftError filtered = [(client, result) for client, result in results if client and result] @@ -96,13 +94,15 @@ def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]: return filtered def search(self, name: str) -> SearchResults: - """ - Searches for a specific dataset by name. + """Searches for a specific dataset by name. Args: + ---- name (str): The name of the dataset to search for. Returns: + ------- SearchResults: An object containing the search results. + """ return SearchResults(self.__search(name)) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 204d95642aa..65e0b180372 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -1,26 +1,21 @@ # stdlib # stdlib -from collections.abc import Collection import logging +from collections.abc import Collection # relative from ..abstract_server import ServerSideType from ..server.credentials import SyftVerifyKey -from ..service.response import SyftError -from ..service.response import SyftSuccess -from ..service.sync.diff_state import ObjectDiffBatch -from ..service.sync.diff_state import ServerDiff -from ..service.sync.diff_state import SyncInstruction -from ..service.sync.resolve_widget import PaginatedResolveWidget -from ..service.sync.resolve_widget import ResolveWidget +from ..service.response import SyftError, SyftSuccess +from ..service.sync.diff_state import ObjectDiffBatch, ServerDiff, SyncInstruction +from ..service.sync.resolve_widget import PaginatedResolveWidget, ResolveWidget from ..service.sync.sync_state import SyncState from ..types.uid import UID from ..util.decorators import deprecated from ..util.util import prompt_warning_message from .datasite_client import DatasiteClient -from .sync_decision import SyncDecision -from .sync_decision import SyncDirection +from .sync_decision import SyncDecision, SyncDirection logger = logging.getLogger(__name__) @@ -78,13 +73,13 @@ def compare_states( direction = SyncDirection.HIGH_TO_LOW else: return SyftError( - "Invalid server side types: can only compare a high and low server" + "Invalid server side types: can only compare a high and low server", ) if hide_usercode: prompt_warning_message( "UserCodes are hidden by default, and are part of the Requests." - " If you want to include them as separate objects, set `hide_usercode=False`" + " If you want to include them as separate objects, set `hide_usercode=False`", ) exclude_types = exclude_types or [] exclude_types.append("usercode") @@ -136,7 +131,7 @@ def resolve( ) -> ResolveWidget | PaginatedResolveWidget | SyftSuccess | SyftError: if not isinstance(obj, ObjectDiffBatch | ServerDiff): raise ValueError( - f"Invalid type: could not resolve object with type {type(obj).__qualname__}" + f"Invalid type: could not resolve object with type {type(obj).__qualname__}", ) return obj.resolve() @@ -157,7 +152,7 @@ def handle_sync_batch( sync_direction = obj_diff_batch.sync_direction if sync_direction is None: return SyftError( - message="Cannot sync an object without a specified sync direction." + message="Cannot sync an object without a specified sync direction.", ) decision = sync_direction.to_sync_decision() @@ -169,11 +164,11 @@ def handle_sync_batch( return SyftSuccess(message="No changes to sync") elif obj_diff_batch.decision is SyncDecision.IGNORE: return SyftError( - message="Attempted to sync an ignored object, please unignore first" + message="Attempted to sync an ignored object, please unignore first", ) elif obj_diff_batch.decision is not None: return SyftError( - message="Attempted to sync an object that has already been synced" + message="Attempted to sync an object that has already been synced", ) src_client = obj_diff_batch.source_client @@ -226,7 +221,7 @@ def handle_ignore_batch( return SyftSuccess(message="This batch is already ignored") elif obj_diff_batch.decision is not None: return SyftError( - message="Attempted to sync an object that has already been synced" + message="Attempted to sync an object that has already been synced", ) obj_diff_batch.decision = SyncDecision.IGNORE diff --git a/packages/syft/src/syft/custom_worker/builder.py b/packages/syft/src/syft/custom_worker/builder.py index e47f341a27f..e6990d22dca 100644 --- a/packages/syft/src/syft/custom_worker/builder.py +++ b/packages/syft/src/syft/custom_worker/builder.py @@ -1,18 +1,14 @@ # stdlib -from functools import cached_property import os.path +from functools import cached_property from pathlib import Path from typing import Any # relative from .builder_docker import DockerBuilder from .builder_k8s import KubernetesBuilder -from .builder_types import BuilderBase -from .builder_types import ImageBuildResult -from .builder_types import ImagePushResult -from .config import CustomWorkerConfig -from .config import DockerWorkerConfig -from .config import WorkerConfig +from .builder_types import BuilderBase, ImageBuildResult, ImagePushResult +from .config import CustomWorkerConfig, DockerWorkerConfig, WorkerConfig from .k8s import IN_KUBERNETES __all__ = ["CustomWorkerBuilder"] @@ -42,13 +38,14 @@ def build_image( tag: str | None = None, **kwargs: Any, ) -> ImageBuildResult: - """ - Builds a Docker image from the given configuration. + """Builds a Docker image from the given configuration. + Args: + ---- config (WorkerConfig): The configuration for building the Docker image. tag (str): The tag to use for the image. - """ + """ if isinstance(config, DockerWorkerConfig): return self._build_dockerfile(config, tag, **kwargs) elif isinstance(config, CustomWorkerConfig): @@ -64,13 +61,14 @@ def push_image( password: str, **kwargs: Any, ) -> ImagePushResult: - """ - Pushes a Docker image to the given repo. + """Pushes a Docker image to the given repo. + Args: + ---- repo (str): The repo to push the image to. tag (str): The tag to use for the image. - """ + """ return self.builder.push_image( tag=tag, username=username, @@ -120,16 +118,19 @@ def _build_template( ) def find_worker_image(self, type: str) -> Path: - """ - Find the Worker Dockerfile and it's context path + """Find the Worker Dockerfile and it's context path - PROD will be in `$APPDIR/grid/` - DEV will be in `packages/grid/backend/grid/images` - In both the cases context dir does not matter (unless we're calling COPY) Args: + ---- type (str): The type of worker. + Returns: + ------- Path: The path to the Dockerfile. + """ filename = f"worker_{type}.dockerfile" lookup_paths = [ diff --git a/packages/syft/src/syft/custom_worker/builder_docker.py b/packages/syft/src/syft/custom_worker/builder_docker.py index 2e544f3842f..e155d995281 100644 --- a/packages/syft/src/syft/custom_worker/builder_docker.py +++ b/packages/syft/src/syft/custom_worker/builder_docker.py @@ -1,7 +1,7 @@ # stdlib -from collections.abc import Iterable import contextlib import io +from collections.abc import Iterable from pathlib import Path from typing import Any @@ -9,10 +9,12 @@ import docker # relative -from .builder_types import BUILD_IMAGE_TIMEOUT_SEC -from .builder_types import BuilderBase -from .builder_types import ImageBuildResult -from .builder_types import ImagePushResult +from .builder_types import ( + BUILD_IMAGE_TIMEOUT_SEC, + BuilderBase, + ImageBuildResult, + ImagePushResult, +) from .utils import iterator_to_string __all__ = ["DockerBuilder"] diff --git a/packages/syft/src/syft/custom_worker/builder_k8s.py b/packages/syft/src/syft/custom_worker/builder_k8s.py index b799b35f7b1..f4ca2eeda9b 100644 --- a/packages/syft/src/syft/custom_worker/builder_k8s.py +++ b/packages/syft/src/syft/custom_worker/builder_k8s.py @@ -5,21 +5,23 @@ from typing import Any # third party -from kr8s.objects import ConfigMap -from kr8s.objects import Job -from kr8s.objects import Secret +from kr8s.objects import ConfigMap, Job, Secret # relative -from .builder_types import BUILD_IMAGE_TIMEOUT_SEC -from .builder_types import BuilderBase -from .builder_types import ImageBuildResult -from .builder_types import ImagePushResult -from .builder_types import PUSH_IMAGE_TIMEOUT_SEC -from .k8s import INTERNAL_REGISTRY_HOST -from .k8s import KUBERNETES_NAMESPACE -from .k8s import KubeUtils -from .k8s import USE_INTERNAL_REGISTRY -from .k8s import get_kr8s_client +from .builder_types import ( + BUILD_IMAGE_TIMEOUT_SEC, + PUSH_IMAGE_TIMEOUT_SEC, + BuilderBase, + ImageBuildResult, + ImagePushResult, +) +from .k8s import ( + INTERNAL_REGISTRY_HOST, + KUBERNETES_NAMESPACE, + USE_INTERNAL_REGISTRY, + KubeUtils, + get_kr8s_client, +) from .utils import ImageUtils __all__ = ["KubernetesBuilder"] @@ -100,7 +102,7 @@ def build_image( raise BuildFailed( "Failed to build the image. " f"Kaniko exit code={exit_code}. " - f"Logs={logs}" + f"Logs={logs}", ) except Exception: @@ -201,7 +203,7 @@ def _create_build_config(self, job_id: str, dockerfile: str) -> ConfigMap: "data": { "Dockerfile": dockerfile, }, - } + }, ) return KubeUtils.create_or_get(config_map) @@ -268,7 +270,7 @@ def _create_kaniko_build_job( "mountPath": "/workspace", }, ], - } + }, ], "volumes": [ { @@ -278,10 +280,10 @@ def _create_kaniko_build_job( }, }, ], - } + }, }, }, - } + }, ) return KubeUtils.create_or_get(job) @@ -338,7 +340,7 @@ def _create_push_job( "readOnly": True, }, ], - } + }, ], "volumes": [ { @@ -354,10 +356,10 @@ def _create_push_job( }, }, ], - } + }, }, }, - } + }, ) return KubeUtils.create_or_get(job) diff --git a/packages/syft/src/syft/custom_worker/builder_types.py b/packages/syft/src/syft/custom_worker/builder_types.py index 2c6b1529adc..6e0f8089011 100644 --- a/packages/syft/src/syft/custom_worker/builder_types.py +++ b/packages/syft/src/syft/custom_worker/builder_types.py @@ -1,6 +1,5 @@ # stdlib -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from pathlib import Path from typing import Any diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index f01266221bd..20fe7d7005a 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -1,21 +1,20 @@ # stdlib import contextlib -from hashlib import sha256 import io +from hashlib import sha256 from pathlib import Path from typing import Any # third party import docker +import yaml from packaging import version from pydantic import field_validator from typing_extensions import Self -import yaml # relative from ..serde.serializable import serializable -from ..service.response import SyftError -from ..service.response import SyftSuccess +from ..service.response import SyftError, SyftSuccess from ..types.base import SyftBaseModel from .utils import iterator_to_string @@ -173,7 +172,7 @@ def test_image_build(self, tag: str, **kwargs: Any) -> SyftSuccess | SyftError: with contextlib.closing(docker.from_env()) as client: if not client.ping(): return SyftError( - "Cannot reach docker server. Please check if docker is running." + "Cannot reach docker server. Please check if docker is running.", ) kwargs["fileobj"] = io.BytesIO(self.dockerfile.encode("utf-8")) diff --git a/packages/syft/src/syft/custom_worker/k8s.py b/packages/syft/src/syft/custom_worker/k8s.py index c3d047a0daa..23386079de4 100644 --- a/packages/syft/src/syft/custom_worker/k8s.py +++ b/packages/syft/src/syft/custom_worker/k8s.py @@ -1,18 +1,14 @@ # stdlib import base64 +import json +import os from collections.abc import Iterable from enum import Enum from functools import cache -import json -import os # third party import kr8s -from kr8s.objects import APIObject -from kr8s.objects import ConfigMap -from kr8s.objects import Pod -from kr8s.objects import Secret -from kr8s.objects import Service +from kr8s.objects import APIObject, ConfigMap, Pod, Secret, Service from pydantic import BaseModel from typing_extensions import Self @@ -89,7 +85,7 @@ def from_status_dict(cls, status: dict) -> Self: phase=PodPhase(status.get("phase", "Unknown")), condition=PodCondition.from_conditions(status.get("conditions", [])), container=ContainerStatus.from_status( - status.get("containerStatuses", {})[0] + status.get("containerStatuses", {})[0], ), ) @@ -102,8 +98,7 @@ def get_kr8s_client() -> kr8s.Api: class KubeUtils: - """ - This class contains utility functions for interacting with kubernetes objects. + """This class contains utility functions for interacting with kubernetes objects. DO NOT call `get_kr8s_client()` inside this class, instead pass it as an argument to the functions. This is to avoid calling these functions on resources across namespaces! @@ -239,7 +234,7 @@ def create_secret( }, "type": type, "data": data, - } + }, ) return KubeUtils.create_or_get(secret) @@ -254,7 +249,6 @@ def create_or_get(obj: APIObject) -> APIObject: @staticmethod def patch_env_vars(env_list: list[dict], env_dict: dict) -> list[dict]: """Patch kubernetes pod environment variables in the list with the provided dictionary.""" - # update existing for item in env_list: k = item["name"] diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index ddb9765042c..0e760af54fd 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -2,15 +2,10 @@ from typing import Any # third party -from kr8s.objects import Pod -from kr8s.objects import Secret -from kr8s.objects import StatefulSet +from kr8s.objects import Pod, Secret, StatefulSet # relative -from .k8s import KUBERNETES_NAMESPACE -from .k8s import KubeUtils -from .k8s import PodStatus -from .k8s import get_kr8s_client +from .k8s import KUBERNETES_NAMESPACE, KubeUtils, PodStatus, get_kr8s_client JSONPATH_AVAILABLE_REPLICAS = "{.status.availableReplicas}" CREATE_POOL_TIMEOUT_SEC = 180 @@ -156,7 +151,6 @@ def _create_stateful_set( **kwargs: Any, ) -> StatefulSet: """Create a stateful set for a pool""" - volumes = [] volume_mounts = [] pull_secret_obj = None @@ -170,7 +164,7 @@ def _create_stateful_set( "secret": { "secretName": secret_name, }, - } + }, ) volume_mounts.append( { @@ -178,14 +172,14 @@ def _create_stateful_set( "mountPath": mount_opts.get("mountPath"), "subPath": mount_opts.get("subPath"), "readOnly": True, - } + }, ) if pull_secret: pull_secret_obj = [ { "name": pull_secret.name, - } + }, ] default_pod_labels = { @@ -213,7 +207,7 @@ def _create_stateful_set( "selector": { "matchLabels": { "app.kubernetes.io/component": pool_name, - } + }, }, "template": { "metadata": { @@ -247,13 +241,13 @@ def _create_stateful_set( "failureThreshold": 30, "periodSeconds": 10, }, - } + }, ], "volumes": volumes, "imagePullSecrets": pull_secret_obj, }, }, }, - } + }, ) return KubeUtils.create_or_get(stateful_set) diff --git a/packages/syft/src/syft/custom_worker/utils.py b/packages/syft/src/syft/custom_worker/utils.py index 5c4a9768649..24ad8471d2b 100644 --- a/packages/syft/src/syft/custom_worker/utils.py +++ b/packages/syft/src/syft/custom_worker/utils.py @@ -1,6 +1,6 @@ # stdlib -from collections.abc import Iterable import json +from collections.abc import Iterable def iterator_to_string(iterator: Iterable) -> str: diff --git a/packages/syft/src/syft/dev/prof.py b/packages/syft/src/syft/dev/prof.py index a6aeffc5780..e9f4cb96f85 100644 --- a/packages/syft/src/syft/dev/prof.py +++ b/packages/syft/src/syft/dev/prof.py @@ -12,11 +12,13 @@ def pyspy() -> None: # type: ignore """Profile a block of code using py-spy. Intended for development purposes only. Example: + ------- ``` with pyspy(): # do some work a = [i for i in range(1000000)] ``` + """ fd, fname = tempfile.mkstemp(".svg") os.close(fd) diff --git a/packages/syft/src/syft/exceptions/user.py b/packages/syft/src/syft/exceptions/user.py index 59147e29522..ef2c78a53f6 100644 --- a/packages/syft/src/syft/exceptions/user.py +++ b/packages/syft/src/syft/exceptions/user.py @@ -5,5 +5,5 @@ from .exception import PySyftException UserAlreadyExistsException = PySyftException( - message="User already exists", roles=[ServiceRole.ADMIN] + message="User already exists", roles=[ServiceRole.ADMIN], ) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index ebe74b85d31..69f9c8c351f 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -3,22 +3,22 @@ # future from __future__ import annotations -# stdlib -from collections.abc import Callable -from enum import Enum import getpass import inspect import logging import os import sys + +# stdlib +from collections.abc import Callable +from enum import Enum from typing import Any # third party from IPython.display import display # relative -from .abstract_server import ServerSideType -from .abstract_server import ServerType +from .abstract_server import ServerSideType, ServerType from .client.client import login as sy_login from .client.client import login_as_guest as sy_login_as_guest from .protocol.data_protocol import stage_protocol_changes @@ -26,8 +26,7 @@ from .server.enclave import Enclave from .server.gateway import Gateway from .server.uvicorn import serve_server -from .service.response import SyftError -from .service.response import SyftInfo +from .service.response import SyftError, SyftInfo from .util.util import get_random_available_port logger = logging.getLogger(__name__) @@ -51,14 +50,14 @@ def get_server_type(server_type: str | ServerType | None) -> ServerType | None: def get_deployment_type(deployment_type: str | None) -> DeploymentType | None: if deployment_type is None: deployment_type = os.environ.get( - "ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON + "ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON, ) try: return DeploymentType(deployment_type) except ValueError: print( - f"deployment_type: {deployment_type} is not a valid DeploymentType: {DeploymentType}" + f"deployment_type: {deployment_type} is not a valid DeploymentType: {DeploymentType}", ) return None @@ -99,14 +98,14 @@ def client(self) -> Any: return self.python_server.get_guest_client(verbose=False) # type: ignore else: raise NotImplementedError( - f"client not implemented for the deployment type:{self.deployment_type}" + f"client not implemented for the deployment type:{self.deployment_type}", ) def login_as_guest(self, **kwargs: Any) -> ClientAlias: return self.client.login_as_guest(**kwargs) def login( - self, email: str | None = None, password: str | None = None, **kwargs: Any + self, email: str | None = None, password: str | None = None, **kwargs: Any, ) -> ClientAlias: if not email: email = input("Email: ") @@ -116,14 +115,14 @@ def login( if self.port: return sy_login( - email=email, password=password, url=self.url, port=self.port + email=email, password=password, url=self.url, port=self.port, ) # type: ignore elif self.deployment_type == DeploymentType.PYTHON: guest_client = self.python_server.get_guest_client(verbose=False) # type: ignore return guest_client.login(email=email, password=password, **kwargs) # type: ignore else: raise NotImplementedError( - f"client not implemented for the deployment type:{self.deployment_type}" + f"client not implemented for the deployment type:{self.deployment_type}", ) def register( @@ -226,7 +225,7 @@ def deploy_to_python( port = int(port) except ValueError: raise ValueError( - f"port must be either 'auto' or a valid int not: {port}" + f"port must be either 'auto' or a valid int not: {port}", ) kwargs["port"] = port @@ -256,7 +255,7 @@ def deploy_to_python( worker = worker_class.named(**supported_kwargs) else: raise NotImplementedError( - f"server_type: {server_type_enum} is not supported" + f"server_type: {server_type_enum} is not supported", ) def stop() -> None: @@ -338,7 +337,7 @@ def launch( ) deployment_type_enum: DeploymentType | None = get_deployment_type( - deployment_type=deploy_to + deployment_type=deploy_to, ) if deployment_type_enum == DeploymentType.PYTHON: @@ -367,8 +366,8 @@ def launch( display( SyftInfo( message=f"You have launched a development server at http://{host}:{server_handle.port}." - + "It is intended only for local use." - ) + + "It is intended only for local use.", + ), ) return server_handle elif deployment_type_enum == DeploymentType.REMOTE: @@ -382,5 +381,5 @@ def launch( migrate=migrate, ) raise NotImplementedError( - f"deployment_type: {deployment_type_enum} is not supported" + f"deployment_type: {deployment_type_enum} is not supported", ) diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 0590db599bc..b62c6c3d798 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -1,30 +1,25 @@ # stdlib -from collections import defaultdict -from collections.abc import Iterable -from collections.abc import MutableMapping -from collections.abc import MutableSequence -from functools import cache import hashlib import json -from operator import itemgetter import os -from pathlib import Path import re -from types import UnionType import typing -from typing import Any import warnings +from collections import defaultdict +from collections.abc import Iterable, MutableMapping, MutableSequence +from functools import cache +from operator import itemgetter +from pathlib import Path +from types import UnionType +from typing import Any # third party from packaging.version import parse -from result import OkErr -from result import Result +from result import OkErr, Result # relative from .. import __version__ -from ..service.response import SyftError -from ..service.response import SyftException -from ..service.response import SyftSuccess +from ..service.response import SyftError, SyftException, SyftSuccess from ..types.dicttuple import DictTuple from ..types.syft_object import SyftBaseObject from ..types.syft_object_registry import SyftObjectRegistry @@ -118,7 +113,7 @@ def _calculate_object_hash(klass: type[SyftBaseObject]) -> str: field_data = { field: handle_annotation_repr_(field_info.rebuild_annotation()) for field, field_info in sorted( - klass.model_fields.items(), key=itemgetter(0) + klass.model_fields.items(), key=itemgetter(0), ) } obj_meta_info = { @@ -156,7 +151,7 @@ def save_history(self, history: dict) -> None: for file_path in protocol_release_dir().iterdir(): for version in self.read_json(file_path): # Skip adding file if the version is not part of the history - if version not in history.keys(): + if version not in history: continue history[version] = {"release_name": file_path.name} self.file_path.write_text(json.dumps(history, indent=2) + "\n") @@ -189,14 +184,14 @@ def build_state(self, stop_key: str | None = None) -> dict: or hash_str in state_version_hashes ): raise Exception( - f"Can't add {object_metadata} already in state {versions}" + f"Can't add {object_metadata} already in state {versions}", ) if action == "remove" and ( str(version) not in state_versions.keys() and hash_str not in state_version_hashes ): raise Exception( - f"Can't remove {object_metadata} missing from state {versions} for object {canonical_name}." + f"Can't remove {object_metadata} missing from state {versions} for object {canonical_name}.", ) if action == "add": state_dict[canonical_name][str(version)] = ( @@ -231,7 +226,7 @@ def diff_state(self, state: dict) -> tuple[dict, dict]: if issubclass(cls, SyftBaseObject): canonical_name = cls.__canonical_name__ if canonical_name in IGNORE_TYPES or canonical_name.startswith( - "MockSyftObject_" + "MockSyftObject_", ): continue @@ -325,7 +320,7 @@ def stage_protocol_changes(self) -> Result[SyftSuccess, SyftError]: # Sort the version dict object_versions[canonical_name] = sort_dict_naturally( - object_versions.get(canonical_name, {}) + object_versions.get(canonical_name, {}), ) current_history["dev"]["object_versions"] = object_versions @@ -341,7 +336,7 @@ def stage_protocol_changes(self) -> Result[SyftSuccess, SyftError]: def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]: if len(self.diff): raise Exception( - "You can't bump the protocol version with unstaged changes." + "You can't bump the protocol version with unstaged changes.", ) keys = self.protocol_history.keys() @@ -349,7 +344,7 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]: self.validate_release() print("You can't bump the protocol if there are no staged changes.") return SyftError( - message="Failed to bump version as there are no staged changes." + message="Failed to bump version as there are no staged changes.", ) highest_protocol = 0 @@ -369,7 +364,6 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]: @staticmethod def freeze_release(protocol_history: dict, latest_protocol: str) -> None: """Freezes latest release as a separate release file.""" - # Get release history release_history = protocol_history[latest_protocol] @@ -380,7 +374,7 @@ def freeze_release(protocol_history: dict, latest_protocol: str) -> None: # Save the new released version release_file.write_text( - json.dumps({latest_protocol: release_history}, indent=2) + json.dumps({latest_protocol: release_history}, indent=2), ) def validate_release(self) -> None: @@ -411,7 +405,7 @@ def validate_release(self) -> None: # Update release name to latest beta, stable or post based on current syft version print( - f"Current release {release_name} will be updated to {current_syft_version}" + f"Current release {release_name} will be updated to {current_syft_version}", ) # Get latest protocol file path @@ -436,7 +430,6 @@ def validate_release(self) -> None: def revert_latest_protocol(self) -> Result[SyftSuccess, SyftError]: """Revert latest protocol changes to dev""" - # Get current protocol history protocol_history = self.read_json(self.file_path) diff --git a/packages/syft/src/syft/serde/__init__.py b/packages/syft/src/syft/serde/__init__.py index 00122b4769f..6fa4930286a 100644 --- a/packages/syft/src/syft/serde/__init__.py +++ b/packages/syft/src/syft/serde/__init__.py @@ -1,4 +1,4 @@ # relative -from .array import NOTHING # noqa: F811 +from .array import NOTHING from .recursive import NOTHING # noqa: F811 from .third_party import NOTHING # noqa: F811 diff --git a/packages/syft/src/syft/serde/array.py b/packages/syft/src/syft/serde/array.py index 3f19e575b97..bddad650967 100644 --- a/packages/syft/src/syft/serde/array.py +++ b/packages/syft/src/syft/serde/array.py @@ -4,8 +4,7 @@ # relative from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from .arrow import numpy_deserialize -from .arrow import numpy_serialize +from .arrow import numpy_deserialize, numpy_serialize from .recursive import recursive_serde_register SUPPORTED_BOOL_TYPES = [np.bool_] diff --git a/packages/syft/src/syft/serde/arrow.py b/packages/syft/src/syft/serde/arrow.py index 31a5ad1d27c..a1b3a7fd1ca 100644 --- a/packages/syft/src/syft/serde/arrow.py +++ b/packages/syft/src/syft/serde/arrow.py @@ -6,8 +6,7 @@ import pyarrow as pa # relative -from ..util.experimental_flags import ApacheArrowCompression -from ..util.experimental_flags import flags +from ..util.experimental_flags import ApacheArrowCompression, flags from .deserialize import _deserialize from .serialize import _serialize @@ -24,7 +23,7 @@ def inner(obj: np.ndarray) -> tuple: numpy_bytes = buffer.to_pybytes() else: numpy_bytes = pa.compress( - buffer, asbytes=True, codec=flags.APACHE_ARROW_COMPRESSION.value + buffer, asbytes=True, codec=flags.APACHE_ARROW_COMPRESSION.value, ) dtype = original_dtype.name return (numpy_bytes, buffer.size, dtype) @@ -34,7 +33,7 @@ def inner(obj: np.ndarray) -> tuple: def arrow_deserialize( - numpy_bytes: bytes, decompressed_size: int, dtype: str + numpy_bytes: bytes, decompressed_size: int, dtype: str, ) -> np.ndarray: original_dtype = np.dtype(dtype) if flags.APACHE_ARROW_COMPRESSION is ApacheArrowCompression.NONE: @@ -57,16 +56,19 @@ def numpyutf8toarray(input_index: np.ndarray) -> np.ndarray: """Decodes utf-8 encoded numpy array to string numpy array. Args: + ---- input_index (np.ndarray): utf-8 encoded array Returns: + ------- np.ndarray: decoded NumpyArray. + """ shape_length = int(input_index[-1]) - shape = tuple(input_index[-(shape_length + 1) : -1]) # noqa + shape = tuple(input_index[-(shape_length + 1) : -1]) string_index = input_index[: -(shape_length + 1)] index_length = int(string_index[-1]) - index_array = string_index[-(index_length + 1) : -1] # noqa + index_array = string_index[-(index_length + 1) : -1] string_array: np.ndarray = string_index[: -(index_length + 1)] output_bytes: bytes = string_array.astype(np.uint8).tobytes() output_list = [] @@ -83,10 +85,13 @@ def arraytonumpyutf8(string_list: str | np.ndarray) -> bytes: """Encodes string Numpyarray to utf-8 encoded numpy array. Args: + ---- string_list (np.ndarray): NumpyArray to be encoded Returns: + ------- bytes: serialized utf-8 encoded int Numpy array + """ array_shape = np.array(string_list).shape string_list = np.array(string_list).flatten() @@ -107,7 +112,7 @@ def arraytonumpyutf8(string_list: str | np.ndarray) -> bytes: shape = np.array(array_shape, dtype=np.uint64) shape_length = np.array([len(shape)], dtype=np.uint64) output_array = np.concatenate( - [np_bytes, np_indexes, index_length, shape, shape_length] + [np_bytes, np_indexes, index_length, shape, shape_length], ) return cast(bytes, _serialize(output_array, to_bytes=True)) diff --git a/packages/syft/src/syft/serde/capnp.py b/packages/syft/src/syft/serde/capnp.py index bfc77bfdf90..5995e4a6708 100644 --- a/packages/syft/src/syft/serde/capnp.py +++ b/packages/syft/src/syft/serde/capnp.py @@ -1,6 +1,5 @@ # stdlib -from importlib.resources import as_file -from importlib.resources import files +from importlib.resources import as_file, files # third party import capnp diff --git a/packages/syft/src/syft/serde/deserialize.py b/packages/syft/src/syft/serde/deserialize.py index 46f3564a1fa..c831a27d170 100644 --- a/packages/syft/src/syft/serde/deserialize.py +++ b/packages/syft/src/syft/serde/deserialize.py @@ -11,8 +11,7 @@ def _deserialize( from_bytes: bool = False, ) -> Any: # relative - from .recursive import rs_bytes2object - from .recursive import rs_proto2object + from .recursive import rs_bytes2object, rs_proto2object if ( (from_bytes and not isinstance(blob, bytes)) diff --git a/packages/syft/src/syft/serde/lib_permissions.py b/packages/syft/src/syft/serde/lib_permissions.py index 751f72df410..dacea998dcc 100644 --- a/packages/syft/src/syft/serde/lib_permissions.py +++ b/packages/syft/src/syft/serde/lib_permissions.py @@ -16,7 +16,7 @@ class CMPCRUDPermission(Enum): class CMPPermission: @property def permissions_string(self) -> str: - raise NotImplementedError() + raise NotImplementedError def __repr__(self) -> str: return self.permission_string diff --git a/packages/syft/src/syft/serde/lib_service_registry.py b/packages/syft/src/syft/serde/lib_service_registry.py index 517df6c643c..6feafe72284 100644 --- a/packages/syft/src/syft/serde/lib_service_registry.py +++ b/packages/syft/src/syft/serde/lib_service_registry.py @@ -1,10 +1,8 @@ # stdlib -from collections.abc import Callable -from collections.abc import Sequence import importlib import inspect -from inspect import Signature -from inspect import _signature_fromstr +from collections.abc import Callable, Sequence +from inspect import Signature, _signature_fromstr from types import BuiltinFunctionType from typing import Any @@ -13,9 +11,7 @@ from typing_extensions import Self # relative -from .lib_permissions import ALL_EXECUTE -from .lib_permissions import CMPPermission -from .lib_permissions import NONE_EXECUTE +from .lib_permissions import ALL_EXECUTE, NONE_EXECUTE, CMPPermission from .signature import get_signature LIB_IGNORE_ATTRIBUTES = { @@ -71,7 +67,7 @@ def __init__( if text_signature is not None: self.signature = _signature_fromstr( - inspect.Signature, obj, text_signature, True + inspect.Signature, obj, text_signature, True, ) self.is_built = False @@ -123,12 +119,15 @@ def init_child( """Get the child of parent as a CMPBase object Args: + ---- parent_obj (_type_): parent object child_path (_type_): _description_ child_obj (_type_): _description_ Returns: + ------- _type_: _description_ + """ parent_is_parent_module = CMPBase.parent_is_parent_module(parent_obj, child_obj) if CMPBase.isfunction(child_obj) and parent_is_parent_module: @@ -139,7 +138,7 @@ def init_child( absolute_path=absolute_path, ) # type: ignore elif inspect.ismodule(child_obj) and CMPBase.is_submodule( - parent_obj, child_obj + parent_obj, child_obj, ): ## TODO, we could register modules and functions in 2 ways: # A) as numpy.float32 (what we are doing now) @@ -204,7 +203,7 @@ def isfunction(obj: Callable) -> bool: ) def __repr__( - self, indent: int = 0, is_last: bool = False, parent_path: str = "" + self, indent: int = 0, is_last: bool = False, parent_path: str = "", ) -> str: """Visualize the tree, e.g.: ├───numpy (ALL_EXECUTE) @@ -222,12 +221,15 @@ def __repr__( │ │ ├───_clean_args (ALL_EXECUTE) Args: + ---- indent (int, optional): indentation level. Defaults to 0. is_last (bool, optional): is last item of collection. Defaults to False. parent_path (str, optional): path of the parent obj. Defaults to "". Returns: + ------- str: representation of the CMP + """ last_idx, c_indent = len(self.children) - 1, indent + 1 children_string = "".join( @@ -237,9 +239,9 @@ def __repr__( sorted( self.children.values(), key=lambda x: x.permissions.permission_string, # type: ignore - ) + ), ) - ] + ], ) tree_prefix = "└───" if is_last else "├───" indent_str = "│ " * indent + tree_prefix @@ -351,5 +353,5 @@ def __repr__(self) -> str: CMPModule("testing", permissions=NONE_EXECUTE), ], ), - ] + ], ).build() diff --git a/packages/syft/src/syft/serde/mock.py b/packages/syft/src/syft/serde/mock.py index 60334afb478..e71f496c7c5 100644 --- a/packages/syft/src/syft/serde/mock.py +++ b/packages/syft/src/syft/serde/mock.py @@ -1,6 +1,6 @@ # stdlib -from collections import defaultdict import secrets +from collections import defaultdict from typing import Any # third party diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 33bf94c8d4f..93ccdd3a031 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -1,19 +1,18 @@ # stdlib -from collections.abc import Callable -from enum import Enum -from enum import EnumMeta import os import tempfile import types +from collections.abc import Callable +from enum import Enum, EnumMeta from typing import Any +# syft absolute +import syft as sy + # third party from capnp.lib.capnp import _DynamicStructBuilder from pydantic import BaseModel -# syft absolute -import syft as sy - # relative from ..types.syft_object_registry import SyftObjectRegistry from .capnp import get_capnp_schema @@ -79,7 +78,7 @@ def check_fqn_alias(cls: object | type) -> tuple[str, ...] | None: def has_canonical_name_version( - cls: type, cannonical_name: str | None, version: int | None + cls: type, cannonical_name: str | None, version: int | None, ) -> bool: cls_canonical_name = getattr(cls, "__canonical_name__", None) cls_version = getattr(cls, "__version__", None) @@ -87,25 +86,25 @@ def has_canonical_name_version( def validate_cannonical_name_version( - cls: type, canonical_name: str | None, version: int | None + cls: type, canonical_name: str | None, version: int | None, ) -> tuple[str, int]: cls_canonical_name = getattr(cls, "__canonical_name__", None) cls_version = getattr(cls, "__version__", None) if cls_canonical_name and canonical_name: raise ValueError( - "Cannot specify both __canonical_name__ attribute and cannonical_name argument." + "Cannot specify both __canonical_name__ attribute and cannonical_name argument.", ) if cls_version and version: raise ValueError( - "Cannot specify both __version__ attribute and version argument." + "Cannot specify both __version__ attribute and version argument.", ) if cls_canonical_name is None and canonical_name is None: raise ValueError( - "Must specify either __canonical_name__ attribute or cannonical_name argument." + "Must specify either __canonical_name__ attribute or cannonical_name argument.", ) if cls_version is None and version is None: raise ValueError( - "Must specify either __version__ attribute or version argument." + "Must specify either __version__ attribute or version argument.", ) canonical_name = canonical_name or cls_canonical_name @@ -114,14 +113,12 @@ def validate_cannonical_name_version( def skip_unregistered_class( - cls: type, canonical_name: str | None, version: str | None + cls: type, canonical_name: str | None, version: str | None, ) -> bool: - """ - Used to gather all classes that are missing canonical_name and version for development. + """Used to gather all classes that are missing canonical_name and version for development. Returns True if the class should be skipped, False otherwise. """ - search_unregistered_classes = ( os.getenv("SYFT_SEARCH_MISSING_CANONICAL_NAME", False) == "true" ) @@ -155,7 +152,7 @@ def recursive_serde_register( return canonical_name, version = validate_cannonical_name_version( - cls, canonical_name, version + cls, canonical_name, version, ) nonrecursive = bool(serialize and deserialize) @@ -280,13 +277,13 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild msg = recursive_scheme.new_message() - # todo: rewrite and make sure every object has a canonical name and version + # TODO: rewrite and make sure every object has a canonical name and version canonical_name, version = SyftObjectRegistry.get_canonical_name_version(self) if not SyftObjectRegistry.has_serde_class(canonical_name, version): # third party raise Exception( - f"obj2proto: {canonical_name} version {version} not in SyftObjectRegistry" + f"obj2proto: {canonical_name} version {version} not in SyftObjectRegistry", ) msg.canonicalName = canonical_name @@ -308,7 +305,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild if nonrecursive or is_type: if serialize is None: raise Exception( - f"Cant serialize {type(self)} nonrecursive without serialize." + f"Cant serialize {type(self)} nonrecursive without serialize.", ) chunk_bytes(self, serialize, "nonrecursiveBlob", msg) return msg @@ -331,7 +328,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild for idx, attr_name in enumerate(sorted(attribute_list)): if not hasattr(self, attr_name): raise ValueError( - f"{attr_name} on {type(self)} does not exist, serialization aborted!" + f"{attr_name} on {type(self)} does not exist, serialization aborted!", ) field_obj = getattr(self, attr_name) @@ -358,14 +355,15 @@ def rs_bytes2object(blob: bytes) -> Any: MAX_TRAVERSAL_LIMIT = 2**64 - 1 with recursive_scheme.from_bytes( - blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT + blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT, ) as msg: return rs_proto2object(msg) def map_fqns_for_backward_compatibility(fqn: str) -> str: - """for backwards compatibility with 0.8.6. Sometimes classes where moved to another file. Which is - exactly why we are implementing it differently""" + """For backwards compatibility with 0.8.6. Sometimes classes where moved to another file. Which is + exactly why we are implementing it differently + """ mapping = { "syft.service.dataset.dataset.MarkdownDescription": "syft.util.misc_objs.MarkdownDescription", "syft.service.object_search.object_migration_state.SyftObjectMigrationState": "syft.service.migration.object_migration_state.SyftObjectMigrationState", # noqa: E501 @@ -394,7 +392,7 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: # third party if not SyftObjectRegistry.has_serde_class(canonical_name, version): raise Exception( - f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" + f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry", ) # TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the @@ -420,7 +418,7 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: if nonrecursive: if deserialize is None: raise Exception( - f"Cant serialize {type(proto)} nonrecursive without serialize." + f"Cant serialize {type(proto)} nonrecursive without serialize.", ) return deserialize(combine_bytes(proto.nonrecursiveBlob)) diff --git a/packages/syft/src/syft/serde/recursive_primitives.py b/packages/syft/src/syft/serde/recursive_primitives.py index 38d8281434d..3d217d377e2 100644 --- a/packages/syft/src/syft/serde/recursive_primitives.py +++ b/packages/syft/src/syft/serde/recursive_primitives.py @@ -1,39 +1,34 @@ # stdlib -from abc import ABCMeta -from collections import OrderedDict -from collections import defaultdict -from collections.abc import Collection -from collections.abc import Iterable -from collections.abc import Mapping -from enum import Enum -from enum import EnumMeta import functools import inspect import pathlib -from pathlib import PurePath import sys import tempfile -from types import MappingProxyType -from types import UnionType import typing -from typing import Any -from typing import GenericAlias -from typing import Optional -from typing import TypeVar -from typing import Union -from typing import _GenericAlias -from typing import _SpecialForm -from typing import _SpecialGenericAlias -from typing import _UnionGenericAlias -from typing import cast import weakref +from abc import ABCMeta +from collections import OrderedDict, defaultdict +from collections.abc import Collection, Iterable, Mapping +from enum import Enum, EnumMeta +from pathlib import PurePath +from types import MappingProxyType, UnionType +from typing import ( + Any, + GenericAlias, + Optional, + TypeVar, + Union, + _GenericAlias, + _SpecialForm, + _SpecialGenericAlias, + _UnionGenericAlias, + cast, +) # relative from ..types.syft_object_registry import SyftObjectRegistry from .capnp import get_capnp_schema -from .recursive import chunk_bytes -from .recursive import combine_bytes -from .recursive import recursive_serde_register +from .recursive import chunk_bytes, combine_bytes, recursive_serde_register from .util import compatible_with_large_file_writes_capnp iterable_schema = get_capnp_schema("iterable.capnp").Iterable @@ -73,7 +68,7 @@ def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection: MAX_TRAVERSAL_LIMIT = 2**64 - 1 with iterable_schema.from_bytes( - blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT + blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT, ) as msg: values = [ _deserialize(combine_bytes(element), from_bytes=True) @@ -116,14 +111,14 @@ def get_deserialized_kv_pairs(blob: bytes) -> list[Any]: pairs = [] with kv_iterable_schema.from_bytes( - blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT + blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT, ) as msg: for key, value in zip(msg.keys, msg.values): pairs.append( ( _deserialize(key, from_bytes=True), _deserialize(combine_bytes(value), from_bytes=True), - ) + ), ) return pairs @@ -176,7 +171,7 @@ def serialize_type(_type_to_serialize: type) -> bytes: # relative type_to_serialize = typing.get_origin(_type_to_serialize) or _type_to_serialize canonical_name, version = SyftObjectRegistry.get_identifier_for_type( - type_to_serialize + type_to_serialize, ) return f"{canonical_name}:{version}".encode() @@ -449,7 +444,7 @@ def recursive_serde_register_type( # former case is for instance for _GerericAlias itself or UnionGenericAlias # Latter case is true for for instance List[str], which is currently not used if (isinstance(t, type) and issubclass(t, _GenericAlias)) or issubclass( - type(t), _GenericAlias + type(t), _GenericAlias, ): recursive_serde_register( t, @@ -557,7 +552,7 @@ def deserialize_any(type_blob: bytes) -> type: # type: ignore version=1, ) recursive_serde_register_type( - _SpecialGenericAlias, canonical_name="_SpecialGenericAlias", version=1 + _SpecialGenericAlias, canonical_name="_SpecialGenericAlias", version=1, ) recursive_serde_register_type(GenericAlias, canonical_name="GenericAlias", version=1) diff --git a/packages/syft/src/syft/serde/serializable.py b/packages/syft/src/syft/serde/serializable.py index 9a683dbcf57..c5969b08389 100644 --- a/packages/syft/src/syft/serde/serializable.py +++ b/packages/syft/src/syft/serde/serializable.py @@ -22,10 +22,10 @@ def serializable( canonical_name: str | None = None, version: int | None = None, ) -> Callable[[T], T]: - """ - Recursively serialize attributes of the class. + """Recursively serialize attributes of the class. Args: + ---- `attrs` : List of attributes to serialize `without` : List of attributes to exclude from serialization `inherit` : Whether to inherit serializable attribute list from base class @@ -44,7 +44,9 @@ def serializable( - `inherit`, `inheritable` will not work as pydantic inherits by default Returns: + ------- Decorated class + """ def rs_decorator(cls: T) -> T: diff --git a/packages/syft/src/syft/serde/signature.py b/packages/syft/src/syft/serde/signature.py index 23b0a556fca..d7a04d7875a 100644 --- a/packages/syft/src/syft/serde/signature.py +++ b/packages/syft/src/syft/serde/signature.py @@ -1,11 +1,8 @@ # stdlib -from collections.abc import Callable import inspect -from inspect import Parameter -from inspect import Signature -from inspect import _ParameterKind -from inspect import _signature_fromstr import re +from collections.abc import Callable +from inspect import Parameter, Signature, _ParameterKind, _signature_fromstr # relative from .deserialize import _deserialize @@ -74,7 +71,7 @@ def signature_remove_self(signature: Signature) -> Signature: params = dict(signature.parameters) params.pop("self", None) return Signature( - list(params.values()), return_annotation=signature.return_annotation + list(params.values()), return_annotation=signature.return_annotation, ) @@ -82,7 +79,7 @@ def signature_remove_context(signature: Signature) -> Signature: params = dict(signature.parameters) params.pop("context", None) return Signature( - list(params.values()), return_annotation=signature.return_annotation + list(params.values()), return_annotation=signature.return_annotation, ) @@ -109,7 +106,7 @@ def get_str_signature_from_docstring(doc: str, callable_name: str) -> str | None if re.search(rf"(?<={params[-1]})\],", signature_str): signature_str = signature_str.replace( - f"[{params[-1]}],", params[-1] + f"[{params[-1]}],", params[-1], ) else: signature_str = signature_str.replace( diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 6a46f789c26..aa7dc6bcb11 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -1,41 +1,37 @@ # stdlib -from datetime import date -from datetime import datetime -from datetime import time import functools +from datetime import date, datetime, time from importlib.util import find_spec from io import BytesIO -# third party -from dateutil import parser -from nacl.signing import SigningKey -from nacl.signing import VerifyKey import numpy as np -from pandas import DataFrame -from pandas import Series -from pandas._libs.tslibs.timestamps import Timestamp import pyarrow as pa import pyarrow.parquet as pq import pydantic + +# third party +from dateutil import parser +from nacl.signing import SigningKey, VerifyKey +from pandas import DataFrame, Series +from pandas._libs.tslibs.timestamps import Timestamp from pydantic._internal._model_construction import ModelMetaclass from pymongo.collection import Collection -from result import Err -from result import Ok +from result import Err, Ok # relative from ..types.dicttuple import DictTuple from ..types.dicttuple import _Meta as _DictTupleMetaClass -from ..types.syft_metaclass import EmptyType -from ..types.syft_metaclass import PartialModelMetaclass -from .array import numpy_deserialize -from .array import numpy_serialize +from ..types.syft_metaclass import EmptyType, PartialModelMetaclass +from .array import numpy_deserialize, numpy_serialize from .deserialize import _deserialize as deserialize -from .recursive_primitives import _serialize_kv_pairs -from .recursive_primitives import deserialize_kv -from .recursive_primitives import deserialize_type -from .recursive_primitives import recursive_serde_register -from .recursive_primitives import recursive_serde_register_type -from .recursive_primitives import serialize_type +from .recursive_primitives import ( + _serialize_kv_pairs, + deserialize_kv, + deserialize_type, + recursive_serde_register, + recursive_serde_register_type, + serialize_type, +) from .serialize import _serialize as serialize recursive_serde_register( @@ -57,10 +53,10 @@ # result Ok and Err recursive_serde_register( - Ok, serialize_attrs=["_value"], canonical_name="result_Ok", version=1 + Ok, serialize_attrs=["_value"], canonical_name="result_Ok", version=1, ) recursive_serde_register( - Err, serialize_attrs=["_value"], canonical_name="result_Err", version=1 + Err, serialize_attrs=["_value"], canonical_name="result_Err", version=1, ) # exceptions @@ -68,7 +64,7 @@ # mongo collection recursive_serde_register_type( - Collection, canonical_name="pymongo_collection", version=1 + Collection, canonical_name="pymongo_collection", version=1, ) @@ -180,10 +176,10 @@ def _serialize_dicttuple(x: DictTuple) -> bytes: recursive_serde_register_type( - ModelMetaclass, canonical_name="pydantic_model_metaclass", version=1 + ModelMetaclass, canonical_name="pydantic_model_metaclass", version=1, ) recursive_serde_register_type( - PartialModelMetaclass, canonical_name="partial_model_metaclass", version=1 + PartialModelMetaclass, canonical_name="partial_model_metaclass", version=1, ) @@ -216,10 +212,10 @@ def serialize_bytes_io(io: BytesIO) -> bytes: from torch._C import _TensorMeta recursive_serde_register_type( - _TensorMeta, canonical_name="torch_tensor_meta", version=1 + _TensorMeta, canonical_name="torch_tensor_meta", version=1, ) recursive_serde_register_type( - torch.Tensor, canonical_name="torch_tensor", version=1 + torch.Tensor, canonical_name="torch_tensor", version=1, ) def torch_serialize(tensor: torch.Tensor) -> bytes: diff --git a/packages/syft/src/syft/server/credentials.py b/packages/syft/src/syft/server/credentials.py index 4b811aa0db5..79155d650db 100644 --- a/packages/syft/src/syft/server/credentials.py +++ b/packages/syft/src/syft/server/credentials.py @@ -6,8 +6,7 @@ # third party from nacl.encoding import HexEncoder -from nacl.signing import SigningKey -from nacl.signing import VerifyKey +from nacl.signing import SigningKey, VerifyKey from pydantic import field_validator # relative @@ -37,7 +36,7 @@ def from_string(key_str: str) -> SyftVerifyKey: def verify(self) -> str: return str(self) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SyftVerifyKey): return False return self.verify_key == other.verify_key @@ -88,7 +87,7 @@ def verify(self) -> str: def __hash__(self) -> int: return hash(self.signing_key) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SyftSigningKey): return False return self.signing_key == other.signing_key diff --git a/packages/syft/src/syft/server/routes.py b/packages/syft/src/syft/server/routes.py index e4d6906ae7f..4c56908306c 100644 --- a/packages/syft/src/syft/server/routes.py +++ b/packages/syft/src/syft/server/routes.py @@ -1,21 +1,16 @@ # stdlib import base64 import binascii -from collections.abc import AsyncGenerator import logging +from collections.abc import AsyncGenerator from typing import Annotated +import requests + # third party -from fastapi import APIRouter -from fastapi import Body -from fastapi import Depends -from fastapi import HTTPException -from fastapi import Request -from fastapi import Response -from fastapi.responses import JSONResponse -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import ValidationError -import requests # relative from ..abstract_server import AbstractServer @@ -23,17 +18,14 @@ from ..protocol.data_protocol import PROTOCOL_TYPE from ..serde.deserialize import _deserialize as deserialize from ..serde.serialize import _serialize as serialize -from ..service.context import ServerServiceContext -from ..service.context import UnauthedServiceContext +from ..service.context import ServerServiceContext, UnauthedServiceContext from ..service.metadata.server_metadata import ServerMetadataJSON from ..service.response import SyftError -from ..service.user.user import UserCreate -from ..service.user.user import UserPrivateKey +from ..service.user.user import UserCreate, UserPrivateKey from ..service.user.user_service import UserService from ..types.uid import UID from ..util.telemetry import TRACE_MODE -from .credentials import SyftVerifyKey -from .credentials import UserLoginCredentials +from .credentials import SyftVerifyKey, UserLoginCredentials from .worker import Worker logger = logging.getLogger(__name__) @@ -121,8 +113,7 @@ async def stream_upload(peer_uid: str, url_path: str, request: Request) -> Respo response_class=JSONResponse, ) def root() -> dict[str, str]: - """ - Currently, all service backends must satisfy either of the following requirements to + """Currently, all service backends must satisfy either of the following requirements to pass the HTTP health checks sent to it from the GCE loadbalancer: 1. Respond with a 200 on '/'. The content does not matter. 2. Expose an arbitrary url as a readiness probe on the pods backing the Service. @@ -143,11 +134,11 @@ def syft_metadata_capnp() -> Response: ) def handle_syft_new_api( - user_verify_key: SyftVerifyKey, communication_protocol: PROTOCOL_TYPE + user_verify_key: SyftVerifyKey, communication_protocol: PROTOCOL_TYPE, ) -> Response: return Response( serialize( - worker.get_api(user_verify_key, communication_protocol), to_bytes=True + worker.get_api(user_verify_key, communication_protocol), to_bytes=True, ), media_type="application/octet-stream", ) @@ -155,7 +146,7 @@ def handle_syft_new_api( # get the SyftAPI object @router.get("/api") def syft_new_api( - request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE + request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE, ) -> Response: user_verify_key: SyftVerifyKey = SyftVerifyKey.from_string(verify_key) if TRACE_MODE: @@ -179,7 +170,7 @@ def handle_new_api_call(data: bytes) -> Response: # make a request to the SyftAPI @router.post("/api_call") def syft_new_api_call( - request: Request, data: Annotated[bytes, Depends(get_body)] + request: Request, data: Annotated[bytes, Depends(get_body)], ) -> Response: if TRACE_MODE: with trace.get_tracer(syft_new_api_call.__module__).start_as_current_span( @@ -199,7 +190,7 @@ def handle_login(email: str, password: str, server: AbstractServer) -> Response: method = server.get_service_method(UserService.exchange_credentials) context = UnauthedServiceContext( - server=server, login_credentials=login_credentials + server=server, login_credentials=login_credentials, ) result = method(context=context) @@ -230,7 +221,7 @@ def handle_register(data: bytes, server: AbstractServer) -> Response: if isinstance(result, SyftError): logger.error( - f"Register Error: {result.message}. user={user_create.model_dump()}" + f"Register Error: {result.message}. user={user_create.model_dump()}", ) response = SyftError(message=f"{result.message}") else: @@ -260,7 +251,7 @@ def login( @router.post("/register", name="register", status_code=200) def register( - request: Request, data: Annotated[bytes, Depends(get_body)] + request: Request, data: Annotated[bytes, Depends(get_body)], ) -> Response: if TRACE_MODE: with trace.get_tracer(register.__module__).start_as_current_span( diff --git a/packages/syft/src/syft/server/run.py b/packages/syft/src/syft/server/run.py index 9fc4878a983..b02e738c276 100644 --- a/packages/syft/src/syft/server/run.py +++ b/packages/syft/src/syft/server/run.py @@ -2,8 +2,7 @@ import argparse # relative -from ..orchestra import Orchestra -from ..orchestra import ServerHandle +from ..orchestra import Orchestra, ServerHandle def str_to_bool(bool_str: str | None) -> bool: @@ -18,7 +17,7 @@ def run() -> ServerHandle | None: parser = argparse.ArgumentParser() parser.add_argument("command", help="command: launch", type=str, default="none") parser.add_argument( - "--name", help="server name", type=str, default="syft-server", dest="name" + "--name", help="server name", type=str, default="syft-server", dest="name", ) parser.add_argument( "--server-type", @@ -36,7 +35,7 @@ def run() -> ServerHandle | None: ) parser.add_argument( - "--port", help="port for binding", type=int, default=8080, dest="port" + "--port", help="port for binding", type=int, default=8080, dest="port", ) parser.add_argument( "--dev-mode", diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 9284f9fb4c0..fb9092b449c 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -1,128 +1,133 @@ # future from __future__ import annotations -# stdlib -from collections import OrderedDict -from collections.abc import Callable -from datetime import MINYEAR -from datetime import datetime -from functools import partial import hashlib import json import logging import os -from pathlib import Path import subprocess # nosec import sys -from time import sleep import traceback -from typing import Any -from typing import cast + +# stdlib +from collections import OrderedDict +from collections.abc import Callable +from datetime import MINYEAR, datetime +from functools import partial +from pathlib import Path +from time import sleep +from typing import Any, cast # third party from nacl.signing import SigningKey -from result import Err -from result import Result +from result import Err, Result # relative from .. import __version__ -from ..abstract_server import AbstractServer -from ..abstract_server import ServerSideType -from ..abstract_server import ServerType -from ..client.api import SignedSyftAPICall -from ..client.api import SyftAPI -from ..client.api import SyftAPICall -from ..client.api import SyftAPIData -from ..client.api import debox_signed_syftapicall_response +from ..abstract_server import AbstractServer, ServerSideType, ServerType +from ..client.api import ( + SignedSyftAPICall, + SyftAPI, + SyftAPICall, + SyftAPIData, + debox_signed_syftapicall_response, +) from ..client.client import SyftClient from ..exceptions.exception import PySyftException -from ..protocol.data_protocol import PROTOCOL_TYPE -from ..protocol.data_protocol import get_data_protocol -from ..service.action.action_object import Action -from ..service.action.action_object import ActionObject -from ..service.action.action_store import ActionStore -from ..service.action.action_store import DictActionStore -from ..service.action.action_store import MongoActionStore -from ..service.action.action_store import SQLiteActionStore +from ..protocol.data_protocol import PROTOCOL_TYPE, get_data_protocol +from ..service.action.action_object import Action, ActionObject +from ..service.action.action_store import ( + ActionStore, + DictActionStore, + MongoActionStore, + SQLiteActionStore, +) from ..service.blob_storage.service import BlobStorageService from ..service.code.user_code_service import UserCodeService from ..service.code.user_code_stash import UserCodeStash -from ..service.context import AuthedServiceContext -from ..service.context import ServerServiceContext -from ..service.context import UnauthedServiceContext -from ..service.context import UserLoginCredentials -from ..service.job.job_stash import Job -from ..service.job.job_stash import JobStash -from ..service.job.job_stash import JobStatus -from ..service.job.job_stash import JobType +from ..service.context import ( + AuthedServiceContext, + ServerServiceContext, + UnauthedServiceContext, + UserLoginCredentials, +) +from ..service.job.job_stash import Job, JobStash, JobStatus, JobType from ..service.metadata.server_metadata import ServerMetadata from ..service.network.network_service import NetworkService from ..service.network.utils import PeerHealthCheckTask from ..service.notifier.notifier_service import NotifierService -from ..service.queue.base_queue import AbstractMessageHandler -from ..service.queue.base_queue import QueueConsumer -from ..service.queue.base_queue import QueueProducer -from ..service.queue.queue import APICallMessageHandler -from ..service.queue.queue import QueueManager -from ..service.queue.queue_stash import APIEndpointQueueItem -from ..service.queue.queue_stash import ActionQueueItem -from ..service.queue.queue_stash import QueueItem -from ..service.queue.queue_stash import QueueStash -from ..service.queue.zmq_queue import QueueConfig -from ..service.queue.zmq_queue import ZMQClientConfig -from ..service.queue.zmq_queue import ZMQQueueConfig +from ..service.queue.base_queue import ( + AbstractMessageHandler, + QueueConsumer, + QueueProducer, +) +from ..service.queue.queue import APICallMessageHandler, QueueManager +from ..service.queue.queue_stash import ( + ActionQueueItem, + APIEndpointQueueItem, + QueueItem, + QueueStash, +) +from ..service.queue.zmq_queue import QueueConfig, ZMQClientConfig, ZMQQueueConfig from ..service.response import SyftError -from ..service.service import AbstractService -from ..service.service import ServiceConfigRegistry -from ..service.service import UserServiceConfigRegistry -from ..service.settings.settings import ServerSettings -from ..service.settings.settings import ServerSettingsUpdate +from ..service.service import ( + AbstractService, + ServiceConfigRegistry, + UserServiceConfigRegistry, +) +from ..service.settings.settings import ServerSettings, ServerSettingsUpdate from ..service.settings.settings_stash import SettingsStash -from ..service.user.user import User -from ..service.user.user import UserCreate -from ..service.user.user import UserView +from ..service.user.user import User, UserCreate, UserView from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService from ..service.user.user_stash import UserStash -from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG -from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME -from ..service.worker.utils import create_default_image +from ..service.worker.utils import ( + DEFAULT_WORKER_IMAGE_TAG, + DEFAULT_WORKER_POOL_NAME, + create_default_image, +) from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool import WorkerPool from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_pool_stash import SyftWorkerPoolStash from ..service.worker.worker_stash import WorkerStash from ..store.blob_storage import BlobStorageConfig -from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig -from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig +from ..store.blob_storage.on_disk import ( + OnDiskBlobStorageClientConfig, + OnDiskBlobStorageConfig, +) from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig from ..store.linked_obj import LinkedObject from ..store.mongo_document_store import MongoStoreConfig -from ..store.sqlite_document_store import SQLiteStoreClientConfig -from ..store.sqlite_document_store import SQLiteStoreConfig +from ..store.sqlite_document_store import SQLiteStoreClientConfig, SQLiteStoreConfig from ..types.datetime import DATETIME_FORMAT from ..types.syft_metaclass import Empty -from ..types.syft_object import Context -from ..types.syft_object import PartialSyftObject -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftObject +from ..types.syft_object import ( + SYFT_OBJECT_VERSION_1, + Context, + PartialSyftObject, + SyftObject, +) from ..types.uid import UID from ..util.experimental_flags import flags from ..util.telemetry import instrument -from ..util.util import get_dev_mode -from ..util.util import get_env -from ..util.util import get_queue_address -from ..util.util import random_name -from ..util.util import str_to_bool -from ..util.util import thread_ident -from .credentials import SyftSigningKey -from .credentials import SyftVerifyKey +from ..util.util import ( + get_dev_mode, + get_env, + get_queue_address, + random_name, + str_to_bool, + thread_ident, +) +from .credentials import SyftSigningKey, SyftVerifyKey from .service_registry import ServiceRegistry -from .utils import get_named_server_uid -from .utils import get_temp_dir_for_server -from .utils import remove_temp_dir_for_server +from .utils import ( + get_named_server_uid, + get_temp_dir_for_server, + remove_temp_dir_for_server, +) from .worker_settings import WorkerSettings logger = logging.getLogger(__name__) @@ -200,8 +205,8 @@ def get_default_bucket_name() -> str: def get_default_worker_pool_count(server: Server) -> int: return int( get_env( - "DEFAULT_WORKER_POOL_COUNT", server.queue_config.client_config.n_consumers - ) + "DEFAULT_WORKER_POOL_COUNT", server.queue_config.client_config.n_consumers, + ), ) @@ -460,14 +465,14 @@ def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig: client_config=SQLiteStoreClientConfig( filename=file_name, path=path, - ) + ), ) return DictStoreConfig() def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: if config is None: client_config = OnDiskBlobStorageClientConfig( - base_directory=self.get_temp_dir("blob") + base_directory=self.get_temp_dir("blob"), ) config_ = OnDiskBlobStorageConfig( client_config=client_config, @@ -484,7 +489,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: if isinstance(config, SeaweedFSConfig) and self.signing_key: blob_storage_service = self.get_service(BlobStorageService) remote_profiles = blob_storage_service.remote_profile_stash.get_all( - credentials=self.signing_key.verify_key, has_permission=True + credentials=self.signing_key.verify_key, has_permission=True, ).ok() for remote_profile in remote_profiles: self.blob_store_config.client_config.remote_profiles[ @@ -499,7 +504,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: ) logger.debug( f"Minimum object size to be saved to the blob storage: " - f"{self.blob_store_config.min_blob_size} (MB)." + f"{self.blob_store_config.min_blob_size} (MB).", ) def run_peer_health_checks(self, context: AuthedServiceContext) -> None: @@ -540,7 +545,7 @@ def create_queue_config( queue_config_ = queue_config elif queue_port is not None or n_consumers > 0 or create_producer: if not create_producer and queue_port is None: - logger.warn("No queue port defined to bind consumers.") + logger.warning("No queue port defined to bind consumers.") queue_config_ = ZMQQueueConfig( client_config=ZMQClientConfig( create_producer=create_producer, @@ -557,7 +562,7 @@ def create_queue_config( def init_queue_manager(self, queue_config: QueueConfig) -> None: MessageHandlers = [APICallMessageHandler] if self.is_subprocess: - return None + return self.queue_manager = QueueManager(config=queue_config) for message_handler in MessageHandlers: @@ -596,7 +601,7 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None: # Create consumer for given worker pool syft_worker_uid = get_syft_worker_uid() logger.info( - f"Running as consumer with uid={syft_worker_uid} service={service_name}" + f"Running as consumer with uid={syft_worker_uid} service={service_name}", ) if syft_worker_uid: @@ -704,7 +709,7 @@ def root_client(self) -> SyftClient: return root_client def _find_klasses_pending_for_migration( - self, object_types: list[SyftObject] + self, object_types: list[SyftObject], ) -> list[SyftObject]: context = AuthedServiceContext( server=self, @@ -722,7 +727,7 @@ def _find_klasses_pending_for_migration( migration_state = migration_state_service.get_state(context, canonical_name) if isinstance(migration_state, SyftError): raise Exception( - f"Failed to get migration state for {canonical_name}. Error: {migration_state}" + f"Failed to get migration state for {canonical_name}. Error: {migration_state}", ) if ( migration_state is not None @@ -739,7 +744,7 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated def find_and_migrate_data( - self, document_store_object_types: list[type[SyftObject]] | None = None + self, document_store_object_types: list[type[SyftObject]] | None = None, ) -> None: context = AuthedServiceContext( server=self, @@ -776,7 +781,7 @@ def get_guest_client(self, verbose: bool = True) -> SyftClient: return client_type guest_client = client_type( - connection=connection, credentials=SyftSigningKey.generate() + connection=connection, credentials=SyftSigningKey.generate(), ) if guest_client.api.refresh_api_callback is not None: guest_client.api.refresh_api_callback() @@ -792,10 +797,10 @@ def __repr__(self) -> str: def post_init(self) -> None: context = AuthedServiceContext( - server=self, credentials=self.verify_key, role=ServiceRole.ADMIN + server=self, credentials=self.verify_key, role=ServiceRole.ADMIN, ) AuthServerContextRegistry.set_server_context( - server_uid=self.id, user_verify_key=self.verify_key, context=context + server_uid=self.id, user_verify_key=self.verify_key, context=context, ) if "usercodeservice" in self.service_path_map: @@ -890,15 +895,13 @@ def _get_service_method_from_path(self, path: str) -> Callable: return getattr(service_obj, method_name) def get_temp_dir(self, dir_name: str = "") -> Path: - """ - Get a temporary directory unique to the server. + """Get a temporary directory unique to the server. Provide all dbs, blob dirs, and locks using this directory. """ return get_temp_dir_for_server(self.id, dir_name) def remove_temp_dir(self) -> None: - """ - Remove the temporary directory for this server. + """Remove the temporary directory for this server. """ remove_temp_dir_for_server(self.id) @@ -920,7 +923,7 @@ def settings(self) -> ServerSettings: settings = settings_stash.get_all(self.signing_key.verify_key) if settings.is_err(): raise ValueError( - f"Cannot get server settings for '{self.name}'. Error: {settings.err()}" + f"Cannot get server settings for '{self.name}'. Error: {settings.err()}", ) if settings.is_ok() and len(settings.ok()) > 0: settings = settings.ok()[0] @@ -974,7 +977,7 @@ def verify_key(self) -> SyftVerifyKey: def __hash__(self) -> int: return hash(self.id) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return False @@ -984,7 +987,7 @@ def __eq__(self, other: Any) -> bool: return True def await_future( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> QueueItem | None | SyftError: # stdlib @@ -1002,7 +1005,7 @@ def await_future( sleep(0.1) def resolve_future( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> QueueItem | None | SyftError: result = self.queue_stash.pop_on_complete(credentials, uid) @@ -1016,7 +1019,7 @@ def resolve_future( return result.err() def forward_message( - self, api_call: SyftAPICall | SignedSyftAPICall + self, api_call: SyftAPICall | SignedSyftAPICall, ) -> Result | QueueItem | SyftObject | SyftError | Any: server_uid = api_call.message.server_uid if "networkservice" not in self.service_path_map: @@ -1024,7 +1027,7 @@ def forward_message( message=( "Server has no network service so we can't " f"forward this message to {server_uid}" - ) + ), ) client = None @@ -1042,14 +1045,14 @@ def forward_message( client = self.peer_client_cache[peer_cache_key] else: context = AuthedServiceContext( - server=self, credentials=api_call.credentials + server=self, credentials=api_call.credentials, ) client = peer.client_with_context(context=context) if client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{peer.id}. Error: {client.err()}" + f"{peer.id}. Error: {client.err()}", ) client = client.ok() @@ -1081,7 +1084,7 @@ def forward_message( def get_role_for_credentials(self, credentials: SyftVerifyKey) -> ServiceRole: role = self.get_service("userservice").get_role_for_credentials( - credentials=credentials + credentials=credentials, ) return role @@ -1093,7 +1096,7 @@ def handle_api_call( ) -> Result[SignedSyftAPICall, Err]: # Get the result result = self.handle_api_call_with_unsigned_result( - api_call, job_id=job_id, check_call_location=check_call_location + api_call, job_id=job_id, check_call_location=check_call_location, ) # Sign the result signed_result = SyftAPIData(data=result).sign(self.signing_key) @@ -1108,18 +1111,17 @@ def handle_api_call_with_unsigned_result( ) -> Result | QueueItem | SyftObject | SyftError: if self.required_signed_calls and isinstance(api_call, SyftAPICall): return SyftError( - message=f"You sent a {type(api_call)}. This server requires SignedSyftAPICall." + message=f"You sent a {type(api_call)}. This server requires SignedSyftAPICall.", ) - else: - if not api_call.is_valid: - return SyftError(message="Your message signature is invalid") + elif not api_call.is_valid: + return SyftError(message="Your message signature is invalid") if api_call.message.server_uid != self.id and check_call_location: return self.forward_message(api_call=api_call) if api_call.message.path == "queue": return self.resolve_future( - credentials=api_call.credentials, uid=api_call.message.kwargs["uid"] + credentials=api_call.credentials, uid=api_call.message.kwargs["uid"], ) if api_call.message.path == "metadata": @@ -1148,11 +1150,11 @@ def handle_api_call_with_unsigned_result( if ServiceConfigRegistry.path_exists(api_call.path): return SyftError( message=f"As a `{role}`, " - f"you have no access to: {api_call.path}" + f"you have no access to: {api_call.path}", ) else: return SyftError( - message=f"API call not in registered services: {api_call.path}" + message=f"API call not in registered services: {api_call.path}", ) _private_api_path = user_config_registry.private_path_for(api_call.path) @@ -1164,7 +1166,7 @@ def handle_api_call_with_unsigned_result( return e.handle() except Exception: result = SyftError( - message=f"Exception calling {api_call.path}. {traceback.format_exc()}" + message=f"Exception calling {api_call.path}. {traceback.format_exc()}", ) else: return self.add_api_call_to_queue(api_call) @@ -1222,7 +1224,7 @@ def add_api_endpoint_execution_to_queue( ) def get_worker_pool_ref_by_name( - self, credentials: SyftVerifyKey, worker_pool_name: str | None = None + self, credentials: SyftVerifyKey, worker_pool_name: str | None = None, ) -> LinkedObject | SyftError: # If worker pool id is not set, then use default worker pool # Else, get the worker pool for given uid @@ -1257,7 +1259,7 @@ def add_action_to_queue( # Extract worker pool id from user code if action.user_code_id is not None: result = self.user_code_stash.get_by_uid( - credentials=credentials, uid=action.user_code_id + credentials=credentials, uid=action.user_code_id, ) # If result is Ok, then user code object exists @@ -1266,7 +1268,7 @@ def add_action_to_queue( worker_pool_name = user_code.worker_pool_name worker_pool_ref = self.get_worker_pool_ref_by_name( - credentials, worker_pool_name + credentials, worker_pool_name, ) if isinstance(worker_pool_ref, SyftError): return worker_pool_ref @@ -1283,7 +1285,7 @@ def add_action_to_queue( worker_pool=worker_pool_ref, # set worker pool reference as part of queue item ) user_id = self.get_service("UserService").get_user_id_for_credentials( - credentials + credentials, ) return self.add_queueitem_to_queue( @@ -1369,11 +1371,11 @@ def _sort_jobs(self, jobs: list[Job]) -> list[Job]: return jobs def _get_existing_user_code_jobs( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, user_code_id: UID, ) -> list[Job] | SyftError: job_service = self.get_service("jobservice") jobs = job_service.get_by_user_code_id( - context=context, user_code_id=user_code_id + context=context, user_code_id=user_code_id, ) if isinstance(jobs, SyftError): @@ -1391,11 +1393,11 @@ def _is_usercode_call_on_owned_kwargs( return False user_code_service = self.get_service("usercodeservice") return user_code_service.is_execution_on_owned_args( - context, user_code_id, api_call.kwargs + context, user_code_id, api_call.kwargs, ) def add_api_call_to_queue( - self, api_call: SyftAPICall, parent_job_id: UID | None = None + self, api_call: SyftAPICall, parent_job_id: UID | None = None, ) -> Job | SyftError: unsigned_call = api_call if isinstance(api_call, SignedSyftAPICall): @@ -1426,7 +1428,7 @@ def add_api_call_to_queue( user.mock_execution_permission or context.role == ServiceRole.ADMIN ) is_usercode_call_on_owned_kwargs = self._is_usercode_call_on_owned_kwargs( - context, unsigned_call, user_code_id + context, unsigned_call, user_code_id, ) # Low side does not execute jobs, unless this is a mock execution if ( @@ -1442,12 +1444,12 @@ def add_api_call_to_queue( from ..util.util import prompt_warning_message prompt_warning_message( - "There are existing jobs for this user code, returning the latest one" + "There are existing jobs for this user code, returning the latest one", ) return existing_jobs[-1] else: return SyftError( - message="Please wait for the admin to allow the execution of this code" + message="Please wait for the admin to allow the execution of this code", ) elif ( @@ -1455,11 +1457,11 @@ def add_api_call_to_queue( and not is_execution_on_owned_kwargs_allowed ): return SyftError( - message="You do not have the permissions for mock execution, please contact the admin" + message="You do not have the permissions for mock execution, please contact the admin", ) return self.add_action_to_queue( - action, api_call.credentials, parent_job_id=parent_job_id + action, api_call.credentials, parent_job_id=parent_job_id, ) else: @@ -1508,7 +1510,7 @@ def get_default_worker_pool(self) -> WorkerPool | None | SyftError: def get_worker_pool_by_name(self, name: str) -> WorkerPool | None | SyftError: result = self.pool_stash.get_by_name( - credentials=self.verify_key, pool_name=name + credentials=self.verify_key, pool_name=name, ) if result.is_err(): return SyftError(message=f"{result.err()}") @@ -1527,13 +1529,13 @@ def get_api( ) def get_method_with_context( - self, function: Callable, context: ServerServiceContext + self, function: Callable, context: ServerServiceContext, ) -> Callable: method = self.get_service_method(function) return partial(method, context=context) def get_unauthed_context( - self, login_credentials: UserLoginCredentials + self, login_credentials: UserLoginCredentials, ) -> ServerServiceContext: return UnauthedServiceContext(server=self, login_credentials=login_credentials) @@ -1542,7 +1544,7 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings | None: settings_stash = SettingsStash(store=self.document_store) if self.signing_key is None: logger.debug( - "create_initial_settings failed as there is no signing key" + "create_initial_settings failed as there is no signing key", ) return None settings_exists = settings_stash.get_all(self.signing_key.verify_key).ok() @@ -1551,15 +1553,15 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings | None: if server_settings.__version__ != ServerSettings.__version__: context = Context() server_settings = server_settings.migrate_to( - ServerSettings.__version__, context + ServerSettings.__version__, context, ) res = settings_stash.delete_by_uid( - self.signing_key.verify_key, server_settings.id + self.signing_key.verify_key, server_settings.id, ) if res.is_err(): raise Exception(res.value) res = settings_stash.set( - self.signing_key.verify_key, server_settings + self.signing_key.verify_key, server_settings, ) if res.is_err(): raise Exception(res.value) @@ -1588,7 +1590,7 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings | None: notifications_enabled=False, ) result = settings_stash.set( - credentials=self.signing_key.verify_key, settings=new_settings + credentials=self.signing_key.verify_key, settings=new_settings, ) if result.is_ok(): return result.ok() @@ -1607,7 +1609,7 @@ def create_admin_new( try: user_stash = UserStash(store=server.document_store) row_exists = user_stash.get_by_email( - credentials=server.signing_key.verify_key, email=email + credentials=server.signing_key.verify_key, email=email, ).ok() if row_exists: return None @@ -1695,7 +1697,7 @@ def create_default_worker_pool(server: Server) -> SyftError | None: if isinstance(default_worker_pool, SyftError): logger.error( f"Failed to get default worker pool {default_pool_name}. " - f"Error: {default_worker_pool.message}" + f"Error: {default_worker_pool.message}", ) return default_worker_pool @@ -1732,7 +1734,7 @@ def create_default_worker_pool(server: Server) -> SyftError | None: f"name={default_pool_name} " f"workers={worker_count} " f"image_uid={default_image.id} " - f"in_memory={server.in_memory_workers}. " + f"in_memory={server.in_memory_workers}. ", ) if default_worker_pool is None: worker_to_add_ = worker_count @@ -1748,11 +1750,11 @@ def create_default_worker_pool(server: Server) -> SyftError | None: else: # Else add a worker to existing worker pool worker_to_add_ = max(default_worker_pool.max_count, worker_count) - len( - default_worker_pool.worker_list + default_worker_pool.worker_list, ) if worker_to_add_ > 0: add_worker_method = server.get_service_method( - SyftWorkerPoolService.add_workers + SyftWorkerPoolService.add_workers, ) result = add_worker_method( context=context, @@ -1771,7 +1773,7 @@ def create_default_worker_pool(server: Server) -> SyftError | None: if container_status.error: logger.error( f"Failed to create container: Worker: {container_status.worker}," - f"Error: {container_status.error}" + f"Error: {container_status.error}", ) return None diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index d7c3555f10c..82e0da255ff 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -1,10 +1,8 @@ # stdlib -from collections.abc import Callable -from dataclasses import dataclass -from dataclasses import field import typing -from typing import Any -from typing import TYPE_CHECKING +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any # relative from ..serde.serializable import serializable @@ -82,7 +80,7 @@ class ServiceRegistry: services: list[AbstractService] = field(default_factory=list, init=False) service_path_map: dict[str, AbstractService] = field( - default_factory=dict, init=False + default_factory=dict, init=False, ) @classmethod diff --git a/packages/syft/src/syft/server/utils.py b/packages/syft/src/syft/server/utils.py index ac7425b143b..54d2116cf72 100644 --- a/packages/syft/src/syft/server/utils.py +++ b/packages/syft/src/syft/server/utils.py @@ -3,24 +3,22 @@ # stdlib import os -from pathlib import Path import shutil import tempfile +from pathlib import Path # relative from ..types.uid import UID def get_named_server_uid(name: str) -> UID: - """ - Get a unique identifier for a named server. + """Get a unique identifier for a named server. """ return UID.with_seed(name) def get_temp_dir_for_server(server_uid: UID, dir_name: str = "") -> Path: - """ - Get a temporary directory unique to the server. + """Get a temporary directory unique to the server. Provide all dbs, blob dirs, and locks using this directory. """ root = os.getenv("SYFT_TEMP_ROOT", "syft") @@ -30,8 +28,7 @@ def get_temp_dir_for_server(server_uid: UID, dir_name: str = "") -> Path: def remove_temp_dir_for_server(server_uid: UID) -> None: - """ - Remove the temporary directory for this server. + """Remove the temporary directory for this server. """ rootdir = get_temp_dir_for_server(server_uid) if rootdir.exists(): diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 953d19a4c2e..1fdfbe70e0a 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -1,26 +1,25 @@ # stdlib -from collections.abc import Callable import logging import multiprocessing import multiprocessing.synchronize import os -from pathlib import Path import platform import signal import subprocess # nosec import sys import time +from collections.abc import Callable +from pathlib import Path from typing import Any -# third party -from fastapi import APIRouter -from fastapi import FastAPI -from pydantic_settings import BaseSettings -from pydantic_settings import SettingsConfigDict import requests -from starlette.middleware.cors import CORSMiddleware import uvicorn +# third party +from fastapi import APIRouter, FastAPI +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.middleware.cors import CORSMiddleware + # relative from ..abstract_server import ServerSideType from ..client.client import API_PATH @@ -32,8 +31,7 @@ from .gateway import Gateway from .routes import make_routes from .server import ServerType -from .utils import get_named_server_uid -from .utils import remove_temp_dir_for_server +from .utils import get_named_server_uid, remove_temp_dir_for_server if os_name() == "macOS": # needed on MacOS to prevent [__NSCFConstantString initialize] may have been in @@ -71,7 +69,7 @@ def app_factory() -> FastAPI: } if settings.server_type not in worker_classes: raise NotImplementedError( - f"server_type: {settings.server_type} is not supported" + f"server_type: {settings.server_type} is not supported", ) worker_class = worker_classes[settings.server_type] @@ -79,7 +77,7 @@ def app_factory() -> FastAPI: if settings.dev_mode: print( f"WARN: private key is based on server name: {settings.name} in dev_mode. " - "Don't run this in production." + "Don't run this in production.", ) worker = worker_class.named(**kwargs) else: diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index b1ea8a7389f..36e32bf0720 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -5,16 +5,13 @@ from typing_extensions import Self # relative -from ..abstract_server import AbstractServer -from ..abstract_server import ServerSideType -from ..abstract_server import ServerType +from ..abstract_server import AbstractServer, ServerSideType, ServerType from ..serde.serializable import serializable from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig from ..store.document_store import StoreConfig -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftObject +from ..types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ..types.uid import UID diff --git a/packages/syft/src/syft/service/action/action_data_empty.py b/packages/syft/src/syft/service/action/action_data_empty.py index 260c6f6d06b..ced798b6b3c 100644 --- a/packages/syft/src/syft/service/action/action_data_empty.py +++ b/packages/syft/src/syft/service/action/action_data_empty.py @@ -6,8 +6,7 @@ # relative from ...serde.serializable import serializable -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.uid import UID diff --git a/packages/syft/src/syft/service/action/action_endpoint.py b/packages/syft/src/syft/service/action/action_endpoint.py index f3be5322191..011aef67603 100644 --- a/packages/syft/src/syft/service/action/action_endpoint.py +++ b/packages/syft/src/syft/service/action/action_endpoint.py @@ -2,14 +2,12 @@ from __future__ import annotations # stdlib -from enum import Enum -from enum import auto +from enum import Enum, auto from typing import Any # relative from ...serde.serializable import serializable -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.uid import UID from ..context import AuthedServiceContext @@ -54,7 +52,7 @@ def private(self, *args: Any, **kwargs: Any) -> Any: ) def __call_function( - self, call_mode: EXECUTION_MODE, *args: Any, **kwargs: Any + self, call_mode: EXECUTION_MODE, *args: Any, **kwargs: Any, ) -> Any: self.context = self.__check_context() endpoint_service = self.context.server.get_service("apiservice") @@ -69,7 +67,7 @@ def __call_function( __endpoint_mode = endpoint_service.execute_server_side_endpoint_by_id return __endpoint_mode( - *args, context=self.context, endpoint_uid=self.endpoint_id, **kwargs + *args, context=self.context, endpoint_uid=self.endpoint_id, **kwargs, ) def __check_context(self) -> AuthedServiceContext: diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index bbad29396b9..1d7f99c528c 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1,65 +1,47 @@ # future from __future__ import annotations -# stdlib -from collections.abc import Callable -from collections.abc import Iterable -from enum import Enum import inspect -from io import BytesIO import logging -from pathlib import Path import sys import threading import time import traceback import types -from typing import Any -from typing import ClassVar -from typing import TYPE_CHECKING + +# stdlib +from collections.abc import Callable, Iterable +from enum import Enum +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar # third party -from pydantic import ConfigDict -from pydantic import Field -from pydantic import field_validator -from pydantic import model_validator -from result import Err -from result import Ok -from result import Result +from pydantic import ConfigDict, Field, field_validator, model_validator +from result import Err, Ok, Result from typing_extensions import Self # relative -from ...client.api import APIRegistry -from ...client.api import SyftAPI -from ...client.api import SyftAPICall +from ...client.api import APIRegistry, SyftAPI, SyftAPICall from ...client.client import SyftClient from ...serde.serializable import serializable from ...serde.serialize import _serialize as serialize from ...server.credentials import SyftVerifyKey from ...service.blob_storage.util import can_upload_to_blob_storage -from ...service.response import SyftError -from ...service.response import SyftSuccess -from ...service.response import SyftWarning +from ...service.response import SyftError, SyftSuccess, SyftWarning from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftBaseObject -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftBaseObject, SyftObject from ...types.syncable_object import SyncableSyftObject -from ...types.uid import LineageID -from ...types.uid import UID +from ...types.uid import UID, LineageID from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..response import SyftException from ..service import from_api_or_context -from .action_data_empty import ActionDataEmpty -from .action_data_empty import ActionDataLink -from .action_data_empty import ObjectNotReady +from .action_data_empty import ActionDataEmpty, ActionDataLink, ObjectNotReady from .action_permissions import ActionPermission -from .action_types import action_type_for_object -from .action_types import action_type_for_type -from .action_types import action_types +from .action_types import action_type_for_object, action_type_for_type, action_types logger = logging.getLogger(__name__) @@ -96,7 +78,8 @@ def repr_cls(c: Any) -> str: class Action(SyftObject): """Serializable Action object. - Parameters: + Parameters + ---------- path: str The path of the Type of the remote object. op: str @@ -109,6 +92,7 @@ class Action(SyftObject): `op` kwargs result_id: Optional[LineageID] Extended UID of the resulted SyftObject + """ __canonical_name__ = "Action" @@ -169,7 +153,7 @@ def syft_history_hash(self) -> int: @classmethod def syft_function_action_from_kwargs_and_id( - cls, kwargs: dict[str, Any], user_code_id: UID + cls, kwargs: dict[str, Any], user_code_id: UID, ) -> Self: kwarg_ids = {} for k, v in kwargs.items(): @@ -220,7 +204,7 @@ def repr_uid(_id: LineageID) -> str: arg_repr = ", ".join([repr_uid(x) for x in self.args]) kwargs_repr = ", ".join( - [f"{key}={repr_uid(value)}" for key, value in self.kwargs.items()] + [f"{key}={repr_uid(value)}" for key, value in self.kwargs.items()], ) _coll_repr_ = ( f"[{repr_uid(self.remote_self)}]" if self.remote_self is not None else "" @@ -428,11 +412,12 @@ class PreHookContext(SyftBaseObject): def make_action_side_effect( - context: PreHookContext, *args: Any, **kwargs: Any + context: PreHookContext, *args: Any, **kwargs: Any, ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: """Create a new action from context_op_name, and add it to the PreHookContext - Parameters: + Parameters + ---------- context: PreHookContext PreHookContext object *args: @@ -442,6 +427,7 @@ def make_action_side_effect( Returns: - Ok[[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success - Err[str] on failure + """ try: action = context.obj.syft_make_action_with_self( @@ -468,7 +454,7 @@ def set_trace_result_for_current_thread( client: SyftClient, ) -> None: cls.__result_registry__[threading.get_ident()] = TraceResult( - client=client, is_tracing=True + client=client, is_tracing=True, ) @classmethod @@ -496,7 +482,7 @@ class TraceResult(SyftBaseModel): def trace_action_side_effect( - context: PreHookContext, *args: Any, **kwargs: Any + context: PreHookContext, *args: Any, **kwargs: Any, ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: action = context.action if action is not None and TraceResultRegistry.current_thread_is_tracing(): @@ -540,7 +526,7 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: def send_action_side_effect( - context: PreHookContext, *args: Any, **kwargs: Any + context: PreHookContext, *args: Any, **kwargs: Any, ) -> Result[Ok[tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]], Err[str]]: """Create a new action from the context.op_name, and execute it on the remote server.""" try: @@ -561,17 +547,18 @@ def send_action_side_effect( context.result_twin_type = action_result.syft_twin_type except Exception as e: return Err( - f"send_action_side_effect failed with {e}\n {traceback.format_exc()}" + f"send_action_side_effect failed with {e}\n {traceback.format_exc()}", ) return Ok((context, args, kwargs)) def propagate_server_uid( - context: PreHookContext, op: str, result: Any + context: PreHookContext, op: str, result: Any, ) -> Result[Ok[Any], Err[str]]: """Patch the result to include the syft_server_uid - Parameters: + Parameters + ---------- context: PreHookContext PreHookContext object op: str @@ -581,9 +568,10 @@ def propagate_server_uid( Returns: - Ok[[result] on success - Err[str] on failure + """ if context.op_name in dont_make_side_effects or not hasattr( - context.obj, "syft_server_uid" + context.obj, "syft_server_uid", ): return Ok(result) @@ -591,7 +579,7 @@ def propagate_server_uid( syft_server_uid = getattr(context.obj, "syft_server_uid", None) if syft_server_uid is None: raise RuntimeError( - "Can't proagate server_uid because parent doesnt have one" + "Can't proagate server_uid because parent doesnt have one", ) if op not in context.obj._syft_dont_wrap_attrs(): @@ -734,7 +722,7 @@ def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: if cmp: diff_attr = AttrDiff( - attr_name="syft_action_data", low_attr=low_data, high_attr=high_data + attr_name="syft_action_data", low_attr=low_data, high_attr=high_data, ) diff_attrs.append(diff_attr) return diff_attrs @@ -764,11 +752,11 @@ def reload_cache(self) -> SyftError | None: if blob_storage_read_method is not None: blob_retrieval_object = blob_storage_read_method( - uid=self.syft_blob_storage_entry_id + uid=self.syft_blob_storage_entry_id, ) if isinstance(blob_retrieval_object, SyftError): logger.error( - f"Could not fetch actionobject data: {blob_retrieval_object}" + f"Could not fetch actionobject data: {blob_retrieval_object}", ) return blob_retrieval_object # relative @@ -790,21 +778,20 @@ def reload_cache(self) -> SyftError | None: return None else: return SyftError( - message="Could not reload cache, could not get read method" + message="Could not reload cache, could not get read method", ) return None def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: # relative - from ...types.blob_storage import BlobFile - from ...types.blob_storage import CreateBlobStorageEntry + from ...types.blob_storage import BlobFile, CreateBlobStorageEntry if not isinstance(data, ActionDataEmpty): if isinstance(data, BlobFile): if not data.uploaded: api = APIRegistry.api_for( - self.syft_server_location, self.syft_client_verify_key + self.syft_server_location, self.syft_client_verify_key, ) data._upload_to_blobstorage_from_api(api) else: @@ -814,12 +801,12 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: syft_client_verify_key=self.syft_client_verify_key, ) if get_metadata is not None and not can_upload_to_blob_storage( - data, get_metadata() + data, get_metadata(), ): self.syft_action_saved_to_blob_store = False return SyftWarning( message=f"The action object {self.id} was not saved to " - f"the blob store but to memory cache since it is small." + f"the blob store but to memory cache since it is small.", ) serialized = serialize(data, to_bytes=True) size = sys.getsizeof(serialized) @@ -848,14 +835,14 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: blob_deposit_object.blob_storage_entry_id ) else: - logger.warn("cannot save to blob storage. allocate_method=None") + logger.warning("cannot save to blob storage. allocate_method=None") self.syft_action_data_type = type(data) self._set_reprs(data) self.syft_has_bool_attr = hasattr(data, "__bool__") else: logger.debug( - "skipping writing action object to store, passed data was empty." + "skipping writing action object to store, passed data was empty.", ) self.syft_action_data_cache = data @@ -863,7 +850,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: return None def _save_to_blob_storage( - self, allow_empty: bool = False + self, allow_empty: bool = False, ) -> SyftError | SyftSuccess | SyftWarning: data = self.syft_action_data if isinstance(data, SyftError): @@ -871,7 +858,7 @@ def _save_to_blob_storage( if isinstance(data, ActionDataEmpty): return SyftError( - message=f"cannot store empty object {self.id} to the blob storage" + message=f"cannot store empty object {self.id} to the blob storage", ) try: @@ -881,7 +868,7 @@ def _save_to_blob_storage( if not TraceResultRegistry.current_thread_is_tracing(): self._clear_cache() return SyftSuccess( - message=f"Saved action object {self.id} to the blob store" + message=f"Saved action object {self.id} to the blob store", ) except Exception as e: raise e @@ -896,7 +883,7 @@ def _set_reprs(self, data: any) -> None: self.syft_action_data_repr_ = truncate_str( data._repr_markdown_() if hasattr(data, "_repr_markdown_") - else data.__repr__() + else data.__repr__(), ) self.syft_action_data_str_ = truncate_str(str(data)) @@ -915,7 +902,7 @@ def syft_lineage_id(self) -> LineageID: @classmethod def __check_action_data(cls, values: dict) -> dict: v = values.get("syft_action_data_cache") - if values.get("syft_action_data_type", None) is None: + if values.get("syft_action_data_type") is None: values["syft_action_data_type"] = type(v) if not isinstance(v, ActionDataEmpty): if inspect.isclass(v): @@ -924,7 +911,7 @@ def __check_action_data(cls, values: dict) -> dict: values["syft_action_data_repr_"] = truncate_str( v._repr_markdown_() if v is not None and hasattr(v, "_repr_markdown_") - else v.__repr__() + else v.__repr__(), ) values["syft_action_data_str_"] = truncate_str(str(v)) values["syft_has_bool_attr"] = hasattr(v, "__bool__") @@ -956,7 +943,7 @@ def syft_get_property(self, obj: Any, method: str) -> Any: def syft_is_property(self, obj: Any, method: str) -> bool: klass_method = getattr(type(obj), method, None) return isinstance(klass_method, property) or inspect.isdatadescriptor( - klass_method + klass_method, ) def syft_eq(self, ext_obj: Self | None) -> bool: @@ -965,25 +952,27 @@ def syft_eq(self, ext_obj: Self | None) -> bool: return self.id.id == ext_obj.id.id def syft_execute_action( - self, action: Action, sync: bool = True + self, action: Action, sync: bool = True, ) -> ActionObjectPointer: """Execute a remote action - Parameters: + Parameters + ---------- action: Action Which action to execute sync: bool Run sync/async - Returns: + Returns + ------- ActionObjectPointer + """ if self.syft_server_uid is None: raise SyftException("Pointers can't execute without a server_uid.") # relative - from ...client.api import APIRegistry - from ...client.api import SyftAPICall + from ...client.api import APIRegistry, SyftAPICall api = APIRegistry.api_for( server_uid=self.syft_server_uid, @@ -1002,14 +991,13 @@ def syft_execute_action( def request(self, client: SyftClient) -> Any | SyftError: # relative - from ..request.request import ActionStoreChange - from ..request.request import SubmitRequest + from ..request.request import ActionStoreChange, SubmitRequest action_object_link = LinkedObject.from_obj( - self, server_uid=self.syft_server_uid + self, server_uid=self.syft_server_uid, ) permission_change = ActionStoreChange( - linked_obj=action_object_link, apply_permission_type=ActionPermission.READ + linked_obj=action_object_link, apply_permission_type=ActionPermission.READ, ) if client.credentials is None: return SyftError(f"{client} has no signing key") @@ -1020,9 +1008,7 @@ def request(self, client: SyftClient) -> Any | SyftError: return client.api.services.request.submit(submit_request) def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: - if self.syft_server_uid is None or self.syft_client_verify_key is None: - return - elif obj.syft_server_uid is not None: + if self.syft_server_uid is None or self.syft_client_verify_key is None or obj.syft_server_uid is not None: return if obj.syft_blob_storage_entry_id is not None: @@ -1062,7 +1048,7 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: ) if api is None: print( - f"failed saving {obj} to blob storage, api is None. You must login to {self.syft_server_location}" + f"failed saving {obj} to blob storage, api is None. You must login to {self.syft_server_location}", ) return else: @@ -1112,7 +1098,8 @@ def syft_make_action( ) -> Action: """Generate new action from the information - Parameters: + Parameters + ---------- path: str The path of the Type of the remote object. op: str @@ -1126,9 +1113,11 @@ def syft_make_action( Returns: Action object - Raises: + Raises + ------ ValueError: For invalid args or kwargs PydanticValidationError: For args and kwargs + """ if args is None: args = [] @@ -1158,7 +1147,8 @@ def syft_make_action_with_self( ) -> Action: """Generate new method action from the current object. - Parameters: + Parameters + ---------- op: str The method to be executed from the remote object. args: List[LineageID] @@ -1168,9 +1158,11 @@ def syft_make_action_with_self( Returns: Action object - Raises: + Raises + ------ ValueError: For invalid args or kwargs PydanticValidationError: For args and kwargs + """ path = self.syft_get_path() return self.syft_make_action( @@ -1183,7 +1175,7 @@ def syft_make_action_with_self( ) def get_sync_dependencies( - self, context: AuthedServiceContext, **kwargs: dict + self, context: AuthedServiceContext, **kwargs: dict, ) -> list[UID]: # type: ignore # relative from ..job.job_stash import Job @@ -1214,12 +1206,15 @@ def syft_remote_method( ) -> Callable: """Generate a Callable object for remote calls. - Parameters: + Parameters + ---------- op: str he method to be executed from the remote object. - Returns: + Returns + ------- A function + """ def wrapper( @@ -1281,7 +1276,7 @@ def refresh_object(self, resolve_nested: bool = True) -> ActionObject | SyftErro ) if api is None: return SyftError( - message=f"api is None. You must login to {self.syft_server_location}" + message=f"api is None. You must login to {self.syft_server_location}", ) res = api.services.action.get(self.id, resolve_nested=resolve_nested) @@ -1313,7 +1308,7 @@ def get(self, block: bool = False) -> Any: else: if not self.has_storage_permission(): prompt_warning_message( - message="This is a placeholder object, the real data lives on a different server and is not synced." + message="This is a placeholder object, the real data lives on a different server and is not synced.", ) nested_res = res.syft_action_data if isinstance(nested_res, ActionObject): @@ -1399,13 +1394,15 @@ def from_obj( ) -> ActionObject: """Create an ActionObject from an existing object. - Parameters: + Parameters + ---------- syft_action_data: Any The object to be converted to a Syft ActionObject id: Optional[UID] Which ID to use for the ActionObject. Optional syft_lineage_id: Optional[LineageID] Which LineageID to use for the ActionObject. Optional + """ if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id: raise ValueError("UID and LineageID should match") @@ -1513,15 +1510,16 @@ def empty( ) -> Self: """Create an ActionObject from a type, using a ActionDataEmpty object - Parameters: + Parameters + ---------- syft_internal_type: Type The Type for which to create a ActionDataEmpty object id: Optional[UID] Which ID to use for the ActionObject. Optional syft_lineage_id: Optional[LineageID] Which LineageID to use for the ActionObject. Optional - """ + """ syft_internal_type = ( type(None) if syft_internal_type is None else syft_internal_type ) @@ -1552,7 +1550,7 @@ def __post_init__(self) -> None: self.syft_post_hooks__[HOOK_ON_POINTERS] = [] api = APIRegistry.api_for( - self.syft_server_location, self.syft_client_verify_key + self.syft_server_location, self.syft_client_verify_key, ) eager_execution_enabled = ( api is not None @@ -1569,14 +1567,14 @@ def __post_init__(self) -> None: self.syft_history_hash = hash(self.id) def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: - """ - Add pre-hooks + """Add pre-hooks Args: + ---- eager_execution: bool: If eager execution is enabled, hooks for tracing and executing the action on remote are added. - """ + """ # this should be a list as orders matters for side_effect in [make_action_side_effect]: if side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: @@ -1591,12 +1589,13 @@ def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: self.syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) def _syft_add_post_hooks__(self, eager_execution: bool) -> None: - """ - Add post-hooks + """Add post-hooks Args: + ---- eager_execution: bool: If eager execution is enabled, hooks for tracing and executing the action on remote are added. + """ if eager_execution: # this should be a list as orders matters @@ -1605,7 +1604,7 @@ def _syft_add_post_hooks__(self, eager_execution: bool) -> None: self.syft_post_hooks__[HOOK_ALWAYS].append(side_effect) def _syft_run_pre_hooks__( - self, context: PreHookContext, name: str, args: Any, kwargs: Any + self, context: PreHookContext, name: str, args: Any, kwargs: Any, ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: """Hooks executed before the actual call""" result_args, result_kwargs = args, kwargs @@ -1640,7 +1639,7 @@ def _syft_run_pre_hooks__( return context, result_args, result_kwargs def _syft_run_post_hooks__( - self, context: PreHookContext, name: str, result: Any + self, context: PreHookContext, name: str, result: Any, ) -> Any: """Hooks executed after the actual call""" new_result = result @@ -1674,7 +1673,7 @@ def _syft_run_post_hooks__( return new_result def _syft_output_action_object( - self, result: Any, context: PreHookContext | None = None + self, result: Any, context: PreHookContext | None = None, ) -> Any: """Wrap the result in an ActionObject""" if issubclass(type(result), ActionObject): @@ -1712,7 +1711,7 @@ def _syft_get_attr_context(self, name: str) -> Any: return context_self def _syft_attr_propagate_ids( - self, context: PreHookContext, name: str, result: Any + self, context: PreHookContext, name: str, result: Any, ) -> Any: """Patch the results with the syft_history_hash, server_uid, and result_id.""" if name in self._syft_dont_wrap_attrs(): @@ -1753,12 +1752,12 @@ def _syft_wrap_attribute_for_bool_on_nonbools(self, name: str) -> Any: """Handle `__getattribute__` for bool casting.""" if name != "__bool__": raise RuntimeError( - "[_wrap_attribute_for_bool_on_nonbools] Use this only for the __bool__ operator" + "[_wrap_attribute_for_bool_on_nonbools] Use this only for the __bool__ operator", ) if self.syft_has_bool_attr: raise RuntimeError( - "[_wrap_attribute_for_bool_on_nonbools] self.syft_action_data already implements the bool operator" + "[_wrap_attribute_for_bool_on_nonbools] self.syft_action_data already implements the bool operator", ) logger.debug("[__getattribute__] Handling bool on nonbools") @@ -1792,7 +1791,7 @@ def _syft_wrap_attribute_for_properties(self, name: str) -> Any: if not self.syft_is_property(context_self, name): raise RuntimeError( - "[_wrap_attribute_for_properties] Use this only on properties" + "[_wrap_attribute_for_properties] Use this only on properties", ) logger.debug(f"[__getattribute__] Handling property {name}") @@ -1806,7 +1805,7 @@ def _syft_wrap_attribute_for_properties(self, name: str) -> Any: context, _, _ = self._syft_run_pre_hooks__(context, name, (), {}) # no input needs to propagate result = self._syft_run_post_hooks__( - context, name, self.syft_get_property(context_self, name) + context, name, self.syft_get_property(context_self, name), ) return self._syft_attr_propagate_ids(context, name, result) @@ -1838,14 +1837,14 @@ def _base_wrapper(*args: Any, **kwargs: Any) -> Any: syft_client_verify_key=self.syft_client_verify_key, ) context, pre_hook_args, pre_hook_kwargs = self._syft_run_pre_hooks__( - context, name, args, kwargs + context, name, args, kwargs, ) if has_action_data_empty(args=args, kwargs=kwargs): result = fake_func(*args, **kwargs) else: original_args, original_kwargs = debox_args_and_kwargs( - pre_hook_args, pre_hook_kwargs + pre_hook_args, pre_hook_kwargs, ) result = original_func(*original_args, **original_kwargs) @@ -1874,7 +1873,7 @@ def wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any: inspect.signature(original_func), ) wrapper.__ipython_inspector_signature_override__ = inspect.signature( - original_func + original_func, ) except Exception: logger.debug(f"name={name} has no signature") @@ -1891,7 +1890,7 @@ def fake_func(*args: Any, **kwargs: Any) -> Any: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) if isinstance( - self.syft_action_data_type, ActionDataEmpty + self.syft_action_data_type, ActionDataEmpty, ) or has_action_data_empty(args=args, kwargs=kwargs): local_func = fake_func else: @@ -1905,7 +1904,7 @@ def fake_func(*args: Any, **kwargs: Any) -> Any: syft_client_verify_key=self.syft_client_verify_key, ) context, pre_hook_args, pre_hook_kwargs = self._syft_run_pre_hooks__( - context, "__setattr__", args, kwargs + context, "__setattr__", args, kwargs, ) original_args, _ = debox_args_and_kwargs(pre_hook_args, pre_hook_kwargs) @@ -1928,9 +1927,11 @@ def __getattribute__(self, name: str) -> Any: * use the syft/_syft prefix for internal methods. * add the method name to the passthrough_attrs. - Parameters: + Parameters + ---------- name: str The name of the attribute to access. + """ # bypass ipython canary verification if name == "_ipython_canary_method_should_not_exist_": @@ -2000,18 +2001,17 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: if isinstance(self.syft_action_data_cache, ActionDataEmpty): data_repr_ = self.syft_action_data_repr_ + elif inspect.isclass(self.syft_action_data_cache): + data_repr_ = repr_cls(self.syft_action_data_cache) else: - if inspect.isclass(self.syft_action_data_cache): - data_repr_ = repr_cls(self.syft_action_data_cache) - else: - data_repr_ = ( - self.syft_action_data_cache._repr_markdown_() - if ( - self.syft_action_data_cache is not None - and hasattr(self.syft_action_data_cache, "_repr_markdown_") - ) - else self.syft_action_data_cache.__repr__() + data_repr_ = ( + self.syft_action_data_cache._repr_markdown_() + if ( + self.syft_action_data_cache is not None + and hasattr(self.syft_action_data_cache, "_repr_markdown_") ) + else self.syft_action_data_cache.__repr__() + ) return f"\n**{res}**\n\n{data_repr_}\n" @@ -2074,10 +2074,10 @@ def __mul__(self, other: Any) -> Any: def __matmul__(self, other: Any) -> Any: return self._syft_output_action_object(self.__matmul__(other)) - def __eq__(self, other: Any) -> Any: + def __eq__(self, other: object) -> Any: return self._syft_output_action_object(self.__eq__(other)) - def __ne__(self, other: Any) -> Any: + def __ne__(self, other: object) -> Any: return self._syft_output_action_object(self.__ne__(other)) def __lt__(self, other: Any) -> Any: @@ -2196,8 +2196,7 @@ def __rrshift__(self, other: Any) -> Any: @serializable() class AnyActionObject(ActionObject): - """ - This is a catch-all class for all objects that are not + """This is a catch-all class for all objects that are not defined in the `action_types` dictionary. """ @@ -2231,7 +2230,7 @@ def debug_original_func(name: str, func: Callable) -> None: def is_action_data_empty(obj: Any) -> bool: return isinstance(obj, AnyActionObject) and issubclass( - obj.syft_action_data_type, ActionDataEmpty + obj.syft_action_data_type, ActionDataEmpty, ) diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 03992eeab07..bdac02f60cf 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -43,7 +43,7 @@ def __init__( @classmethod def from_permission_string( - cls, uid: UID, permission_string: str + cls, uid: UID, permission_string: str, ) -> "ActionObjectPermission": if permission_string.startswith("ALL_"): permission = ActionPermission[permission_string] diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index ab82c80a2b8..8dc59079924 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -1,14 +1,11 @@ # stdlib import importlib import logging -from typing import Any -from typing import cast +from typing import Any, cast # third party import numpy as np -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable @@ -18,37 +15,39 @@ from ...types.twin_object import TwinObject from ...types.uid import UID from ..blob_storage.service import BlobStorageService -from ..code.user_code import UserCode -from ..code.user_code import execute_byte_code +from ..code.user_code import UserCode, execute_byte_code from ..context import AuthedServiceContext -from ..policy.policy import OutputPolicy -from ..policy.policy import retrieve_from_db -from ..response import SyftError -from ..response import SyftSuccess -from ..response import SyftWarning -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import UserLibConfigRegistry -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from ..user.user_roles import ServiceRole +from ..policy.policy import OutputPolicy, retrieve_from_db +from ..response import SyftError, SyftSuccess, SyftWarning +from ..service import ( + SERVICE_TO_TYPES, + TYPE_TO_SERVICE, + AbstractService, + UserLibConfigRegistry, + service_method, +) +from ..user.user_roles import ADMIN_ROLE_LEVEL, GUEST_ROLE_LEVEL, ServiceRole from .action_endpoint import CustomEndpointActionObject -from .action_object import Action -from .action_object import ActionObject -from .action_object import ActionObjectPointer -from .action_object import ActionType -from .action_object import AnyActionObject -from .action_object import TwinMode -from .action_permissions import ActionObjectPermission -from .action_permissions import ActionObjectREAD -from .action_permissions import ActionPermission +from .action_object import ( + Action, + ActionObject, + ActionObjectPointer, + ActionType, + AnyActionObject, + TwinMode, +) +from .action_permissions import ( + ActionObjectPermission, + ActionObjectREAD, + ActionPermission, +) from .action_store import ActionStore from .action_types import action_type_for_type from .numpy import NumpyArrayObject -from .pandas import PandasDataFrameObject # noqa: F401 -from .pandas import PandasSeriesObject # noqa: F401 +from .pandas import ( + PandasDataFrameObject, # noqa: F401 + PandasSeriesObject, # noqa: F401 +) logger = logging.getLogger(__name__) @@ -111,8 +110,7 @@ def is_detached_obj( action_object: ActionObject | TwinObject, ignore_detached_obj: bool = False, ) -> bool: - """ - A detached object is an object that is not yet saved to the blob storage. + """A detached object is an object that is not yet saved to the blob storage. """ if ( isinstance(action_object, TwinObject) @@ -147,7 +145,7 @@ def _set( ) -> Result[ActionObject, str]: if self.is_detached_obj(action_object, ignore_detached_objs): return Err( - "You uploaded an ActionObject that is not yet in the blob storage" + "You uploaded an ActionObject that is not yet in the blob storage", ) """Save an object to the action store""" # 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable @@ -195,10 +193,10 @@ def _set( if action_object.mock_obj.syft_action_saved_to_blob_store: blob_id = action_object.mock_obj.syft_blob_storage_entry_id permission = ActionObjectPermission( - blob_id, ActionPermission.ALL_READ + blob_id, ActionPermission.ALL_READ, ) blob_storage_service: AbstractService = context.server.get_service( - BlobStorageService + BlobStorageService, ) blob_storage_service.stash.add_permission(permission) if has_result_read_permission: @@ -211,7 +209,7 @@ def _set( return result.err() @service_method( - path="action.is_resolved", name="is_resolved", roles=GUEST_ROLE_LEVEL + path="action.is_resolved", name="is_resolved", roles=GUEST_ROLE_LEVEL, ) def is_resolved( self, @@ -224,7 +222,7 @@ def is_resolved( obj = result.ok() if obj.is_link: result = self.resolve_links( - context, obj.syft_action_data.action_object_id.id + context, obj.syft_action_data.action_object_id.id, ) # Checking in case any error occurred if result.is_err(): @@ -242,7 +240,7 @@ def is_resolved( return result @service_method( - path="action.resolve_links", name="resolve_links", roles=GUEST_ROLE_LEVEL + path="action.resolve_links", name="resolve_links", roles=GUEST_ROLE_LEVEL, ) def resolve_links( self, @@ -261,7 +259,7 @@ def resolve_links( # If it's not a leaf if obj.is_link: nested_result = self.resolve_links( - context, obj.syft_action_data.action_object_id.id, twin_mode + context, obj.syft_action_data.action_object_id.id, twin_mode, ) return nested_result @@ -291,7 +289,7 @@ def _get( ) -> Result[ActionObject, str]: """Get an object from the action store""" result = self.store.get( - uid=uid, credentials=context.credentials, has_permission=has_permission + uid=uid, credentials=context.credentials, has_permission=has_permission, ) if result.is_ok() and context.server is not None: obj: TwinObject | ActionObject = result.ok() @@ -306,11 +304,11 @@ def _get( and obj.is_link ): if not self.is_resolved( # type: ignore[unreachable] - context, obj.syft_action_data.action_object_id.id + context, obj.syft_action_data.action_object_id.id, ).ok(): return SyftError(message="This object is not resolved yet.") result = self.resolve_links( - context, obj.syft_action_data.action_object_id.id, twin_mode + context, obj.syft_action_data.action_object_id.id, twin_mode, ) return result if isinstance(obj, TwinObject): @@ -328,15 +326,14 @@ def _get( return result @service_method( - path="action.get_pointer", name="get_pointer", roles=GUEST_ROLE_LEVEL + path="action.get_pointer", name="get_pointer", roles=GUEST_ROLE_LEVEL, ) def get_pointer( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Result[ActionObjectPointer, str]: """Get a pointer from the action store""" - result = self.store.get_pointer( - uid=uid, credentials=context.credentials, server_uid=context.server.id + uid=uid, credentials=context.credentials, server_uid=context.server.id, ) if result.is_ok(): obj = result.ok() @@ -349,7 +346,7 @@ def get_pointer( @service_method(path="action.get_mock", name="get_mock", roles=GUEST_ROLE_LEVEL) def get_mock( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Result[SyftError, SyftObject]: """Get a pointer from the action store""" result = self.store.get_mock(uid=uid) @@ -367,7 +364,7 @@ def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> boo def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool: return self.store.has_permissions( - [ActionObjectREAD(uid=uid, credentials=context.credentials)] + [ActionObjectREAD(uid=uid, credentials=context.credentials)], ) # not a public service endpoint @@ -399,7 +396,7 @@ def _user_code_execute( # Filter input kwargs based on policy filtered_kwargs = input_policy.filter_kwargs( - kwargs=kwargs, context=context, code_item_id=code_item.id + kwargs=kwargs, context=context, code_item_id=code_item.id, ) if filtered_kwargs.is_err(): return filtered_kwargs @@ -446,7 +443,7 @@ def _user_code_execute( # no twins # allow python types from inputpolicy filtered_kwargs = filter_twin_kwargs( - real_kwargs, twin_mode=TwinMode.NONE, allow_python_types=True + real_kwargs, twin_mode=TwinMode.NONE, allow_python_types=True, ) exec_result = execute_byte_code(code_item, filtered_kwargs, context) if output_policy: @@ -459,17 +456,17 @@ def _user_code_execute( user_code_service.update_code_state(context, code_item) if isinstance(exec_result.result, ActionObject): result_action_object = ActionObject.link( - result_id=result_id, pointer_id=exec_result.result.id + result_id=result_id, pointer_id=exec_result.result.id, ) else: result_action_object = wrap_result(result_id, exec_result.result) else: # twins private_kwargs = filter_twin_kwargs( - real_kwargs, twin_mode=TwinMode.PRIVATE, allow_python_types=True + real_kwargs, twin_mode=TwinMode.PRIVATE, allow_python_types=True, ) private_exec_result = execute_byte_code( - code_item, private_kwargs, context + code_item, private_kwargs, context, ) if output_policy: private_exec_result.result = output_policy.apply_to_output( @@ -480,11 +477,11 @@ def _user_code_execute( code_item.output_policy = output_policy # type: ignore user_code_service.update_code_state(context, code_item) result_action_object_private = wrap_result( - result_id, private_exec_result.result + result_id, private_exec_result.result, ) mock_kwargs = filter_twin_kwargs( - real_kwargs, twin_mode=TwinMode.MOCK, allow_python_types=True + real_kwargs, twin_mode=TwinMode.MOCK, allow_python_types=True, ) # relative from .action_data_empty import ActionDataEmpty @@ -493,11 +490,11 @@ def _user_code_execute( mock_exec_result_obj = ActionDataEmpty() else: mock_exec_result = execute_byte_code( - code_item, mock_kwargs, context + code_item, mock_kwargs, context, ) if output_policy: mock_exec_result.result = output_policy.apply_to_output( - context, mock_exec_result.result, update_policy=False + context, mock_exec_result.result, update_policy=False, ) mock_exec_result_obj = mock_exec_result.result @@ -567,7 +564,7 @@ def set_result_to_store( return set_result blob_storage_service: AbstractService = context.server.get_service( - BlobStorageService + BlobStorageService, ) def store_permission( @@ -622,7 +619,7 @@ def execute_plan( return self._get(context, result_id, TwinMode.MOCK, has_permission=True) def call_function( - self, context: AuthedServiceContext, action: Action + self, context: AuthedServiceContext, action: Action, ) -> Result[ActionObject, str] | Err: # run function/class init _user_lib_config_registry = UserLibConfigRegistry.from_user(context.credentials) @@ -633,7 +630,7 @@ def call_function( return execute_callable(self, context, action) else: return Err( - f"Failed executing {action}. You have no permission for {absolute_path}" + f"Failed executing {action}. You have no permission for {absolute_path}", ) def set_attribute( @@ -645,24 +642,24 @@ def set_attribute( args, _ = resolve_action_args(action, context, self) if args.is_err(): return Err( - f"Failed executing action {action}, could not resolve args: {args.err()}" + f"Failed executing action {action}, could not resolve args: {args.err()}", ) else: args = args.ok() if not isinstance(args[0], ActionObject): return Err( - f"Failed executing action {action} setattribute requires a non-twin string as first argument" + f"Failed executing action {action} setattribute requires a non-twin string as first argument", ) name = args[0].syft_action_data # dont do the whole filtering dance with the name args = [args[1]] if isinstance(resolved_self, TwinObject): - # todo, create copy? + # TODO, create copy? private_args = filter_twin_args(args, twin_mode=TwinMode.PRIVATE) private_val = private_args[0] setattr(resolved_self.private.syft_action_data, name, private_val) - # todo: what do we use as data for the mock here? + # TODO: what do we use as data for the mock here? # depending on permisisons? public_args = filter_twin_args(args, twin_mode=TwinMode.MOCK) public_val = public_args[0] @@ -671,12 +668,12 @@ def set_attribute( TwinObject( id=action.result_id, private_obj=ActionObject.from_obj( - resolved_self.private.syft_action_data + resolved_self.private.syft_action_data, ), private_obj_id=action.result_id, mock_obj=ActionObject.from_obj(resolved_self.mock.syft_action_data), mock_obj_id=action.result_id, - ) + ), ) else: # TODO: Implement for twinobject args @@ -686,13 +683,13 @@ def set_attribute( return Ok( ActionObject.from_obj(resolved_self.syft_action_data), ) - # todo: permissions + # TODO: permissions # setattr(resolved_self.syft_action_data, name, val) # val = resolved_self.syft_action_data # result_action_object = Ok(wrap_result(action.result_id, val)) def get_attribute( - self, action: Action, resolved_self: ActionObject | TwinObject + self, action: Action, resolved_self: ActionObject | TwinObject, ) -> Ok[TwinObject | ActionObject]: if isinstance(resolved_self, TwinObject): private_result = getattr(resolved_self.private.syft_action_data, action.op) @@ -704,7 +701,7 @@ def get_attribute( private_obj_id=action.result_id, mock_obj=ActionObject.from_obj(mock_result), mock_obj_id=action.result_id, - ) + ), ) else: val = getattr(resolved_self.syft_action_data, action.op) # type: ignore[unreachable] @@ -727,14 +724,14 @@ def call_method( ) if private_result.is_err(): return Err( - f"Failed executing action {action}, result is an error: {private_result.err()}" + f"Failed executing action {action}, result is an error: {private_result.err()}", ) mock_result = execute_object( - self, context, resolved_self.mock, action, twin_mode=TwinMode.MOCK + self, context, resolved_self.mock, action, twin_mode=TwinMode.MOCK, ) if mock_result.is_err(): return Err( - f"Failed executing action {action}, result is an error: {mock_result.err()}" + f"Failed executing action {action}, result is an error: {mock_result.err()}", ) private_result = private_result.ok() @@ -747,16 +744,15 @@ def call_method( private_obj_id=action.result_id, mock_obj=mock_result, mock_obj_id=action.result_id, - ) + ), ) else: return execute_object(self, context, resolved_self, action) # type:ignore[unreachable] def unwrap_nested_actionobjects( - self, context: AuthedServiceContext, data: Any + self, context: AuthedServiceContext, data: Any, ) -> Any: - """recursively unwraps nested action objects""" - + """Recursively unwraps nested action objects""" if isinstance(data, list): return [self.unwrap_nested_actionobjects(context, obj) for obj in data] if isinstance(data, dict): @@ -773,14 +769,13 @@ def unwrap_nested_actionobjects( nested_res = res.syft_action_data if isinstance(nested_res, ActionObject): raise ValueError( - "More than double nesting of ActionObjects is currently not supported" + "More than double nesting of ActionObjects is currently not supported", ) return nested_res return data def contains_nested_actionobjects(self, data: Any) -> bool: - """ - returns if this is a list/set/dict that contains ActionObjects + """Returns if this is a list/set/dict that contains ActionObjects """ def unwrap_collection(col: set | dict | list) -> [Any]: # type: ignore @@ -831,7 +826,7 @@ def flatten_action_arg(self, context: AuthedServiceContext, arg: UID) -> UID | N @service_method(path="action.execute", name="execute", roles=GUEST_ROLE_LEVEL) def execute( - self, context: AuthedServiceContext, action: Action + self, context: AuthedServiceContext, action: Action, ) -> Result[ActionObject, Err]: """Execute an operation on objects in the action store""" # relative @@ -846,7 +841,7 @@ def execute( # transform lineage ids into ids kwarg_ids[k] = v.id result_action_object = usercode_service._call( - context, action.user_code_id, action.result_id, **kwarg_ids + context, action.user_code_id, action.result_id, **kwarg_ids, ) return result_action_object elif action.action_type == ActionType.FUNCTION: @@ -860,7 +855,7 @@ def execute( ) if resolved_self.is_err(): return Err( - f"Failed executing action {action}, could not resolve self: {resolved_self.err()}" + f"Failed executing action {action}, could not resolve self: {resolved_self.err()}", ) resolved_self = resolved_self.ok() if action.op == "__call__" and resolved_self.syft_action_data_type == Plan: @@ -872,7 +867,7 @@ def execute( return result_action_object elif action.action_type == ActionType.SETATTRIBUTE: result_action_object = self.set_attribute( - context, action, resolved_self + context, action, resolved_self, ) elif action.action_type == ActionType.GETATTRIBUTE: result_action_object = self.get_attribute(action, resolved_self) @@ -883,14 +878,14 @@ def execute( if result_action_object.is_err(): return Err( - f"Failed executing action {action}, result is an error: {result_action_object.err()}" + f"Failed executing action {action}, result is an error: {result_action_object.err()}", ) else: result_action_object = result_action_object.ok() # check if we have read permissions on the result has_result_read_permission = self.has_read_permission_for_action_result( - context, action + context, action, ) result_action_object._set_obj_location_( @@ -903,7 +898,7 @@ def execute( # pass permission information to the action store as extra kwargs context.extra_kwargs = { - "has_result_read_permission": has_result_read_permission + "has_result_read_permission": has_result_read_permission, } if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) @@ -913,13 +908,13 @@ def execute( ) if set_result.is_err(): return Err( - f"Failed executing action {action}, set result is an error: {set_result.err()}" + f"Failed executing action {action}, set result is an error: {set_result.err()}", ) return set_result def has_read_permission_for_action_result( - self, context: AuthedServiceContext, action: Action + self, context: AuthedServiceContext, action: Action, ) -> bool: action_obj_ids = ( action.args + list(action.kwargs.values()) + [action.remote_self] @@ -932,7 +927,7 @@ def has_read_permission_for_action_result( @service_method(path="action.exists", name="exists", roles=GUEST_ROLE_LEVEL) def exists( - self, context: AuthedServiceContext, obj_id: UID + self, context: AuthedServiceContext, obj_id: UID, ) -> Result[SyftSuccess, SyftError]: """Checks if the given object id exists in the Action Store""" if self.store.exists(obj_id): @@ -942,7 +937,7 @@ def exists( @service_method(path="action.delete", name="delete", roles=ADMIN_ROLE_LEVEL) def delete( - self, context: AuthedServiceContext, uid: UID, soft_delete: bool = False + self, context: AuthedServiceContext, uid: UID, soft_delete: bool = False, ) -> SyftSuccess | SyftError: get_res = self.store.get(uid=uid, credentials=context.credentials) if get_res.is_err(): @@ -958,7 +953,7 @@ def delete( # delete the action object from the action store store_del_res = self._delete_from_action_store( - context=context, uid=obj.id, soft_delete=soft_delete + context=context, uid=obj.id, soft_delete=soft_delete, ) if isinstance(store_del_res, SyftError): return SyftError(message=store_del_res.message) @@ -973,12 +968,12 @@ def _delete_blob_storage_entry( ) -> SyftSuccess | SyftError: deleted_blob_ids = [] blob_store_service = cast( - BlobStorageService, context.server.get_service(BlobStorageService) + BlobStorageService, context.server.get_service(BlobStorageService), ) if isinstance(obj, ActionObject) and obj.syft_blob_storage_entry_id: blob_del_res = blob_store_service.delete( - context=context, uid=obj.syft_blob_storage_entry_id + context=context, uid=obj.syft_blob_storage_entry_id, ) if isinstance(blob_del_res, SyftError): return SyftError(message=blob_del_res.message) @@ -987,7 +982,7 @@ def _delete_blob_storage_entry( if isinstance(obj, TwinObject): if obj.private.syft_blob_storage_entry_id: blob_del_res = blob_store_service.delete( - context=context, uid=obj.private.syft_blob_storage_entry_id + context=context, uid=obj.private.syft_blob_storage_entry_id, ) if isinstance(blob_del_res, SyftError): return SyftError(message=blob_del_res.message) @@ -995,7 +990,7 @@ def _delete_blob_storage_entry( if obj.mock.syft_blob_storage_entry_id: blob_del_res = blob_store_service.delete( - context=context, uid=obj.mock.syft_blob_storage_entry_id + context=context, uid=obj.mock.syft_blob_storage_entry_id, ) if isinstance(blob_del_res, SyftError): return SyftError(message=blob_del_res.message) @@ -1019,7 +1014,7 @@ def _delete_from_action_store( if isinstance(obj, TwinObject): res = self._soft_delete_action_obj( - context=context, action_obj=obj.private + context=context, action_obj=obj.private, ) if res.is_err(): return SyftError(message=res.err()) @@ -1039,7 +1034,7 @@ def _delete_from_action_store( return SyftSuccess(message=f"Action object with uid '{uid}' deleted.") def _soft_delete_action_obj( - self, context: AuthedServiceContext, action_obj: ActionObject + self, context: AuthedServiceContext, action_obj: ActionObject, ) -> Result[ActionObject, str]: action_obj.syft_action_data_cache = None res = action_obj._save_to_blob_storage() @@ -1053,13 +1048,13 @@ def _soft_delete_action_obj( def resolve_action_args( - action: Action, context: AuthedServiceContext, service: ActionService + action: Action, context: AuthedServiceContext, service: ActionService, ) -> tuple[Ok[dict], bool]: has_twin_inputs = False args = [] for arg_id in action.args: arg_value = service._get( - context=context, uid=arg_id, twin_mode=TwinMode.NONE, has_permission=True + context=context, uid=arg_id, twin_mode=TwinMode.NONE, has_permission=True, ) if arg_value.is_err(): return arg_value, False @@ -1070,13 +1065,13 @@ def resolve_action_args( def resolve_action_kwargs( - action: Action, context: AuthedServiceContext, service: ActionService + action: Action, context: AuthedServiceContext, service: ActionService, ) -> tuple[Ok[dict], bool]: has_twin_inputs = False kwargs = {} for key, arg_id in action.kwargs.items(): kwarg_value = service._get( - context=context, uid=arg_id, twin_mode=TwinMode.NONE, has_permission=True + context=context, uid=arg_id, twin_mode=TwinMode.NONE, has_permission=True, ) if kwarg_value.is_err(): return kwarg_value, False @@ -1135,7 +1130,7 @@ def _get_target_callable(path: str, op: str) -> Any: private_kwargs = filter_twin_kwargs(kwargs, twin_mode=twin_mode) private_result = target_callable(*private_args, **private_kwargs) result_action_object_private = wrap_result( - action.result_id, private_result + action.result_id, private_result, ) twin_mode = TwinMode.MOCK @@ -1195,7 +1190,7 @@ def execute_object( private_kwargs = filter_twin_kwargs(kwargs, twin_mode=TwinMode.PRIVATE) private_result = target_method(*private_args, **private_kwargs) result_action_object_private = wrap_result( - action.result_id, private_result + action.result_id, private_result, ) mock_args = filter_twin_args(args, twin_mode=TwinMode.MOCK) @@ -1223,7 +1218,7 @@ def execute_object( result_action_object = wrap_result(action.result_id, result) else: raise Exception( - f"Bad combination of: twin_mode: {twin_mode} and has_twin_inputs: {has_twin_inputs}" + f"Bad combination of: twin_mode: {twin_mode} and has_twin_inputs: {has_twin_inputs}", ) else: return Err("Missing target method") @@ -1251,7 +1246,7 @@ def filter_twin_args(args: list[Any], twin_mode: TwinMode) -> Any: filtered.append(arg.mock.syft_action_data) else: raise Exception( - f"Filter can only use {TwinMode.PRIVATE} or {TwinMode.MOCK}" + f"Filter can only use {TwinMode.PRIVATE} or {TwinMode.MOCK}", ) else: filtered.append(arg.syft_action_data) @@ -1259,7 +1254,7 @@ def filter_twin_args(args: list[Any], twin_mode: TwinMode) -> Any: def filter_twin_kwargs( - kwargs: dict, twin_mode: TwinMode, allow_python_types: bool = False + kwargs: dict, twin_mode: TwinMode, allow_python_types: bool = False, ) -> Any: filtered = {} for k, v in kwargs.items(): @@ -1270,21 +1265,20 @@ def filter_twin_kwargs( filtered[k] = v.mock.syft_action_data else: raise Exception( - f"Filter can only use {TwinMode.PRIVATE} or {TwinMode.MOCK}" + f"Filter can only use {TwinMode.PRIVATE} or {TwinMode.MOCK}", ) + elif isinstance(v, ActionObject): + filtered[k] = v.syft_action_data + elif ( + isinstance(v, str | int | float | dict | CustomEndpointActionObject) + and allow_python_types + ): + filtered[k] = v else: - if isinstance(v, ActionObject): - filtered[k] = v.syft_action_data - elif ( - isinstance(v, str | int | float | dict | CustomEndpointActionObject) - and allow_python_types - ): - filtered[k] = v - else: - # third party - raise ValueError( - f"unexepected value {v} passed to filtered twin kwargs" - ) + # third party + raise ValueError( + f"unexepected value {v} passed to filtered twin kwargs", + ) return filtered diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 250b3c5e9b5..8c740598c6e 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -5,31 +5,27 @@ import threading # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable -from ...server.credentials import SyftSigningKey -from ...server.credentials import SyftVerifyKey +from ...server.credentials import SyftSigningKey, SyftVerifyKey from ...store.dict_document_store import DictStoreConfig -from ...store.document_store import BasePartitionSettings -from ...store.document_store import DocumentStore -from ...store.document_store import StoreConfig +from ...store.document_store import BasePartitionSettings, DocumentStore, StoreConfig from ...types.syft_object import SyftObject from ...types.twin_object import TwinObject -from ...types.uid import LineageID -from ...types.uid import UID +from ...types.uid import UID, LineageID from ..response import SyftSuccess from .action_object import is_action_data_empty -from .action_permissions import ActionObjectEXECUTE -from .action_permissions import ActionObjectOWNER -from .action_permissions import ActionObjectPermission -from .action_permissions import ActionObjectREAD -from .action_permissions import ActionObjectWRITE -from .action_permissions import ActionPermission -from .action_permissions import StoragePermission +from .action_permissions import ( + ActionObjectEXECUTE, + ActionObjectOWNER, + ActionObjectPermission, + ActionObjectREAD, + ActionObjectWRITE, + ActionPermission, + StoragePermission, +) lock = threading.RLock() @@ -42,11 +38,13 @@ class ActionStore: class KeyValueActionStore(ActionStore): """Generic Key-Value Action store. - Parameters: + Parameters + ---------- store_config: StoreConfig Backend specific configuration, including connection configuration, database name, or client class type. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + """ def __init__( @@ -60,13 +58,13 @@ def __init__( self.store_config = store_config self.settings = BasePartitionSettings(name="Action") self.data = self.store_config.backing_store( - "data", self.settings, self.store_config + "data", self.settings, self.store_config, ) self.permissions = self.store_config.backing_store( - "permissions", self.settings, self.store_config, ddtype=set + "permissions", self.settings, self.store_config, ddtype=set, ) self.storage_permissions = self.store_config.backing_store( - "storage_permissions", self.settings, self.store_config, ddtype=set + "storage_permissions", self.settings, self.store_config, ddtype=set, ) if root_verify_key is None: @@ -81,7 +79,7 @@ def __init__( self.__user_stash = UserStash(store=document_store) def get( - self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False + self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False, ) -> Result[SyftObject, str]: uid = uid.id # We only need the UID from LineageID or UID @@ -106,7 +104,7 @@ def get_mock(self, uid: UID) -> Result[SyftObject, str]: try: syft_object = self.data[uid] if isinstance(syft_object, TwinObject) and not is_action_data_empty( - syft_object.mock + syft_object.mock, ): return Ok(syft_object.mock) return Err("No mock") @@ -170,7 +168,7 @@ def set( else: # root takes owneship, but you can still write ownership_result = self.take_ownership( - uid=uid, credentials=self.root_verify_key + uid=uid, credentials=self.root_verify_key, ) can_write = True if ownership_result.is_ok() else False @@ -186,7 +184,7 @@ def set( [ ActionObjectWRITE(uid=uid, credentials=credentials), ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] + ], ) if uid not in self.storage_permissions: @@ -194,14 +192,14 @@ def set( self.storage_permissions[uid] = set() if add_storage_permission: self.add_storage_permission( - StoragePermission(uid=uid, server_uid=self.server_uid) + StoragePermission(uid=uid, server_uid=self.server_uid), ) return Ok(SyftSuccess(message=f"Set for ID: {uid}")) return Err(f"Permission: {write_permission} denied") def take_ownership( - self, uid: UID, credentials: SyftVerifyKey + self, uid: UID, credentials: SyftVerifyKey, ) -> Result[SyftSuccess, str]: uid = uid.id # We only need the UID from LineageID or UID @@ -213,7 +211,7 @@ def take_ownership( ActionObjectWRITE(uid=uid, credentials=credentials), ActionObjectREAD(uid=uid, credentials=credentials), ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] + ], ) return Ok(SyftSuccess(message=f"Ownership of ID: {uid} taken.")) return Err(f"UID: {uid} already owned.") @@ -266,13 +264,7 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: return True # 🟡 TODO 14: add ALL_READ, ALL_EXECUTE etc - if permission.permission == ActionPermission.OWNER: - pass - elif permission.permission == ActionPermission.READ: - pass - elif permission.permission == ActionPermission.WRITE: - pass - elif permission.permission == ActionPermission.EXECUTE: + if permission.permission == ActionPermission.OWNER or permission.permission == ActionPermission.READ or (permission.permission == ActionPermission.WRITE or permission.permission == ActionPermission.EXECUTE): pass return False @@ -343,7 +335,7 @@ def _all( return Ok(result) def migrate_data( - self, to_klass: SyftObject, credentials: SyftVerifyKey + self, to_klass: SyftObject, credentials: SyftVerifyKey, ) -> Result[bool, str]: has_root_permission = credentials == self.root_verify_key @@ -355,7 +347,7 @@ def migrate_data( migrated_value = value.migrate_to(to_klass.__version__) except Exception as e: return Err( - f"Failed to migrate data to {to_klass} {to_klass.__version__} for qk: {key}. Exception: {e}" + f"Failed to migrate data to {to_klass} {to_klass.__version__} for qk: {key}. Exception: {e}", ) result = self.set( uid=key, @@ -375,11 +367,13 @@ def migrate_data( class DictActionStore(KeyValueActionStore): """Dictionary-Based Key-Value Action store. - Parameters: + Parameters + ---------- store_config: StoreConfig Backend specific configuration, including client class type. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + """ def __init__( @@ -402,25 +396,27 @@ def __init__( class SQLiteActionStore(KeyValueActionStore): """SQLite-Based Key-Value Action store. - Parameters: + Parameters + ---------- store_config: StoreConfig SQLite specific configuration, including connection settings or client class type. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + """ - pass @serializable(canonical_name="MongoActionStore", version=1) class MongoActionStore(KeyValueActionStore): """Mongo-Based Action store. - Parameters: + Parameters + ---------- store_config: StoreConfig Mongo specific configuration. root_verify_key: Optional[SyftVerifyKey] Signature verification key, used for checking access permissions. + """ - pass diff --git a/packages/syft/src/syft/service/action/action_types.py b/packages/syft/src/syft/service/action/action_types.py index c7bd730d557..8dd6083e780 100644 --- a/packages/syft/src/syft/service/action/action_types.py +++ b/packages/syft/src/syft/service/action/action_types.py @@ -13,9 +13,11 @@ def action_type_for_type(obj_or_type: Any) -> type: """Convert standard type to Syft types - Parameters: + Parameters + ---------- obj_or_type: Union[object, type] Can be an object or a class + """ if isinstance(obj_or_type, ActionDataEmpty): obj_or_type = obj_or_type.syft_internal_type @@ -24,7 +26,7 @@ def action_type_for_type(obj_or_type: Any) -> type: if obj_or_type not in action_types: logger.debug( - f"WARNING: No Type for {obj_or_type}, returning {action_types[Any]}" + f"WARNING: No Type for {obj_or_type}, returning {action_types[Any]}", ) return action_types.get(obj_or_type, action_types[Any]) @@ -33,9 +35,11 @@ def action_type_for_type(obj_or_type: Any) -> type: def action_type_for_object(obj: Any) -> type: """Convert standard type to Syft types - Parameters: + Parameters + ---------- obj_or_type: Union[object, type] Can be an object or a class + """ _type = type(obj) diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index 1949eeb0575..9b5d0ce80a7 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -1,6 +1,5 @@ # stdlib -from typing import Any -from typing import ClassVar +from typing import Any, ClassVar # third party import numpy as np @@ -9,9 +8,7 @@ # relative from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from .action_object import ActionObject -from .action_object import ActionObjectPointer -from .action_object import BASE_PASSTHROUGH_ATTRS +from .action_object import BASE_PASSTHROUGH_ATTRS, ActionObject, ActionObjectPointer from .action_types import action_types # @serializable(attrs=["id", "server_uid", "parent_id"]) @@ -62,7 +59,7 @@ class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): # return self == other def __array_ufunc__( - self, ufunc: Any, method: str, *inputs: Any, **kwargs: Any + self, ufunc: Any, method: str, *inputs: Any, **kwargs: Any, ) -> Self | tuple[Self, ...]: inputs = tuple( ( @@ -81,7 +78,7 @@ def __array_ufunc__( ) else: return NumpyArrayObject( - syft_action_data_cache=result, dtype=result.dtype, shape=result.shape + syft_action_data_cache=result, dtype=result.dtype, shape=result.shape, ) diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index 9de480ddd0f..9e1c55966c4 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,16 +1,13 @@ # stdlib -from typing import Any -from typing import ClassVar +from typing import Any, ClassVar # third party -from pandas import DataFrame -from pandas import Series +from pandas import DataFrame, Series # relative from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from .action_object import ActionObject -from .action_object import BASE_PASSTHROUGH_ATTRS +from .action_object import BASE_PASSTHROUGH_ATTRS, ActionObject from .action_types import action_types diff --git a/packages/syft/src/syft/service/action/plan.py b/packages/syft/src/syft/service/action/plan.py index cca7437f869..0c1cc3f481b 100644 --- a/packages/syft/src/syft/service/action/plan.py +++ b/packages/syft/src/syft/service/action/plan.py @@ -1,17 +1,14 @@ # stdlib -from collections.abc import Callable import inspect +from collections.abc import Callable from typing import Any # relative -from ... import ActionObject -from ... import Worker +from ... import ActionObject, Worker from ...client.client import SyftClient from ...serde.recursive import recursive_serde_register -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from .action_object import Action -from .action_object import TraceResultRegistry +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject +from .action_object import Action, TraceResultRegistry class Plan(SyftObject): @@ -37,7 +34,7 @@ def __repr__(self) -> str: inp_str = "Inputs:\n" inp_str += "\n".join( - [f"\t\t{k}: {v.__class__.__name__}" for k, v in self.inputs.items()] + [f"\t\t{k}: {v.__class__.__name__}" for k, v in self.inputs.items()], ) act_str = f"Actions:\n\t\t{len(self.actions)} Actions" @@ -92,7 +89,7 @@ def planify(func: Callable) -> ActionObject: def build_plan_inputs( - forward_func: Callable, client: SyftClient + forward_func: Callable, client: SyftClient, ) -> dict[str, ActionObject]: signature = inspect.signature(forward_func) res = {} @@ -104,7 +101,7 @@ def build_plan_inputs( res[k] = default_value.send(client) else: raise ValueError( - f"arg {k} has no placeholder as default value (required for @make_plan functions)" + f"arg {k} has no placeholder as default value (required for @make_plan functions)", ) return res diff --git a/packages/syft/src/syft/service/action/verification.py b/packages/syft/src/syft/service/action/verification.py index 063634e993c..03b7d578baf 100644 --- a/packages/syft/src/syft/service/action/verification.py +++ b/packages/syft/src/syft/service/action/verification.py @@ -7,9 +7,7 @@ import pandas as pd # relative -from ..response import SyftError -from ..response import SyftResponseMessage -from ..response import SyftSuccess +from ..response import SyftError, SyftResponseMessage, SyftSuccess from .action_object import ActionObject @@ -29,15 +27,15 @@ def verify_result( # Manual type casting for now, to automate later if isinstance(asset.syft_action_data, np.ndarray): trace_assets.append( - ActionObject(id=asset.id, syft_result_obj=np.ndarray([])) + ActionObject(id=asset.id, syft_result_obj=np.ndarray([])), ) elif isinstance(asset.syft_action_data, pd.DataFrame): trace_assets.append( - ActionObject(id=asset.id, syft_result_obj=pd.DataFrame()) + ActionObject(id=asset.id, syft_result_obj=pd.DataFrame()), ) else: raise NotImplementedError( - f"Trace mode not yet automated for type: {type(asset.syft_action_data)}" + f"Trace mode not yet automated for type: {type(asset.syft_action_data)}", ) print("Code Verification in progress.") @@ -86,20 +84,20 @@ def wrapper(*args: Any, **kwargs: Any) -> SyftSuccess | SyftError: for asset in args: if not isinstance(asset, ActionObject): raise Exception( - f"ActionObject expected, instead received: {type(asset)}" + f"ActionObject expected, instead received: {type(asset)}", ) # Manual type casting for now, to automate later if isinstance(asset.syft_action_data, np.ndarray): trace_assets.append( - ActionObject(id=asset.id, syft_result_obj=np.ndarray([])) + ActionObject(id=asset.id, syft_result_obj=np.ndarray([])), ) elif isinstance(asset.syft_action_data, pd.DataFrame): trace_assets.append( - ActionObject(id=asset.id, syft_result_obj=pd.DataFrame()) + ActionObject(id=asset.id, syft_result_obj=pd.DataFrame()), ) else: raise NotImplementedError( - f"Trace mode not yet automated for type: {type(asset.syft_action_data)}" + f"Trace mode not yet automated for type: {type(asset.syft_action_data)}", ) print("Evaluating function normally to obtain history hash") diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 89e19146e99..6762cacae19 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -1,37 +1,32 @@ # stdlib import ast -from collections.abc import Callable import inspect -from inspect import Signature import keyword import linecache import re import textwrap -from typing import Any -from typing import cast +from collections.abc import Callable +from inspect import Signature +from typing import Any, cast # third party -from pydantic import ValidationError -from pydantic import field_validator -from pydantic import model_validator -from result import Err -from result import Ok -from result import Result +from pydantic import ValidationError, field_validator, model_validator +from result import Err, Ok, Result # relative from ...abstract_server import AbstractServer from ...client.client import SyftClient from ...serde.serializable import serializable from ...serde.signature import signature_remove_context -from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, PartialSyftObject, SyftObject from ...types.syncable_object import SyncableSyftObject -from ...types.transforms import TransformContext -from ...types.transforms import generate_action_object_id -from ...types.transforms import generate_id -from ...types.transforms import keep -from ...types.transforms import transform +from ...types.transforms import ( + TransformContext, + generate_action_object_id, + generate_id, + keep, + transform, +) from ...types.uid import UID from ...util.misc_objs import MarkdownDescription from ..context import AuthedServiceContext @@ -79,8 +74,9 @@ def get_signature(func: Callable) -> Signature: def register_fn_in_linecache(fname: str, src: str) -> None: - """adds a function to linecache, such that inspect.getsource works for functions nested in this function. - This only works if the same function is compiled under the same filename""" + """Adds a function to linecache, such that inspect.getsource works for functions nested in this function. + This only works if the same function is compiled under the same filename + """ lines = [ line + "\n" for line in src.splitlines() ] # use same splitting method same as linecache 112 (py3.12) @@ -197,7 +193,7 @@ def validate_func_name(cls, func_name: str) -> str: @field_validator("settings", check_fields=False) @classmethod def validate_settings( - cls, settings: dict[str, Any] | None + cls, settings: dict[str, Any] | None, ) -> dict[str, Any] | None: return settings @@ -247,12 +243,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: from ..context import AuthedServiceContext mock_context = AuthedServiceContext( - server=AbstractServer(), credentials=SyftSigningKey.generate().verify_key + server=AbstractServer(), credentials=SyftSigningKey.generate().verify_key, ) return self.call_locally(mock_context, *args, **kwargs) def call_locally( - self, context: AuthedServiceContext, *args: Any, **kwargs: Any + self, context: AuthedServiceContext, *args: Any, **kwargs: Any, ) -> Any: inner_function = ast.parse(self.api_code).body[0] inner_function.decorator_list = [] @@ -304,7 +300,7 @@ def validate_signature(cls, data: dict[str, Any]) -> dict[str, Any]: # Add none check if private_function and private_function.signature != mock_function.signature: raise ValueError( - "Mock and Private API Endpoints must have the same signature." + "Mock and Private API Endpoints must have the same signature.", ) return data @@ -328,7 +324,7 @@ def validate_path(cls, path: str) -> str: @field_validator("private_function", check_fields=False) @classmethod def validate_private_function( - cls, private_function: PrivateAPIEndpoint | None + cls, private_function: PrivateAPIEndpoint | None, ) -> PrivateAPIEndpoint | None: # TODO: what kind of validation should we do here? @@ -337,7 +333,7 @@ def validate_private_function( @field_validator("mock_function", check_fields=False) @classmethod def validate_mock_function( - cls, mock_function: PublicAPIEndpoint + cls, mock_function: PublicAPIEndpoint, ) -> PublicAPIEndpoint: # TODO: what kind of validation should we do here? return mock_function @@ -371,7 +367,7 @@ class CreateTwinAPIEndpoint(BaseTwinAPIEndpoint): endpoint_timeout: int = 60 def __init__( - self, description: str | MarkdownDescription | None = "", **kwargs: Any + self, description: str | MarkdownDescription | None = "", **kwargs: Any, ) -> None: if isinstance(description, str): description = MarkdownDescription(text=description) @@ -418,9 +414,13 @@ def has_permission(self, context: AuthedServiceContext) -> bool: """Check if the user has permission to access the endpoint. Args: + ---- context: The context of the user requesting the code. + Returns: + ------- bool: True if the user has permission to access the endpoint, False otherwise. + """ if context.role.value == 128: return True @@ -430,9 +430,13 @@ def select_code(self, context: AuthedServiceContext) -> Result[Ok, Err]: """Select the code to execute based on the user's permissions and public code availability. Args: + ---- context: The context of the user requesting the code. + Returns: + ------- Result[Ok, Err]: The selected code to execute. + """ if self.has_permission(context) and self.private_function: return Ok(self.private_function) @@ -442,11 +446,13 @@ def exec(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> Any: """Execute the code based on the user's permissions and public code availability. Args: + ---- context: The context of the user requesting the code. *args: Any **kwargs: Any Returns: Any: The result of the executed code. + """ result = self.select_code(context) if result.is_err(): @@ -456,7 +462,7 @@ def exec(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> Any: return self.exec_code(selected_code, context, *args, **kwargs) def exec_mock_function( - self, context: AuthedServiceContext, *args: Any, **kwargs: Any + self, context: AuthedServiceContext, *args: Any, **kwargs: Any, ) -> Any: """Execute the public code if it exists.""" if self.mock_function: @@ -465,16 +471,18 @@ def exec_mock_function( return SyftError(message="No public code available") def exec_private_function( - self, context: AuthedServiceContext, *args: Any, **kwargs: Any + self, context: AuthedServiceContext, *args: Any, **kwargs: Any, ) -> Any: """Execute the private code if user is has the proper permissions. Args: + ---- context: The context of the user requesting the code. *args: Any **kwargs: Any Returns: Any: The result of the executed code. + """ if self.private_function is None: return SyftError(message="No private code available") @@ -489,10 +497,10 @@ def get_user_client_from_server(self, context: AuthedServiceContext) -> SyftClie guest_client = context.server.get_guest_client() user_client = guest_client signing_key_for_verify_key = context.server.get_service_method( - UserService.signing_key_for_verify_key + UserService.signing_key_for_verify_key, ) private_key = signing_key_for_verify_key( - context=context, verify_key=context.credentials + context=context, verify_key=context.credentials, ) signing_key = private_key.signing_key user_client.credentials = signing_key @@ -524,7 +532,7 @@ def exec_code( exec(raw_byte_code) # nosec internal_context = code.build_internal_context( - context=context, admin_client=admin_client, user_client=user_client + context=context, admin_client=admin_client, user_client=user_client, ) # execute it @@ -541,7 +549,7 @@ def exec_code( api_service = context.server.get_service("apiservice") upsert_result = api_service.stash.upsert( - context.server.get_service("userservice").admin_verify_key(), self + context.server.get_service("userservice").admin_verify_key(), self, ) if upsert_result.is_err(): @@ -554,11 +562,11 @@ def exec_code( # TODO: cleanup typeerrors if context.role.value == 128 or isinstance(e, TypeError): return SyftError( - message=f"An error was raised during the execution of the API endpoint call: \n {str(e)}" + message=f"An error was raised during the execution of the API endpoint call: \n {e!s}", ) else: return SyftError( - message="Ops something went wrong during this endpoint execution, please contact your admin." + message="Ops something went wrong during this endpoint execution, please contact your admin.", ) @@ -576,7 +584,7 @@ def check_and_cleanup_signature(context: TransformContext) -> TransformContext: params = dict(context.obj.signature.parameters) if "context" not in params: raise ValueError( - "Function Signature must include 'context' [AuthedContext] parameters." + "Function Signature must include 'context' [AuthedContext] parameters.", ) params.pop("context", None) context.output["signature"] = Signature( @@ -659,8 +667,8 @@ def endpoint_to_private_endpoint() -> list[Callable]: "helper_functions", "state", "signature", - ] - ) + ], + ), ] @@ -676,8 +684,8 @@ def endpoint_to_public_endpoint() -> list[Callable]: "helper_functions", "state", "signature", - ] - ) + ], + ), ] diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 87051d9b4b4..71b6313ce17 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -1,12 +1,10 @@ # stdlib import time -from typing import Any -from typing import cast +from typing import Any, cast # third party from pydantic import ValidationError -from result import Err -from result import Ok +from result import Err, Ok # relative from ...serde.serializable import serializable @@ -17,22 +15,23 @@ from ...util.telemetry import instrument from ..action.action_service import ActionService from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from .api import CreateTwinAPIEndpoint -from .api import Endpoint -from .api import PrivateAPIEndpoint -from .api import PublicAPIEndpoint -from .api import TwinAPIContextView -from .api import TwinAPIEndpoint -from .api import TwinAPIEndpointView -from .api import UpdateTwinAPIEndpoint +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + ADMIN_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, +) +from .api import ( + CreateTwinAPIEndpoint, + Endpoint, + PrivateAPIEndpoint, + PublicAPIEndpoint, + TwinAPIContextView, + TwinAPIEndpoint, + TwinAPIEndpointView, + UpdateTwinAPIEndpoint, +) from .api_stash import TwinAPIEndpointStash @@ -71,7 +70,7 @@ def set( if isinstance(endpoint, CreateTwinAPIEndpoint): endpoint_exists = self.stash.path_exists( - context.credentials, new_endpoint.path + context.credentials, new_endpoint.path, ) if endpoint_exists.is_err(): @@ -79,7 +78,7 @@ def set( if endpoint_exists.is_ok() and endpoint_exists.ok(): return SyftError( - message="An API endpoint already exists at the given path." + message="An API endpoint already exists at the given path.", ) result = self.stash.upsert(context.credentials, endpoint=new_endpoint) @@ -119,7 +118,6 @@ def update( endpoint_timeout: int | None = None, ) -> SyftSuccess | SyftError: """Updates an specific API endpoint.""" - # if any of these are supplied e.g. truthy then keep going otherwise return # an error # TODO: change to an Update object with autosplat @@ -131,7 +129,7 @@ def update( ): return SyftError( message='At least one of "mock_function", "private_function", ' - '"hide_mock_definition" or "endpoint_timeout" is required.' + '"hide_mock_definition" or "endpoint_timeout" is required.', ) endpoint_result = self.stash.get_by_path(context.credentials, endpoint_path) @@ -193,10 +191,9 @@ def update( roles=ADMIN_ROLE_LEVEL, ) def delete( - self, context: AuthedServiceContext, endpoint_path: str + self, context: AuthedServiceContext, endpoint_path: str, ) -> SyftSuccess | SyftError: """Deletes an specific API endpoint.""" - result = self.stash.get_by_path(context.credentials, endpoint_path) if result.is_err(): @@ -219,7 +216,7 @@ def delete( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def view( - self, context: AuthedServiceContext, path: str + self, context: AuthedServiceContext, path: str, ) -> TwinAPIEndpointView | SyftError: """Retrieves an specific API endpoint.""" result = self.stash.get_by_path(context.server.verify_key, path) @@ -235,7 +232,7 @@ def view( roles=ADMIN_ROLE_LEVEL, ) def get( - self, context: AuthedServiceContext, api_path: str + self, context: AuthedServiceContext, api_path: str, ) -> TwinAPIEndpoint | SyftError: """Retrieves an specific API endpoint.""" result = self.stash.get_by_path(context.server.verify_key, api_path) @@ -335,7 +332,7 @@ def api_endpoints( return api_endpoint_view @service_method( - path="api.call_in_jobs", name="call_in_jobs", roles=GUEST_ROLE_LEVEL + path="api.call_in_jobs", name="call_in_jobs", roles=GUEST_ROLE_LEVEL, ) def call_in_jobs( self, @@ -426,7 +423,7 @@ def _call_in_jobs( message=( f"Function timed out in {custom_endpoint.endpoint_timeout} seconds. " + f"Get the Job with id: {job_id} to check results." - ) + ), ) if job.status == JobStatus.COMPLETED: @@ -435,10 +432,10 @@ def _call_in_jobs( return SyftError(message="Function failed to complete.") @service_method( - path="api.get_public_context", name="get_public_context", roles=ADMIN_ROLE_LEVEL + path="api.get_public_context", name="get_public_context", roles=ADMIN_ROLE_LEVEL, ) def get_public_context( - self, context: AuthedServiceContext, path: str + self, context: AuthedServiceContext, path: str, ) -> dict[str, Any] | SyftError: """Get specific public api context.""" custom_endpoint = self.get_code( @@ -449,7 +446,7 @@ def get_public_context( return custom_endpoint return custom_endpoint.mock_function.build_internal_context(context=context).to( - TwinAPIContextView + TwinAPIContextView, ) @service_method( @@ -458,7 +455,7 @@ def get_public_context( roles=ADMIN_ROLE_LEVEL, ) def get_private_context( - self, context: AuthedServiceContext, path: str + self, context: AuthedServiceContext, path: str, ) -> dict[str, Any] | SyftError: """Get specific private api context.""" custom_endpoint = self.get_code( @@ -469,11 +466,11 @@ def get_private_context( return custom_endpoint custom_endpoint.private_function = cast( - PrivateAPIEndpoint, custom_endpoint.private_function + PrivateAPIEndpoint, custom_endpoint.private_function, ) return custom_endpoint.private_function.build_internal_context( - context=context + context=context, ).to(TwinAPIContextView) @service_method(path="api.get_all", name="get_all", roles=ADMIN_ROLE_LEVEL) @@ -518,7 +515,7 @@ def call( ) if result.is_err(): return SyftError( - message=f"Failed to set result to store: {result.err()}" + message=f"Failed to set result to store: {result.err()}", ) return Ok(result.ok()) @@ -558,7 +555,7 @@ def call_public( ) if result.is_err(): return SyftError( - message=f"Failed to set result to store: {result.err()}" + message=f"Failed to set result to store: {result.err()}", ) return Ok(result.ok()) @@ -569,7 +566,7 @@ def call_public( return Err(value=f"Failed to run. {e}, {traceback.format_exc()}") @service_method( - path="api.call_private", name="call_private", roles=GUEST_ROLE_LEVEL + path="api.call_private", name="call_private", roles=GUEST_ROLE_LEVEL, ) def call_private( self, @@ -596,11 +593,11 @@ def call_private( action_service = cast(ActionService, context.server.get_service(ActionService)) try: result = action_service.set_result_to_store( - context=context, result_action_object=action_obj + context=context, result_action_object=action_obj, ) if result.is_err(): return SyftError( - message=f"Failed to set result to store: {result.err()}" + message=f"Failed to set result to store: {result.err()}", ) return Ok(result.ok()) @@ -616,7 +613,7 @@ def call_private( name="exists", ) def exists( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Check if an endpoint exists""" endpoint = self.get_endpoint_by_uid(context, uid) @@ -672,7 +669,7 @@ def execute_server_side_endpoint_mock_by_id( return endpoint.exec_code(endpoint.mock_function, context, *args, **kwargs) def get_endpoint_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> TwinAPIEndpoint | SyftError: admin_key = context.server.get_service("userservice").admin_verify_key() result = self.stash.get_by_uid(admin_key, uid) @@ -681,7 +678,7 @@ def get_endpoint_by_uid( return result.ok() def get_endpoints( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[TwinAPIEndpoint] | SyftError: # TODO: Add ability to specify which roles see which endpoints # for now skip auth @@ -691,12 +688,12 @@ def get_endpoints( return SyftError(messages="Unable to get CustomAPIEndpoint") def get_code( - self, context: AuthedServiceContext, endpoint_path: str + self, context: AuthedServiceContext, endpoint_path: str, ) -> TwinAPIEndpoint | SyftError: result = self.stash.get_by_path(context.server.verify_key, path=endpoint_path) if result.is_err(): return SyftError( - message=f"CustomAPIEndpoint: {endpoint_path} does not exist" + message=f"CustomAPIEndpoint: {endpoint_path} does not exist", ) if result.is_ok(): diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 3b6daac1422..7dc26c32901 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -1,16 +1,12 @@ # stdlib # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings +from ...store.document_store import BaseUIDStoreStash, DocumentStore, PartitionSettings from .api import TwinAPIEndpoint MISSING_PATH_STRING = "Endpoint path: {path} does not exist." @@ -20,14 +16,14 @@ class TwinAPIEndpointStash(BaseUIDStoreStash): object_type = TwinAPIEndpoint settings: PartitionSettings = PartitionSettings( - name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint + name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_by_path( - self, credentials: SyftVerifyKey, path: str + self, credentials: SyftVerifyKey, path: str, ) -> Result[TwinAPIEndpoint, str]: endpoint_results = self.get_all(credentials=credentials) if endpoint_results.is_err(): @@ -69,6 +65,6 @@ def upsert( super().delete_by_uid(credentials=credentials, uid=endpoint.id) result = super().set( - credentials=credentials, obj=endpoint, ignore_duplicates=False + credentials=credentials, obj=endpoint, ignore_duplicates=False, ) return result diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index 87289278bf3..f2ff094b39c 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -9,14 +9,14 @@ from ...store.document_store import DocumentStore from ...util.util import str_to_bool from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import service_method +from ..response import SyftError, SyftSuccess +from ..service import AbstractService, service_method from ..user.user_roles import GUEST_ROLE_LEVEL -from .attestation_constants import ATTESTATION_SERVICE_URL -from .attestation_constants import ATTEST_CPU_ENDPOINT -from .attestation_constants import ATTEST_GPU_ENDPOINT +from .attestation_constants import ( + ATTEST_CPU_ENDPOINT, + ATTEST_GPU_ENDPOINT, + ATTESTATION_SERVICE_URL, +) @serializable(canonical_name="AttestationService", version=1) @@ -27,7 +27,7 @@ def __init__(self, store: DocumentStore) -> None: self.store = store def perform_request( - self, method: Callable, endpoint: str, raw: bool = False + self, method: Callable, endpoint: str, raw: bool = False, ) -> SyftSuccess | SyftError | str: try: response = method(f"{ATTESTATION_SERVICE_URL}{endpoint}") @@ -51,7 +51,7 @@ def perform_request( roles=GUEST_ROLE_LEVEL, ) def get_cpu_attestation( - self, context: AuthedServiceContext, raw_token: bool = False + self, context: AuthedServiceContext, raw_token: bool = False, ) -> str | SyftError | SyftSuccess: return self.perform_request(requests.get, ATTEST_CPU_ENDPOINT, raw_token) @@ -61,6 +61,6 @@ def get_cpu_attestation( roles=GUEST_ROLE_LEVEL, ) def get_gpu_attestation( - self, context: AuthedServiceContext, raw_token: bool = False + self, context: AuthedServiceContext, raw_token: bool = False, ) -> str | SyftError | SyftSuccess: return self.perform_request(requests.get, ATTEST_GPU_ENDPOINT, raw_token) diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py index d3e275625ae..92361f15d82 100644 --- a/packages/syft/src/syft/service/blob_storage/remote_profile.py +++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py @@ -1,10 +1,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...store.document_store import BaseUIDStoreStash, DocumentStore, PartitionSettings +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject @serializable() @@ -28,7 +25,7 @@ class AzureRemoteProfile(RemoteProfile): class RemoteProfileStash(BaseUIDStoreStash): object_type = RemoteProfile settings: PartitionSettings = PartitionSettings( - name=RemoteProfile.__canonical_name__, object_type=RemoteProfile + name=RemoteProfile.__canonical_name__, object_type=RemoteProfile, ) def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 3254e2918da..93f0e390fb4 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,25 +11,21 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.document_store import DocumentStore -from ...store.document_store import UIDPartitionKey -from ...types.blob_storage import AzureSecureFilePathLocation -from ...types.blob_storage import BlobFileType -from ...types.blob_storage import BlobStorageEntry -from ...types.blob_storage import BlobStorageMetadata -from ...types.blob_storage import CreateBlobStorageEntry -from ...types.blob_storage import SeaweedSecureFilePathLocation +from ...store.document_store import DocumentStore, UIDPartitionKey +from ...types.blob_storage import ( + AzureSecureFilePathLocation, + BlobFileType, + BlobStorageEntry, + BlobStorageMetadata, + CreateBlobStorageEntry, + SeaweedSecureFilePathLocation, +) from ...types.uid import UID from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from .remote_profile import AzureRemoteProfile -from .remote_profile import RemoteProfileStash +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL, GUEST_ROLE_LEVEL +from .remote_profile import AzureRemoteProfile, RemoteProfileStash from .stash import BlobStorageStash BlobDepositType = OnDiskBlobDeposit | SeaweedFSBlobDeposit @@ -48,7 +44,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="blob_storage.get_all", name="get_all") def get_all_blob_storage_entries( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[BlobStorageEntry] | SyftError: result = self.stash.get_all(context.credentials) if result.is_ok(): @@ -99,7 +95,7 @@ def mount_azure( print(init_request.content) # TODO check return code res = context.server.blob_storage_client.connect().client.list_objects( - Bucket=bucket_name + Bucket=bucket_name, ) # stdlib objects = res["Contents"] @@ -135,10 +131,10 @@ def mount_azure( return SyftSuccess(message="Mounting Azure Successful!") @service_method( - path="blob_storage.get_files_from_bucket", name="get_files_from_bucket" + path="blob_storage.get_files_from_bucket", name="get_files_from_bucket", ) def get_files_from_bucket( - self, context: AuthedServiceContext, bucket_name: str + self, context: AuthedServiceContext, bucket_name: str, ) -> list | SyftError: result = self.stash.find_all(context.credentials, bucket_name=bucket_name) if result.is_err(): @@ -168,7 +164,7 @@ def get_files_from_bucket( @service_method(path="blob_storage.get_by_uid", name="get_by_uid") def get_blob_storage_entry_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> BlobStorageEntry | SyftError: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): @@ -177,7 +173,7 @@ def get_blob_storage_entry_by_uid( @service_method(path="blob_storage.get_metadata", name="get_metadata") def get_blob_storage_metadata_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> BlobStorageEntry | SyftError: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): @@ -192,19 +188,19 @@ def get_blob_storage_metadata_by_uid( roles=GUEST_ROLE_LEVEL, ) def read( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> BlobRetrieval | SyftError: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): obj: BlobStorageEntry | None = result.ok() if obj is None: return SyftError( - message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it", ) with context.server.blob_storage_client.connect() as conn: res: BlobRetrieval = conn.read( - obj.location, obj.type_, bucket_name=obj.bucket_name + obj.location, obj.type_, bucket_name=obj.bucket_name, ) res.syft_blob_storage_entry_id = uid res.file_size = obj.file_size @@ -217,12 +213,12 @@ def _allocate( obj: CreateBlobStorageEntry, uploaded_by: SyftVerifyKey | None = None, ) -> BlobDepositType | SyftError: - """ - Allocate a secure location for the blob storage entry. + """Allocate a secure location for the blob storage entry. If uploaded_by is None, the credentials of the context will be used. Args: + ---- context (AuthedServiceContext): context obj (CreateBlobStorageEntry): create blob parameters uploaded_by (SyftVerifyKey | None, optional): Uploader credentials. @@ -230,7 +226,9 @@ def _allocate( Defaults to None. Returns: + ------- BlobDepositType | SyftError: Blob deposit + """ upload_credentials = uploaded_by or context.credentials @@ -264,7 +262,7 @@ def _allocate( roles=GUEST_ROLE_LEVEL, ) def allocate( - self, context: AuthedServiceContext, obj: CreateBlobStorageEntry + self, context: AuthedServiceContext, obj: CreateBlobStorageEntry, ) -> BlobDepositType | SyftError: return self._allocate(context, obj) @@ -287,7 +285,7 @@ def allocate_for_user( roles=GUEST_ROLE_LEVEL, ) def write_to_disk( - self, context: AuthedServiceContext, uid: UID, data: bytes + self, context: AuthedServiceContext, uid: UID, data: bytes, ) -> SyftSuccess | SyftError: result = self.stash.get_by_uid( credentials=context.credentials, @@ -300,7 +298,7 @@ def write_to_disk( if obj is None: return SyftError( - message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it", ) try: @@ -332,7 +330,7 @@ def mark_write_complete( if obj is None: return SyftError( - message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it", ) obj.no_lines = no_lines @@ -350,7 +348,7 @@ def mark_write_complete( @service_method(path="blob_storage.delete", name="delete") def delete( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: get_res = self.stash.get_by_uid(context.credentials, uid=uid) if get_res.is_err(): @@ -360,7 +358,7 @@ def delete( if obj is None: return SyftError( message=f"No blob storage entry exists for uid: {uid}, " - f"or you have no permissions to read it" + f"or you have no permissions to read it", ) try: @@ -370,17 +368,17 @@ def delete( return file_unlinked_result except Exception as e: return SyftError( - message=f"Failed to delete blob file with id '{uid}'. Error: {e}" + message=f"Failed to delete blob file with id '{uid}'. Error: {e}", ) blob_entry_delete_res = self.stash.delete( - context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True + context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True, ) if blob_entry_delete_res.is_err(): return SyftError(message=blob_entry_delete_res.err()) return SyftSuccess( - message=f"Blob storage entry with id '{uid}' deleted successfully." + message=f"Blob storage entry with id '{uid}' deleted successfully.", ) diff --git a/packages/syft/src/syft/service/blob_storage/stash.py b/packages/syft/src/syft/service/blob_storage/stash.py index 8fc93a4f034..7cd419a1493 100644 --- a/packages/syft/src/syft/service/blob_storage/stash.py +++ b/packages/syft/src/syft/service/blob_storage/stash.py @@ -1,8 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings +from ...store.document_store import BaseUIDStoreStash, DocumentStore, PartitionSettings from ...types.blob_storage import BlobStorageEntry @@ -10,7 +8,7 @@ class BlobStorageStash(BaseUIDStoreStash): object_type = BlobStorageEntry settings: PartitionSettings = PartitionSettings( - name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry + name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry, ) def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/blob_storage/util.py b/packages/syft/src/syft/service/blob_storage/util.py index df795c86b87..f2fc56ad10a 100644 --- a/packages/syft/src/syft/service/blob_storage/util.py +++ b/packages/syft/src/syft/service/blob_storage/util.py @@ -3,8 +3,7 @@ # relative from ...util.util import get_mb_serialized_size -from ..metadata.server_metadata import ServerMetadata -from ..metadata.server_metadata import ServerMetadataJSON +from ..metadata.server_metadata import ServerMetadata, ServerMetadataJSON def min_size_for_blob_storage_upload( @@ -14,6 +13,6 @@ def min_size_for_blob_storage_upload( def can_upload_to_blob_storage( - data: Any, metadata: ServerMetadata | ServerMetadataJSON + data: Any, metadata: ServerMetadata | ServerMetadataJSON, ) -> bool: return get_mb_serialized_size(data) >= min_size_for_blob_storage_upload(metadata) diff --git a/packages/syft/src/syft/service/code/code_parse.py b/packages/syft/src/syft/service/code/code_parse.py index 4cde893520d..3753d99a6b3 100644 --- a/packages/syft/src/syft/service/code/code_parse.py +++ b/packages/syft/src/syft/service/code/code_parse.py @@ -1,6 +1,6 @@ # stdlib -from _ast import Module import ast +from _ast import Module from typing import Any diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 2baafb2ea57..4959c127e50 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -6,21 +6,19 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionSettings, + QueryKeys, + UIDPartitionKey, +) from ...types.uid import UID from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL, GUEST_ROLE_LEVEL from .user_code import UserCodeStatusCollection @@ -40,7 +38,7 @@ def __init__(self, store: DocumentStore) -> None: self._object_type = self.object_type def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[UserCodeStatusCollection, str]: qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) return self.query_one(credentials=credentials, qks=qks) @@ -71,10 +69,10 @@ def create( return SyftError(message=result.err()) @service_method( - path="code_status.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL + path="code_status.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL, ) def get_status( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> UserCodeStatusCollection | SyftError: """Get the status of a user code item""" result = self.stash.get_by_uid(context.credentials, uid=uid) @@ -84,7 +82,7 @@ def get_status( @service_method(path="code_status.get_all", name="get_all", roles=ADMIN_ROLE_LEVEL) def get_all( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[UserCodeStatusCollection] | SyftError: """Get all user code item statuses""" result = self.stash.get_all(context.credentials) @@ -94,7 +92,7 @@ def get_all( @service_method(path="code_status.remove", name="remove", roles=ADMIN_ROLE_LEVEL) def remove( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Remove a user code item status""" result = self.stash.delete_by_uid(context.credentials, uid=uid) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 1d4615e2e54..5ac08bbecb4 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -3,108 +3,94 @@ # stdlib import ast -from collections.abc import Callable -from copy import deepcopy import datetime -from enum import Enum import hashlib import inspect -from io import StringIO import json import keyword import random import re import sys -from textwrap import dedent -from threading import Thread import time import traceback -from typing import Any -from typing import ClassVar -from typing import TYPE_CHECKING -from typing import cast -from typing import final +from collections.abc import Callable +from copy import deepcopy +from enum import Enum +from io import StringIO +from textwrap import dedent +from threading import Thread +from typing import TYPE_CHECKING, Any, ClassVar, cast, final # third party -from IPython.display import HTML -from IPython.display import Markdown -from IPython.display import display -from pydantic import ValidationError -from pydantic import field_validator -from result import Err -from result import Ok -from result import Result +from IPython.display import HTML, Markdown, display +from pydantic import ValidationError, field_validator +from result import Err, Ok, Result from typing_extensions import Self # relative -from ...abstract_server import ServerSideType -from ...abstract_server import ServerType -from ...client.api import APIRegistry -from ...client.api import ServerIdentity -from ...client.api import generate_remote_function +from ...abstract_server import ServerSideType, ServerType +from ...client.api import APIRegistry, ServerIdentity, generate_remote_function from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize -from ...serde.signature import signature_remove_context -from ...serde.signature import signature_remove_self +from ...serde.signature import signature_remove_context, signature_remove_self from ...server.credentials import SyftVerifyKey from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.dicttuple import DictTuple -from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, PartialSyftObject, SyftObject from ...types.syncable_object import SyncableSyftObject -from ...types.transforms import TransformContext -from ...types.transforms import add_server_uid_for_key -from ...types.transforms import generate_id -from ...types.transforms import transform +from ...types.transforms import ( + TransformContext, + add_server_uid_for_key, + generate_id, + transform, +) from ...types.uid import UID from ...util import options from ...util.colors import SURFACE from ...util.decorators import deprecated -from ...util.markdown import CodeMarkdown -from ...util.markdown import as_markdown_code +from ...util.markdown import CodeMarkdown, as_markdown_code from ...util.notebook_ui.styles import FONT_CSS from ...util.util import prompt_warning_message from ..action.action_endpoint import CustomEndpointActionObject -from ..action.action_object import Action -from ..action.action_object import ActionObject +from ..action.action_object import Action, ActionObject from ..context import AuthedServiceContext from ..dataset.dataset import Asset from ..job.job_stash import Job -from ..output.output_service import ExecutionOutput -from ..output.output_service import OutputService -from ..policy.policy import Constant -from ..policy.policy import CustomInputPolicy -from ..policy.policy import CustomOutputPolicy -from ..policy.policy import EmpyInputPolicy -from ..policy.policy import ExactMatch -from ..policy.policy import InputPolicy -from ..policy.policy import OutputPolicy -from ..policy.policy import SingleExecutionExactOutput -from ..policy.policy import SubmitUserPolicy -from ..policy.policy import UserPolicy -from ..policy.policy import filter_only_uids -from ..policy.policy import init_policy -from ..policy.policy import load_policy_code -from ..policy.policy import partition_by_server +from ..output.output_service import ExecutionOutput, OutputService +from ..policy.policy import ( + Constant, + CustomInputPolicy, + CustomOutputPolicy, + EmpyInputPolicy, + ExactMatch, + InputPolicy, + OutputPolicy, + SingleExecutionExactOutput, + SubmitUserPolicy, + UserPolicy, + filter_only_uids, + init_policy, + load_policy_code, + partition_by_server, +) from ..policy.policy_service import PolicyService -from ..response import SyftError -from ..response import SyftException -from ..response import SyftInfo -from ..response import SyftNotReady -from ..response import SyftSuccess -from ..response import SyftWarning +from ..response import ( + SyftError, + SyftException, + SyftInfo, + SyftNotReady, + SyftSuccess, + SyftWarning, +) from ..service import ServiceConfigRegistry from ..user.user import UserView from ..user.user_roles import ServiceRole from .code_parse import LaunchJobVisitor from .unparse import unparse -from .utils import check_for_global_vars -from .utils import parse_code -from .utils import submit_subjobs_code +from .utils import check_for_global_vars, parse_code, submit_subjobs_code if TYPE_CHECKING: # relative @@ -200,11 +186,11 @@ def get_status_message(self) -> SyftSuccess | SyftNotReady | SyftError: ) if self.denied: return SyftError( - message=f"{type(self)} Your code cannot be run: {denial_string}" + message=f"{type(self)} Your code cannot be run: {denial_string}", ) else: return SyftNotReady( - message=f"{type(self)} Your code is waiting for approval. {string}" + message=f"{type(self)} Your code is waiting for approval. {string}", ) @property @@ -240,11 +226,11 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus: return self.status_dict[server_identity][0] else: raise Exception( - f"Code Object does not contain {context.server.name} Datasite's data" + f"Code Object does not contain {context.server.name} Datasite's data", ) else: raise Exception( - f"Invalid Server Type for Code Submission:{context.server.server_type}" + f"Invalid Server Type for Code Submission:{context.server.server_type}", ) def mutate( @@ -255,7 +241,7 @@ def mutate( verify_key: SyftVerifyKey, ) -> SyftError | Self: server_identity = ServerIdentity( - server_name=server_name, server_id=server_id, verify_key=verify_key + server_name=server_name, server_id=server_id, verify_key=verify_key, ) status_dict = self.status_dict if server_identity in status_dict: @@ -264,7 +250,7 @@ def mutate( return self else: return SyftError( - message="Cannot Modify Status as the Datasite's data is not included in the request" + message="Cannot Modify Status as the Datasite's data is not included in the request", ) def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: @@ -399,12 +385,12 @@ def user(self) -> UserView | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.user.get_by_verify_key(self.user_verify_key) def _compute_status_l0( - self, context: AuthedServiceContext | None = None + self, context: AuthedServiceContext | None = None, ) -> UserCodeStatusCollection | SyftError: if context is None: # Clientside @@ -415,7 +401,7 @@ def _compute_status_l0( if self._has_output_read_permissions_cache is None: is_approved = api.output.has_output_read_permissions( - self.id, self.user_verify_key + self.id, self.user_verify_key, ) self._has_output_read_permissions_cache = is_approved else: @@ -425,7 +411,7 @@ def _compute_status_l0( server_identity = ServerIdentity.from_server(context.server) output_service = context.server.get_service("outputservice") is_approved = output_service.has_output_read_permissions( - context, self.id, self.user_verify_key + context, self.id, self.user_verify_key, ) if isinstance(is_approved, SyftError): @@ -436,7 +422,7 @@ def _compute_status_l0( if is_approved: prompt_warning_message( "This request already has results published to the data scientist. " - "They will still be able to access those results." + "They will still be able to access those results.", ) message = self.l0_deny_reason status = (UserCodeStatus.DENIED, message) @@ -458,29 +444,29 @@ def status(self) -> UserCodeStatusCollection | SyftError: if self.is_l0_deployment: if self.status_link is not None: return SyftError( - message="Encountered a low side UserCode object with a status_link." + message="Encountered a low side UserCode object with a status_link.", ) return self._compute_status_l0() if self.status_link is None: return SyftError( - message="This UserCode does not have a status. Please contact the Admin." + message="This UserCode does not have a status. Please contact the Admin.", ) res = self.status_link.resolve return res def get_status( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> UserCodeStatusCollection | SyftError: if self.is_l0_deployment: if self.status_link is not None: return SyftError( - message="Encountered a low side UserCode object with a status_link." + message="Encountered a low side UserCode object with a status_link.", ) return self._compute_status_l0(context) if self.status_link is None: return SyftError( - message="This UserCode does not have a status. Please contact the Admin." + message="This UserCode does not have a status. Please contact the Admin.", ) status = self.status_link.resolve_with_context(context) @@ -556,15 +542,15 @@ def _get_input_policy(self) -> InputPolicy | None: if server_view_workaround: input_policy = self.input_policy_type( - init_kwargs=self.input_policy_init_kwargs + init_kwargs=self.input_policy_init_kwargs, ) else: input_policy = self.input_policy_type( - **self.input_policy_init_kwargs + **self.input_policy_init_kwargs, ) elif isinstance(self.input_policy_type, UserPolicy): input_policy = init_policy( - self.input_policy_type, self.input_policy_init_kwargs + self.input_policy_type, self.input_policy_init_kwargs, ) else: raise Exception(f"Invalid output_policy_type: {self.input_policy_type}") @@ -609,18 +595,18 @@ def _get_output_policy(self) -> OutputPolicy | None: if len(self.output_policy_state) == 0: output_policy = None if isinstance(self.output_policy_type, type) and issubclass( - self.output_policy_type, OutputPolicy + self.output_policy_type, OutputPolicy, ): output_policy = self.output_policy_type( - **self.output_policy_init_kwargs + **self.output_policy_init_kwargs, ) elif isinstance(self.output_policy_type, UserPolicy): output_policy = init_policy( - self.output_policy_type, self.output_policy_init_kwargs + self.output_policy_type, self.output_policy_init_kwargs, ) else: raise Exception( - f"Invalid output_policy_type: {self.output_policy_type}" + f"Invalid output_policy_type: {self.output_policy_type}", ) if output_policy is not None: @@ -665,19 +651,19 @@ def output_policy(self, value: Any) -> None: # type: ignore @property def output_history(self) -> list[ExecutionOutput] | SyftError: api = APIRegistry.api_for( - self.syft_server_location, self.syft_client_verify_key + self.syft_server_location, self.syft_client_verify_key, ) if api is None: return SyftError( - message=f"Can't access the api. You must login to {self.syft_server_location}" + message=f"Can't access the api. You must login to {self.syft_server_location}", ) return api.services.output.get_by_user_code_id(self.id) def get_output_history( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[ExecutionOutput] | SyftError: output_service = cast( - OutputService, context.server.get_service("outputservice") + OutputService, context.server.get_service("outputservice"), ) return output_service.get_by_user_code_id(context, self.id) @@ -692,7 +678,7 @@ def store_execution_output( output_policy = self.get_output_policy(context) if output_policy is None and not is_admin: return SyftError( - message="You must wait for the output policy to be approved" + message="You must wait for the output policy to be approved", ) output_ids = filter_only_uids(outputs) @@ -781,7 +767,7 @@ def _inputs_json(self) -> str | SyftError: return input_str def get_sync_dependencies( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[UID] | SyftError: dependencies = [] @@ -799,7 +785,7 @@ def get_sync_dependencies( @property def run(self) -> Callable | None: warning = SyftWarning( - message="This code was submitted by a User and could be UNSAFE." + message="This code was submitted by a User and could be UNSAFE.", ) display(warning) @@ -818,12 +804,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable | SyftError: if on_private_data: display( SyftInfo( - message="The result you see is computed on PRIVATE data." - ) + message="The result you see is computed on PRIVATE data.", + ), ) if on_mock_data: display( - SyftInfo(message="The result you see is computed on MOCK data.") + SyftInfo(message="The result you see is computed on MOCK data."), ) # remove the decorator @@ -868,7 +854,7 @@ def _inner_repr(self, level: int = 0) -> str: # indent all lines except the first one inputs_str = "\n".join( - [f" {line}" for line in self._inputs_json.split("\n")] + [f" {line}" for line in self._inputs_json.split("\n")], ).lstrip() md = f"""class UserCode @@ -890,7 +876,7 @@ def _inner_repr(self, level: int = 0) -> str: """ md = "\n".join( - [f"{' '*level}{substring}" for substring in md.split("\n")[:-1]] + [f"{' '*level}{substring}" for substring in md.split("\n")[:-1]], ) if self.nested_codes is not None: for obj, _ in self.nested_codes.values(): @@ -941,7 +927,7 @@ def _ipython_display_(self, level: int = 0) -> None: """ md = "\n".join( - [f"{' '*level}{substring}" for substring in self.raw_code.split("\n")[:-1]] + [f"{' '*level}{substring}" for substring in self.raw_code.split("\n")[:-1]], ) display(HTML(repr_str), Markdown(as_markdown_code(md))) if self.nested_codes is not None and self.nested_codes != {}: @@ -1092,8 +1078,8 @@ def _ephemeral_server_call( if time_alive is None and not blocking: print( SyftInfo( - message="Closing the server after time_alive=300 (the default value)" - ) + message="Closing the server after time_alive=300 (the default value)", + ), ) time_alive = 300 @@ -1117,7 +1103,7 @@ def _ephemeral_server_call( api = APIRegistry.get_by_recent_server_uid(server_uid=server_id.server_id) if api is None: return SyftError( - f"Can't access the api. You must login to {server_id.server_id}" + f"Can't access the api. You must login to {server_id.server_id}", ) # Creating TwinObject from the ids of the kwargs # Maybe there are some corner cases where this is not enough @@ -1132,7 +1118,7 @@ def _ephemeral_server_call( return SyftError( message="You do not have access to object you want \ to use, or the private object does not have mock \ - data. Contact the Server Admin." + data. Contact the Server Admin.", ) else: data_obj = mock_obj @@ -1198,7 +1184,7 @@ def is_valid_usercode_name(func_name: str) -> Result[Any, str]: service_method_path = f"code.{func_name}" if ServiceConfigRegistry.path_exists(service_method_path): return Err( - f"Could not create syft function with name {func_name}: a service with the same name already exists" + f"Could not create syft function with name {func_name}: a service with the same name already exists", ) return Ok(None) @@ -1323,7 +1309,7 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: success_message = SyftSuccess( message=f"Syft function '{f.__name__}' successfully created. " f"To add a code request, please create a project using `project = syft.Project(...)`, " - f"then use command `project.create_code_request`." + f"then use command `project.create_code_request`.", ) display(success_message) @@ -1367,7 +1353,7 @@ def parse_user_code( call_stmt = ast.Assign( targets=[ast.Name(id="result")], value=ast.Call( - func=ast.Name(id=original_func_name), args=[], keywords=call_stmt_keywords + func=ast.Name(id=original_func_name), args=[], keywords=call_stmt_keywords, ), lineno=0, ) @@ -1456,7 +1442,7 @@ def locate_launch_jobs(context: TransformContext) -> TransformContext: # TODO: Not great user_code = user_codes[-1] user_code_link = LinkedObject.from_obj( - user_code, server_uid=context.server.id + user_code, server_uid=context.server.id, ) nested_codes[call] = (user_code_link, user_code.nested_codes) context.output["nested_codes"] = nested_codes @@ -1479,7 +1465,7 @@ def compile_code(context: TransformContext) -> TransformContext: if byte_code is None: raise ValueError( "Unable to compile byte code from parsed code. " - + context.output["parsed_code"] + + context.output["parsed_code"], ) return context @@ -1578,7 +1564,7 @@ def create_code_status(context: TransformContext) -> TransformContext: ) else: raise NotImplementedError( - f"Invalid server type:{context.server.server_type} for code submission" + f"Invalid server type:{context.server.server_type} for code submission", ) res = context.server.get_service("usercodestatusservice").create(context, status) @@ -1717,11 +1703,11 @@ def launch_job(func: UserCode, **kwargs: Any) -> Job | None: for k, v in kwargs.items(): value = ActionObject.from_obj(v) ptr = action_service.set_result_to_store( - value, context, has_result_read_permission=False + value, context, has_result_read_permission=False, ) if ptr.is_err(): raise ValueError( - f"failed to create argument {k} for launch job using value {v}" + f"failed to create argument {k} for launch job using value {v}", ) ptr = ptr.ok() kw2id[k] = ptr.id @@ -1752,7 +1738,7 @@ def launch_job(func: UserCode, **kwargs: Any) -> Job | None: def execute_byte_code( - code_item: UserCode, kwargs: dict[str, Any], context: AuthedServiceContext + code_item: UserCode, kwargs: dict[str, Any], context: AuthedServiceContext, ) -> Any: stdout_ = sys.stdout stderr_ = sys.stderr @@ -1778,7 +1764,7 @@ def increment_progress(self, n: int = 1) -> None: self._set_progress(by=n) def _set_progress( - self, to: int | None = None, by: int | None = None + self, to: int | None = None, by: int | None = None, ) -> None: if safe_context.is_async is not None: if by is None and to is None: @@ -1861,7 +1847,7 @@ def to_str(arg: Any) -> str: if context.job is not None: time = datetime.datetime.now().strftime("%d/%m/%y %H:%M:%S") original_print( - f"{time} EXCEPTION LOG ({job_id}):\n{error_msg}", file=sys.stderr + f"{time} EXCEPTION LOG ({job_id}):\n{error_msg}", file=sys.stderr, ) else: # for local execution @@ -1913,7 +1899,8 @@ def to_str(arg: Any) -> str: def traceback_from_error(e: Exception, code: UserCode) -> str: """We do this because the normal traceback.format_exc() does not work well for exec, - it missed the references to the actual code""" + it missed the references to the actual code + """ line_nr = 0 tb = e.__traceback__ while tb is not None: @@ -1942,7 +1929,7 @@ def traceback_from_error(e: Exception, code: UserCode) -> str: def load_approved_policy_code( - user_code_items: list[UserCode], context: AuthedServiceContext | None + user_code_items: list[UserCode], context: AuthedServiceContext | None, ) -> Any: """Reload the policy code in memory for user code that is approved.""" try: diff --git a/packages/syft/src/syft/service/code/user_code_parse.py b/packages/syft/src/syft/service/code/user_code_parse.py index 5a17a7ba7f5..dcc682c095f 100644 --- a/packages/syft/src/syft/service/code/user_code_parse.py +++ b/packages/syft/src/syft/service/code/user_code_parse.py @@ -33,13 +33,13 @@ def make_ast_args(args: list[str]) -> ast.arguments: def make_ast_func( - name: str, input_kwargs: list[str], output_arg: str, body: list[ast.AST] + name: str, input_kwargs: list[str], output_arg: str, body: list[ast.AST], ) -> ast.FunctionDef: args = make_ast_args(input_kwargs) r = make_return(output_arg) new_body = body + [r] f = ast.FunctionDef( - name=name, args=args, body=new_body, decorator_list=[], lineno=0 + name=name, args=args, body=new_body, decorator_list=[], lineno=0, ) return f diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 4d200bdbf22..64e04c24410 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -1,12 +1,8 @@ # stdlib -from typing import Any -from typing import TypeVar -from typing import cast +from typing import Any, TypeVar, cast # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable @@ -18,34 +14,34 @@ from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_object import ActionObject -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionPermission +from ..action.action_permissions import ActionObjectPermission, ActionPermission from ..action.action_service import ActionService from ..context import AuthedServiceContext from ..output.output_service import ExecutionOutput from ..policy.policy import OutputPolicy -from ..request.request import Request -from ..request.request import SubmitRequest -from ..request.request import SyncedUserCodeStatusChange -from ..request.request import UserCodeStatusChange +from ..request.request import ( + Request, + SubmitRequest, + SyncedUserCodeStatusChange, + UserCodeStatusChange, +) from ..request.request_service import RequestService -from ..response import SyftError -from ..response import SyftNotReady -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from ..user.user_roles import ServiceRole -from .user_code import SubmitUserCode -from .user_code import UserCode -from .user_code import UserCodeStatus -from .user_code import UserCodeUpdate -from .user_code import get_code_hash -from .user_code import load_approved_policy_code +from ..response import SyftError, SyftNotReady, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + ADMIN_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, + ServiceRole, +) +from .user_code import ( + SubmitUserCode, + UserCode, + UserCodeStatus, + UserCodeUpdate, + get_code_hash, + load_approved_policy_code, +) from .user_code_stash import UserCodeStash @@ -61,7 +57,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="code.submit", name="submit", roles=GUEST_ROLE_LEVEL) def submit( - self, context: AuthedServiceContext, code: SubmitUserCode + self, context: AuthedServiceContext, code: SubmitUserCode, ) -> UserCode | SyftError: """Add User Code""" result = self._submit(context, code, exists_ok=False) @@ -75,19 +71,21 @@ def _submit( submit_code: SubmitUserCode, exists_ok: bool = False, ) -> Result[UserCode, str]: - """ - Submit a UserCode. + """Submit a UserCode. If exists_ok is True, the function will return the existing code if it exists. Args: + ---- context (AuthedServiceContext): context submit_code (SubmitUserCode): UserCode to submit exists_ok (bool, optional): If True, return the existing code if it exists. If false, existing codes returns Err. Defaults to False. Returns: + ------- Result[UserCode, str]: New UserCode or error + """ existing_code_or_err = self.stash.get_by_code_hash( context.credentials, @@ -109,12 +107,12 @@ def _submit( # if the validation fails, we should remove the user code status # and code version to prevent dangling status root_context = AuthedServiceContext( - credentials=context.server.verify_key, server=context.server + credentials=context.server.verify_key, server=context.server, ) if code.status_link is not None: _ = context.server.get_service("usercodestatusservice").remove( - root_context, code.status_link.object_uid + root_context, code.status_link.object_uid, ) return result @@ -152,7 +150,7 @@ def update( @service_method(path="code.delete", name="delete", roles=ADMIN_ROLE_LEVEL) def delete( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Delete User Code""" result = self.stash.delete_by_uid(context.credentials, uid) @@ -166,10 +164,10 @@ def delete( roles=GUEST_ROLE_LEVEL, ) def get_by_service_name( - self, context: AuthedServiceContext, service_func_name: str + self, context: AuthedServiceContext, service_func_name: str, ) -> list[UserCode] | SyftError: result = self.stash.get_by_service_func_name( - context.credentials, service_func_name=service_func_name + context.credentials, service_func_name=service_func_name, ) if result.is_err(): return SyftError(message=str(result.err())) @@ -214,7 +212,7 @@ def _request_code_execution( ) -> Request | SyftError: # Cannot make multiple requests for the same code get_by_usercode_id = context.server.get_service_method( - RequestService.get_by_usercode_id + RequestService.get_by_usercode_id, ) existing_requests = get_by_usercode_id(context, user_code.id) if isinstance(existing_requests, SyftError): @@ -222,7 +220,7 @@ def _request_code_execution( if len(existing_requests) > 0: return SyftError( message=f"Request {existing_requests[0].id} already exists for this UserCode. " - f"Please use the existing request, or submit a new UserCode to create a new request." + f"Please use the existing request, or submit a new UserCode to create a new request.", ) # Users that have access to the output also have access to the code item @@ -231,7 +229,7 @@ def _request_code_execution( [ ActionObjectPermission(user_code.id, ActionPermission.READ, x) for x in user_code.output_readers - ] + ], ) code_link = LinkedObject.from_obj(user_code, server_uid=context.server.id) @@ -263,8 +261,7 @@ def _get_or_submit_user_code( context: AuthedServiceContext, code: SubmitUserCode | UserCode, ) -> Result[UserCode, str]: - """ - - If the code is a UserCode, check if it exists and return + """- If the code is a UserCode, check if it exists and return - If the code is a SubmitUserCode and the same code hash exists, return the existing code - If the code is a SubmitUserCode and the code hash does not exist, submit the code """ @@ -294,7 +291,6 @@ def request_code_execution( reason: str | None = "", ) -> Request | SyftError: """Request Code execution on user code""" - user_code_or_err = self._get_or_submit_user_code(context, code) if user_code_or_err.is_err(): return SyftError(message=user_code_or_err.err()) @@ -315,10 +311,10 @@ def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError: return SyftError(message=result.err()) @service_method( - path="code.get_by_id", name="get_by_id", roles=DATA_SCIENTIST_ROLE_LEVEL + path="code.get_by_id", name="get_by_id", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> UserCode | SyftError: """Get a User Code Item""" result = self.stash.get_by_uid(context.credentials, uid=uid) @@ -340,7 +336,7 @@ def get_by_uid( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_all_for_user( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> SyftSuccess | SyftError: """Get All User Code Items for User's VerifyKey""" # TODO: replace with incoming user context and key @@ -350,7 +346,7 @@ def get_all_for_user( return SyftError(message=result.err()) def update_code_state( - self, context: AuthedServiceContext, code_item: UserCode + self, context: AuthedServiceContext, code_item: UserCode, ) -> SyftSuccess | SyftError: context = context.as_root_context() result = self.stash.update(context.credentials, code_item) @@ -388,7 +384,7 @@ def is_execution_allowed( return True def is_execution_on_owned_args_allowed( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> bool | SyftError: if context.role == ServiceRole.ADMIN: return True @@ -398,10 +394,9 @@ def is_execution_on_owned_args_allowed( return current_user.mock_execution_permission def keep_owned_kwargs( - self, kwargs: dict[str, Any], context: AuthedServiceContext + self, kwargs: dict[str, Any], context: AuthedServiceContext, ) -> dict[str, Any] | SyftError: """Return only the kwargs that are owned by the user""" - action_service = context.server.get_service("actionservice") mock_kwargs = {} @@ -426,7 +421,7 @@ def is_execution_on_owned_args( ) -> bool: # Check if all kwargs are owned by the user all_kwargs_are_owned = len( - self.keep_owned_kwargs(passed_kwargs, context) + self.keep_owned_kwargs(passed_kwargs, context), ) == len(passed_kwargs) if not all_kwargs_are_owned: return False @@ -445,7 +440,7 @@ def is_execution_on_owned_args( @service_method(path="code.call", name="call", roles=GUEST_ROLE_LEVEL) def call( - self, context: AuthedServiceContext, uid: UID, **kwargs: Any + self, context: AuthedServiceContext, uid: UID, **kwargs: Any, ) -> CachedSyftObject | ActionObject | SyftSuccess | SyftError: """Call a User Code Function""" kwargs.pop("result_id", None) @@ -456,7 +451,7 @@ def call( return result.ok() def valid_worker_pool_for_context( - self, context: AuthedServiceContext, user_code: UserCode + self, context: AuthedServiceContext, user_code: UserCode, ) -> bool: """This is a temporary fix that is needed until every function is always just ran as job""" # relative @@ -495,7 +490,7 @@ def _call( pass else: return Err( - "You do not have the permissions for mock execution, please contact the admin" + "You do not have the permissions for mock execution, please contact the admin", ) override_execution_permission = ( context.has_execute_permissions or context.role == ServiceRole.ADMIN @@ -544,7 +539,7 @@ def _call( if ( input_policy is not None and not last_executed_output.check_input_ids( - kwargs=kwarg2id + kwargs=kwarg2id, ) ): inp_policy_validation = input_policy._is_valid( @@ -572,7 +567,7 @@ def _call( CachedSyftObject( result=res, error_msg=output_policy_message, - ) + ), ) else: return cast(Err, is_valid.to_result()) @@ -583,12 +578,12 @@ def _call( return Err( value="You tried to run a syft function attached to a worker pool in blocking mode," "which is currently not supported. Run your function with `blocking=False` to run" - " as a job on your worker pool" + " as a job on your worker pool", ) action_service: ActionService = context.server.get_service("actionservice") # type: ignore result_action_object: Result[ActionObject | TwinObject, str] = ( action_service._user_code_execute( - context, code, kwarg2id, result_id=result_id + context, code, kwarg2id, result_id=result_id, ) ) if result_action_object.is_err(): @@ -597,7 +592,7 @@ def _call( result_action_object = result_action_object.ok() output_result = action_service.set_result_to_store( - result_action_object, context, code.get_output_policy(context) + result_action_object, context, code.get_output_policy(context), ) if output_result.is_err(): @@ -626,7 +621,7 @@ def _call( # print(res) has_result_read_permission = action_service.has_read_permission( - context, result.id + context, result.id, ) if isinstance(result, TwinObject): @@ -650,14 +645,14 @@ def _call( return Err(value=f"Failed to run. {e}, {traceback.format_exc()}") def has_code_permission( - self, code_item: UserCode, context: AuthedServiceContext + self, code_item: UserCode, context: AuthedServiceContext, ) -> SyftSuccess | SyftError: if not ( context.credentials == context.server.verify_key or context.credentials == code_item.user_verify_key ): return SyftError( - message=f"Code Execution Permission: {context.credentials} denied" + message=f"Code Execution Permission: {context.credentials} denied", ) return SyftSuccess(message="you have permission") @@ -707,7 +702,7 @@ def resolve_outputs( if context.server is not None: action_service = context.server.get_service("actionservice") result = action_service.get( - context, uid=output_id, twin_mode=TwinMode.PRIVATE + context, uid=output_id, twin_mode=TwinMode.PRIVATE, ) if result.is_err(): return result diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 0fcb41b2087..8d12338d4a9 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -6,16 +6,20 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionSettings, + QueryKeys, +) from ...util.telemetry import instrument -from .user_code import CodeHashPartitionKey -from .user_code import ServiceFuncNamePartitionKey -from .user_code import SubmitTimePartitionKey -from .user_code import UserCode -from .user_code import UserVerifyKeyPartitionKey +from .user_code import ( + CodeHashPartitionKey, + ServiceFuncNamePartitionKey, + SubmitTimePartitionKey, + UserCode, + UserVerifyKeyPartitionKey, +) @instrument @@ -23,28 +27,28 @@ class UserCodeStash(BaseUIDStoreStash): object_type = UserCode settings: PartitionSettings = PartitionSettings( - name=UserCode.__canonical_name__, object_type=UserCode + name=UserCode.__canonical_name__, object_type=UserCode, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_all_by_user_verify_key( - self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey, ) -> Result[list[UserCode], str]: qks = QueryKeys(qks=[UserVerifyKeyPartitionKey.with_obj(user_verify_key)]) return self.query_one(credentials=credentials, qks=qks) def get_by_code_hash( - self, credentials: SyftVerifyKey, code_hash: str + self, credentials: SyftVerifyKey, code_hash: str, ) -> Result[UserCode | None, str]: qks = QueryKeys(qks=[CodeHashPartitionKey.with_obj(code_hash)]) return self.query_one(credentials=credentials, qks=qks) def get_by_service_func_name( - self, credentials: SyftVerifyKey, service_func_name: str + self, credentials: SyftVerifyKey, service_func_name: str, ) -> Result[list[UserCode], str]: qks = QueryKeys(qks=[ServiceFuncNamePartitionKey.with_obj(service_func_name)]) return self.query_all( - credentials=credentials, qks=qks, order_by=SubmitTimePartitionKey + credentials=credentials, qks=qks, order_by=SubmitTimePartitionKey, ) diff --git a/packages/syft/src/syft/service/code/utils.py b/packages/syft/src/syft/service/code/utils.py index 6c1371a8c64..6a606e3f8db 100644 --- a/packages/syft/src/syft/service/code/utils.py +++ b/packages/syft/src/syft/service/code/utils.py @@ -6,10 +6,8 @@ from IPython import get_ipython # relative -from ..response import SyftException -from ..response import SyftWarning -from .code_parse import GlobalsVisitor -from .code_parse import LaunchJobVisitor +from ..response import SyftException, SyftWarning +from .code_parse import GlobalsVisitor, LaunchJobVisitor def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore @@ -30,7 +28,6 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore ) # works only in interactive envs (like jupyter notebooks) except Exception: ipython = None - pass for call in nested_calls: if ipython is not None: @@ -42,22 +39,20 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore def check_for_global_vars(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: - """ - Check that the code does not contain any global variables + """Check that the code does not contain any global variables """ v = GlobalsVisitor() try: v.visit(code_tree) except Exception: raise SyftException( - "Your code contains (a) global variable(s), which is not allowed" + "Your code contains (a) global variable(s), which is not allowed", ) return v def parse_code(raw_code: str) -> ast.Module | SyftWarning: - """ - Parse the code into an AST tree and return a warning if there are syntax errors + """Parse the code into an AST tree and return a warning if there are syntax errors """ try: tree = ast.parse(raw_code) diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index db1464c4add..b9757d88ab2 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -6,9 +6,7 @@ from ...client.api import APIRegistry from ...serde.serializable import serializable from ...service.user.user_roles import ServiceRole -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.syft_object import SyftVerifyKey +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject, SyftVerifyKey from ...types.uid import UID from ...util.notebook_ui.components.tabulator_template import ( build_tabulator_table_with_data, @@ -74,11 +72,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError: if isinstance(index, str): raise TypeError(f"index {index} must be an integer, not a string") api = APIRegistry.api_for( - self.syft_server_location, self.syft_client_verify_key + self.syft_server_location, self.syft_client_verify_key, ) if api is None: return SyftError( - message=f"Can't access the api. You must login to {self.server_uid}" + message=f"Can't access the api. You must login to {self.server_uid}", ) if ( api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value @@ -87,7 +85,7 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError: # negative index would dynamically resolve to a different version return SyftError( message="For security concerns we do not allow negative indexing. \ - Try using absolute values when indexing" + Try using absolute values when indexing", ) return self.user_code_history[index] @@ -141,7 +139,7 @@ def __getitem__(self, key: str | int) -> CodeHistoriesDict | SyftError: api = APIRegistry.api_for(self.server_uid, self.syft_client_verify_key) if api is None: return SyftError( - message=f"Can't access the api. You must login to {self.server_uid}" + message=f"Can't access the api. You must login to {self.server_uid}", ) return api.services.code_history.get_history_for_user(key) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 762d7c5da91..db88d02263f 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -6,21 +6,22 @@ from ...store.document_store import DocumentStore from ...types.uid import UID from ...util.telemetry import instrument -from ..code.user_code import SubmitUserCode -from ..code.user_code import UserCode +from ..code.user_code import SubmitUserCode, UserCode from ..code.user_code_service import UserCodeService from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import service_method -from ..user.user_roles import DATA_OWNER_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import ServiceRole -from .code_history import CodeHistoriesDict -from .code_history import CodeHistory -from .code_history import CodeHistoryView -from .code_history import UsersCodeHistoriesDict +from ..response import SyftError, SyftSuccess +from ..service import AbstractService, service_method +from ..user.user_roles import ( + DATA_OWNER_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + ServiceRole, +) +from .code_history import ( + CodeHistoriesDict, + CodeHistory, + CodeHistoryView, + UsersCodeHistoriesDict, +) from .code_history_stash import CodeHistoryStash @@ -82,7 +83,7 @@ def submit_version( return SyftSuccess(message="Code version submit success") @service_method( - path="code_history.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL + path="code_history.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_all(self, context: AuthedServiceContext) -> list[CodeHistory] | SyftError: """Get a Dataset""" @@ -92,10 +93,10 @@ def get_all(self, context: AuthedServiceContext) -> list[CodeHistory] | SyftErro return SyftError(message=result.err()) @service_method( - path="code_history.get", name="get", roles=DATA_SCIENTIST_ROLE_LEVEL + path="code_history.get", name="get", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_code_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Get a User Code Item""" result = self.stash.get_by_uid(context.credentials, uid=uid) @@ -106,7 +107,7 @@ def get_code_by_uid( @service_method(path="code_history.delete", name="delete") def delete( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: result = self.stash.delete_by_uid(context.credentials, uid) if result.is_ok(): @@ -115,19 +116,19 @@ def delete( return SyftError(message=result.err()) def fetch_histories_for_user( - self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey + self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey, ) -> CodeHistoriesDict | SyftError: if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: result = self.stash.get_by_verify_key( - credentials=context.server.verify_key, user_verify_key=user_verify_key + credentials=context.server.verify_key, user_verify_key=user_verify_key, ) else: result = self.stash.get_by_verify_key( - credentials=context.credentials, user_verify_key=user_verify_key + credentials=context.credentials, user_verify_key=user_verify_key, ) user_code_service: UserCodeService = context.server.get_service( - "usercodeservice" + "usercodeservice", ) # type: ignore def get_code(uid: UID) -> UserCode | SyftError: @@ -160,10 +161,10 @@ def get_code(uid: UID) -> UserCode | SyftError: roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_histories_for_current_user( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> CodeHistoriesDict | SyftError: return self.fetch_histories_for_user( - context=context, user_verify_key=context.credentials + context=context, user_verify_key=context.credentials, ) @service_method( @@ -172,16 +173,16 @@ def get_histories_for_current_user( roles=DATA_OWNER_ROLE_LEVEL, ) def get_history_for_user( - self, context: AuthedServiceContext, email: str + self, context: AuthedServiceContext, email: str, ) -> CodeHistoriesDict | SyftError: user_service = context.server.get_service("userservice") result = user_service.stash.get_by_email( - credentials=context.credentials, email=email + credentials=context.credentials, email=email, ) if result.is_ok(): user = result.ok() return self.fetch_histories_for_user( - context=context, user_verify_key=user.verify_key + context=context, user_verify_key=user.verify_key, ) return SyftError(message=result.err()) @@ -191,7 +192,7 @@ def get_history_for_user( roles=DATA_OWNER_ROLE_LEVEL, ) def get_histories_group_by_user( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> UsersCodeHistoriesDict | SyftError: if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: result = self.stash.get_all(context.credentials, has_permission=True) @@ -217,7 +218,7 @@ def get_histories_group_by_user( for code_history in code_histories: user_email = verify_key_2_user_email[code_history.user_verify_key] user_code_histories.user_dict[user_email].append( - code_history.service_func_name + code_history.service_func_name, ) return user_code_histories diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index c419416664a..4c76b5c3594 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -6,11 +6,13 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, +) from .code_history import CodeHistory NamePartitionKey = PartitionKey(key="service_func_name", type_=str) @@ -21,7 +23,7 @@ class CodeHistoryStash(BaseUIDStoreStash): object_type = CodeHistory settings: PartitionSettings = PartitionSettings( - name=CodeHistory.__canonical_name__, object_type=CodeHistory + name=CodeHistory.__canonical_name__, object_type=CodeHistory, ) def __init__(self, store: DocumentStore) -> None: @@ -37,18 +39,18 @@ def get_by_service_func_name_and_verify_key( qks=[ NamePartitionKey.with_obj(service_func_name), VerifyKeyPartitionKey.with_obj(user_verify_key), - ] + ], ) return self.query_one(credentials=credentials, qks=qks) def get_by_service_func_name( - self, credentials: SyftVerifyKey, service_func_name: str + self, credentials: SyftVerifyKey, service_func_name: str, ) -> Result[list[CodeHistory], str]: qks = QueryKeys(qks=[NamePartitionKey.with_obj(service_func_name)]) return self.query_all(credentials=credentials, qks=qks) def get_by_verify_key( - self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey, ) -> Result[CodeHistory | None, str]: if isinstance(user_verify_key, str): user_verify_key = SyftVerifyKey.from_string(user_verify_key) diff --git a/packages/syft/src/syft/service/context.py b/packages/syft/src/syft/service/context.py index bafa07bb5c5..8dabf4bb01e 100644 --- a/packages/syft/src/syft/service/context.py +++ b/packages/syft/src/syft/service/context.py @@ -6,16 +6,15 @@ # relative from ..abstract_server import AbstractServer -from ..server.credentials import SyftVerifyKey -from ..server.credentials import UserLoginCredentials -from ..types.syft_object import Context -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftObject +from ..server.credentials import SyftVerifyKey, UserLoginCredentials +from ..types.syft_object import ( + SYFT_OBJECT_VERSION_1, + Context, + SyftBaseObject, + SyftObject, +) from ..types.uid import UID -from .user.user_roles import ROLE_TO_CAPABILITIES -from .user.user_roles import ServiceRole -from .user.user_roles import ServiceRoleCapability +from .user.user_roles import ROLE_TO_CAPABILITIES, ServiceRole, ServiceRoleCapability class ServerServiceContext(Context, SyftObject): @@ -46,7 +45,7 @@ def capabilities(self) -> list[ServiceRoleCapability]: def with_credentials(self, credentials: SyftVerifyKey, role: ServiceRole) -> Self: return AuthedServiceContext( - credentials=credentials, role=role, server=self.server + credentials=credentials, role=role, server=self.server, ) def as_root_context(self) -> Self: diff --git a/packages/syft/src/syft/service/data_subject/data_subject.py b/packages/syft/src/syft/service/data_subject/data_subject.py index bdeabdc0ece..70f17e1af75 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject.py +++ b/packages/syft/src/syft/service/data_subject/data_subject.py @@ -1,6 +1,5 @@ # stdlib from collections.abc import Callable -from typing import Any # third party from typing_extensions import Self @@ -8,12 +7,13 @@ # relative from ...serde.serializable import serializable from ...store.document_store import PartitionKey -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.transforms import TransformContext -from ...types.transforms import add_server_uid_for_key -from ...types.transforms import generate_id -from ...types.transforms import transform +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject +from ...types.transforms import ( + TransformContext, + add_server_uid_for_key, + generate_id, + transform, +) from ...types.uid import UID from ...util.markdown import as_markdown_python_code from ..response import SyftError @@ -50,7 +50,7 @@ def members(self) -> list: def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return hash(self) == hash(other) def __repr_syft_nested__(self) -> str: @@ -90,14 +90,14 @@ def member_count(self) -> int: def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return hash(self) == hash(other) def __repr_syft_nested__(self) -> str: return f"DataSubject({self.name})" def _create_member_relationship( - self, data_subject: Self, _relationship_set: set + self, data_subject: Self, _relationship_set: set, ) -> None: for member in data_subject.members.values(): _relationship_set.add((data_subject, member)) diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member.py b/packages/syft/src/syft/service/data_subject/data_subject_member.py index 82767e4b631..3c7500375b6 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member.py @@ -1,11 +1,9 @@ # stdlib -from typing import Any # relative from ...serde.serializable import serializable from ...store.document_store import PartitionKey -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject ParentPartitionKey = PartitionKey(key="parent", type_=str) ChildPartitionKey = PartitionKey(key="child", type_=str) @@ -25,7 +23,7 @@ class DataSubjectMemberRelationship(SyftObject): def __hash__(self) -> int: return hash(self.parent + self.child) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return hash(self) == hash(other) def __repr__(self) -> str: diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index 097c14a6ffa..60ad1a0b591 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -6,20 +6,21 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionSettings, + QueryKeys, +) from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from .data_subject_member import ChildPartitionKey -from .data_subject_member import DataSubjectMemberRelationship -from .data_subject_member import ParentPartitionKey +from ..response import SyftError, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService +from .data_subject_member import ( + ChildPartitionKey, + DataSubjectMemberRelationship, + ParentPartitionKey, +) @instrument @@ -35,13 +36,13 @@ def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_all_for_parent( - self, credentials: SyftVerifyKey, name: str + self, credentials: SyftVerifyKey, name: str, ) -> Result[DataSubjectMemberRelationship | None, str]: qks = QueryKeys(qks=[ParentPartitionKey.with_obj(name)]) return self.query_all(credentials=credentials, qks=qks) def get_all_for_child( - self, credentials: SyftVerifyKey, name: str + self, credentials: SyftVerifyKey, name: str, ) -> Result[DataSubjectMemberRelationship | None, str]: qks = QueryKeys(qks=[ChildPartitionKey.with_obj(name)]) return self.query_all(credentials=credentials, qks=qks) @@ -58,7 +59,7 @@ def __init__(self, store: DocumentStore) -> None: self.stash = DataSubjectMemberStash(store=store) def add( - self, context: AuthedServiceContext, parent: str, child: str + self, context: AuthedServiceContext, parent: str, child: str, ) -> SyftSuccess | SyftError: """Register relationship between data subject and it's member.""" relation = DataSubjectMemberRelationship(parent=parent, child=child) @@ -68,11 +69,11 @@ def add( return SyftSuccess(message=f"Relationship added for: {parent} -> {child}") def get_relatives( - self, context: AuthedServiceContext, data_subject_name: str + self, context: AuthedServiceContext, data_subject_name: str, ) -> list[str] | SyftError: """Get all Members for given data subject""" result = self.stash.get_all_for_parent( - context.credentials, name=data_subject_name + context.credentials, name=data_subject_name, ) if result.is_ok(): data_subject_members = result.ok() diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index dfcd20b7110..165b59e2d91 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -6,21 +6,17 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionSettings, + QueryKeys, +) from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from .data_subject import DataSubject -from .data_subject import DataSubjectCreate -from .data_subject import NamePartitionKey +from ..response import SyftError, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from .data_subject import DataSubject, DataSubjectCreate, NamePartitionKey from .data_subject_member_service import DataSubjectMemberService @@ -29,14 +25,14 @@ class DataSubjectStash(BaseUIDStoreStash): object_type = DataSubject settings: PartitionSettings = PartitionSettings( - name=DataSubject.__canonical_name__, object_type=DataSubject + name=DataSubject.__canonical_name__, object_type=DataSubject, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_by_name( - self, credentials: SyftVerifyKey, name: str + self, credentials: SyftVerifyKey, name: str, ) -> Result[DataSubject | None, str]: qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) return self.query_one(credentials, qks=qks) @@ -66,12 +62,11 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="data_subject.add", name="add_data_subject") def add( - self, context: AuthedServiceContext, data_subject: DataSubjectCreate + self, context: AuthedServiceContext, data_subject: DataSubjectCreate, ) -> SyftSuccess | SyftError: """Register a data subject.""" - member_relationship_add = context.server.get_service_method( - DataSubjectMemberService.add + DataSubjectMemberService.add, ) member_relationships: set[tuple[str, str]] = data_subject.member_relationships @@ -98,7 +93,7 @@ def add( return result return SyftSuccess( - message=f"{len(member_relationships)+1} Data Subjects Registered" + message=f"{len(member_relationships)+1} Data Subjects Registered", ) @service_method(path="data_subject.get_all", name="get_all") @@ -112,10 +107,10 @@ def get_all(self, context: AuthedServiceContext) -> list[DataSubject] | SyftErro @service_method(path="data_subject.get_members", name="members_for") def get_members( - self, context: AuthedServiceContext, data_subject_name: str + self, context: AuthedServiceContext, data_subject_name: str, ) -> list[DataSubject] | SyftError: get_relatives = context.server.get_service_method( - DataSubjectMemberService.get_relatives + DataSubjectMemberService.get_relatives, ) relatives = get_relatives(context, data_subject_name) @@ -134,7 +129,7 @@ def get_members( @service_method(path="data_subject.get_by_name", name="get_by_name") def get_by_name( - self, context: AuthedServiceContext, name: str + self, context: AuthedServiceContext, name: str, ) -> SyftSuccess | SyftError: """Get a Data Subject by its name.""" result = self.stash.get_by_name(context.credentials, name=name) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 5b13e108435..7aead3ee526 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -1,22 +1,19 @@ # stdlib +import logging +import textwrap from collections.abc import Callable from datetime import datetime from enum import Enum -import logging -import textwrap from typing import Any -# third party -from IPython.display import display import itables import markdown import pandas as pd -from pydantic import ConfigDict -from pydantic import field_validator -from pydantic import model_validator -from result import Err -from result import Ok -from result import Result + +# third party +from IPython.display import display +from pydantic import ConfigDict, field_validator, model_validator +from result import Err, Ok, Result from typing_extensions import Self # relative @@ -25,33 +22,26 @@ from ...store.document_store import PartitionKey from ...types.datetime import DateTime from ...types.dicttuple import DictTuple -from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.transforms import TransformContext -from ...types.transforms import generate_id -from ...types.transforms import make_set_default -from ...types.transforms import transform -from ...types.transforms import validate_url +from ...types.syft_object import SYFT_OBJECT_VERSION_1, PartialSyftObject, SyftObject +from ...types.transforms import ( + TransformContext, + generate_id, + make_set_default, + transform, + validate_url, +) from ...types.uid import UID from ...util import options -from ...util.colors import ON_SURFACE_HIGHEST -from ...util.colors import SURFACE -from ...util.colors import SURFACE_SURFACE +from ...util.colors import ON_SURFACE_HIGHEST, SURFACE, SURFACE_SURFACE from ...util.markdown import as_markdown_python_code from ...util.misc_objs import MarkdownDescription from ...util.notebook_ui.icons import Icon -from ...util.notebook_ui.styles import FONT_CSS -from ...util.notebook_ui.styles import ITABLES_CSS +from ...util.notebook_ui.styles import FONT_CSS, ITABLES_CSS from ..action.action_data_empty import ActionDataEmpty from ..action.action_object import ActionObject -from ..data_subject.data_subject import DataSubject -from ..data_subject.data_subject import DataSubjectCreate +from ..data_subject.data_subject import DataSubject, DataSubjectCreate from ..data_subject.data_subject_service import DataSubjectService -from ..response import SyftError -from ..response import SyftException -from ..response import SyftSuccess -from ..response import SyftWarning +from ..response import SyftError, SyftException, SyftSuccess, SyftWarning NamePartitionKey = PartitionKey(key="name", type_=str) logger = logging.getLogger(__name__) @@ -151,18 +141,18 @@ def _repr_html_(self) -> Any: private_data_obj = private_data_res.ok_value if isinstance(private_data_obj, ActionObject): data_table_line = itables.to_html_datatable( - df=self.data.syft_action_data, css=itables_css + df=self.data.syft_action_data, css=itables_css, ) elif isinstance(private_data_obj, pd.DataFrame): data_table_line = itables.to_html_datatable( - df=private_data_obj, css=itables_css + df=private_data_obj, css=itables_css, ) else: data_table_line = private_data_res.ok_value if isinstance(self.mock, ActionObject): mock_table_line = itables.to_html_datatable( - df=self.mock.syft_action_data, css=itables_css + df=self.mock.syft_action_data, css=itables_css, ) elif isinstance(self.mock, pd.DataFrame): mock_table_line = itables.to_html_datatable(df=self.mock, css=itables_css) @@ -286,12 +276,13 @@ def has_permission(self, data_result: Any) -> bool: ) def _private_data(self) -> Result[Any, str]: - """ - Retrieves the private data associated with this asset. + """Retrieves the private data associated with this asset. - Returns: + Returns + ------- Result[Any, str]: A Result object containing the private data if the user has permission otherwise an Err object with the message "You do not have permission to access private data." + """ api = APIRegistry.api_for( server_uid=self.server_uid, @@ -373,11 +364,11 @@ def __mock_is_real_for_empty_mock_must_be_false(self) -> Self: def contains_empty(self) -> bool: if isinstance(self.mock, ActionObject) and isinstance( - self.mock.syft_action_data_cache, ActionDataEmpty + self.mock.syft_action_data_cache, ActionDataEmpty, ): return True if isinstance(self.data, ActionObject) and isinstance( - self.data.syft_action_data_cache, ActionDataEmpty + self.data.syft_action_data_cache, ActionDataEmpty, ): return True return False @@ -396,16 +387,16 @@ def add_contributor( try: _role_str = role.value if isinstance(role, Enum) else role contributor = Contributor( - name=name, role=_role_str, email=email, phone=phone, note=note + name=name, role=_role_str, email=email, phone=phone, note=note, ) if contributor in self.contributors: return SyftError( - message=f"Contributor with email: '{email}' already exists in '{self.name}' Asset." + message=f"Contributor with email: '{email}' already exists in '{self.name}' Asset.", ) self.contributors.add(contributor) return SyftSuccess( - message=f"Contributor '{name}' added to '{self.name}' Asset." + message=f"Contributor '{name}' added to '{self.name}' Asset.", ) except Exception as e: return SyftError(message=f"Failed to add contributor. Error: {e}") @@ -440,7 +431,7 @@ def set_shape(self, shape: tuple) -> None: def check(self) -> SyftSuccess | SyftError: if not check_mock(self.data, self.mock): return SyftError( - message=f"set_obj type {type(self.data)} must match set_mock type {type(self.mock)}" + message=f"set_obj type {type(self.data)} must match set_mock type {type(self.mock)}", ) # if not _is_action_data_empty(self.mock): # data_shape = get_shape_or_len(self.data) @@ -580,7 +571,7 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: for asset in self.asset_list: if asset.description is not None: description_text = textwrap.shorten( - asset.description.text, width=100, placeholder="..." + asset.description.text, width=100, placeholder="...", ) _repr_str += f"\t{asset.name}: {description_text}\n\n" else: @@ -601,7 +592,7 @@ def client(self) -> Any | None: client = SyftClientSessionCache.get_client_for_server_uid(self.server_uid) if client is None: return SyftError( - message=f"No clients for {self.server_uid} in memory. Please login with sy.login" + message=f"No clients for {self.server_uid} in memory. Please login with sy.login", ) return client @@ -613,7 +604,7 @@ def client(self) -> Any | None: "You can create an asset without a mock with `sy.Asset(..., mock=sy.ActionObject.empty())` or\n" "set the mock of an existing asset to be empty with `asset.no_mock()` or ", "`asset.mock = sy.ActionObject.empty()`.", - ] + ], ) @@ -627,8 +618,8 @@ def _check_asset_must_contain_mock(asset_list: list[CreateAsset]) -> None: *[f"{asset}\n" for asset in assets_without_mock], "\n", _ASSET_WITH_NONE_MOCK_ERROR_MESSAGE, - ] - ) + ], + ), ) @@ -660,7 +651,7 @@ class CreateDataset(Dataset): @field_validator("asset_list") @classmethod def __assets_must_contain_mock( - cls, asset_list: list[CreateAsset] + cls, asset_list: list[CreateAsset], ) -> list[CreateAsset]: _check_asset_must_contain_mock(asset_list) return asset_list @@ -695,21 +686,21 @@ def add_contributor( try: _role_str = role.value if isinstance(role, Enum) else role contributor = Contributor( - name=name, role=_role_str, email=email, phone=phone, note=note + name=name, role=_role_str, email=email, phone=phone, note=note, ) if contributor in self.contributors: return SyftError( - message=f"Contributor with email: '{email}' already exists in '{self.name}' Dataset." + message=f"Contributor with email: '{email}' already exists in '{self.name}' Dataset.", ) self.contributors.add(contributor) return SyftSuccess( - message=f"Contributor '{name}' added to '{self.name}' Dataset." + message=f"Contributor '{name}' added to '{self.name}' Dataset.", ) except Exception as e: return SyftError(message=f"Failed to add contributor. Error: {e}") def add_asset( - self, asset: CreateAsset, force_replace: bool = False + self, asset: CreateAsset, force_replace: bool = False, ) -> SyftSuccess | SyftError: if asset.mock is None: raise ValueError(_ASSET_WITH_NONE_MOCK_ERROR_MESSAGE) @@ -719,18 +710,18 @@ def add_asset( if not force_replace: return SyftError( message=f"""Asset "{asset.name}" already exists in '{self.name}' Dataset.""" - """ Use add_asset(asset, force_replace=True) to replace.""" + """ Use add_asset(asset, force_replace=True) to replace.""", ) else: self.asset_list[i] = asset return SyftSuccess( - message=f"Asset {asset.name} has been successfully replaced." + message=f"Asset {asset.name} has been successfully replaced.", ) self.asset_list.append(asset) return SyftSuccess( - message=f"Asset '{asset.name}' added to '{self.name}' Dataset." + message=f"Asset '{asset.name}' added to '{self.name}' Dataset.", ) def replace_asset(self, asset: CreateAsset) -> SyftSuccess | SyftError: @@ -747,7 +738,7 @@ def remove_asset(self, name: str) -> SyftSuccess | SyftError: return SyftError(message=f"No asset exists with name: {name}") self.asset_list.remove(asset_to_remove) return SyftSuccess( - message=f"Asset '{self.name}' removed from '{self.name}' Dataset." + message=f"Asset '{self.name}' removed from '{self.name}' Dataset.", ) def check(self) -> Result[SyftSuccess, list[SyftError]]: @@ -791,7 +782,7 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: # TODO, upload to blob storage here if context.server is None: raise ValueError( - "f{context}'s server is None, please log in. No trasformation happened" + "f{context}'s server is None, please log in. No trasformation happened", ) action_service = context.server.get_service("actionservice") result = action_service._set( @@ -823,7 +814,7 @@ def set_data_subjects(context: TransformContext) -> TransformContext | SyftError raise ValueError(f"{context}'s output is None. No transformation happened") if context.server is None: return SyftError( - "f{context}'s server is None, please log in. No trasformation happened" + "f{context}'s server is None, please log in. No trasformation happened", ) data_subjects = context.output["data_subjects"] get_data_subject = context.server.get_service_method(DataSubjectService.get_by_name) diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index c3bc68385ad..bd246d3406b 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -1,7 +1,6 @@ # stdlib -from collections.abc import Collection -from collections.abc import Sequence import logging +from collections.abc import Collection, Sequence from typing import cast # relative @@ -10,26 +9,18 @@ from ...types.dicttuple import DictTuple from ...types.uid import UID from ...util.telemetry import instrument -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionPermission +from ..action.action_permissions import ActionObjectPermission, ActionPermission from ..action.action_service import ActionService from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import DATA_OWNER_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from ..warnings import CRUDReminder -from ..warnings import HighSideCRUDWarning -from .dataset import Asset -from .dataset import CreateDataset -from .dataset import Dataset -from .dataset import DatasetPageView -from .dataset import DatasetUpdate +from ..response import SyftError, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + DATA_OWNER_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, +) +from ..warnings import CRUDReminder, HighSideCRUDWarning +from .dataset import Asset, CreateDataset, Dataset, DatasetPageView, DatasetUpdate from .dataset_stash import DatasetStash logger = logging.getLogger(__name__) @@ -87,7 +78,7 @@ def __init__(self, store: DocumentStore) -> None: roles=DATA_OWNER_ROLE_LEVEL, ) def add( - self, context: AuthedServiceContext, dataset: CreateDataset + self, context: AuthedServiceContext, dataset: CreateDataset, ) -> SyftSuccess | SyftError: """Add a Dataset""" dataset = dataset.to(Dataset, context=context) @@ -96,7 +87,7 @@ def add( dataset, add_permissions=[ ActionObjectPermission( - uid=dataset.id, permission=ActionPermission.ALL_READ + uid=dataset.id, permission=ActionPermission.ALL_READ, ), ], ) @@ -104,7 +95,7 @@ def add( return SyftError(message=str(result.err())) return SyftSuccess( message=f"Dataset uploaded to '{context.server.name}'. " - f"To see the datasets uploaded by a client on this server, use command `[your_client].datasets`" + f"To see the datasets uploaded by a client on this server, use command `[your_client].datasets`", ) @service_method( @@ -133,7 +124,7 @@ def get_all( datasets.remove(dataset) return _paginate_dataset_collection( - datasets=datasets, page_size=page_size, page_index=page_index + datasets=datasets, page_size=page_size, page_index=page_index, ) @service_method(path="dataset.search", name="search", roles=GUEST_ROLE_LEVEL) @@ -157,7 +148,7 @@ def search( ] return _paginate_dataset_collection( - filtered_results, page_size=page_size, page_index=page_index + filtered_results, page_size=page_size, page_index=page_index, ) @service_method(path="dataset.get_by_id", name="get_by_id") @@ -174,7 +165,7 @@ def get_by_id(self, context: AuthedServiceContext, uid: UID) -> Dataset | SyftEr @service_method(path="dataset.get_by_action_id", name="get_by_action_id") def get_by_action_id( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> list[Dataset] | SyftError: """Get Datasets by an Action ID""" result = self.stash.search_action_ids(context.credentials, uid=uid) @@ -194,7 +185,7 @@ def get_by_action_id( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_assets_by_action_id( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> list[Asset] | SyftError: """Get Assets by an Action ID""" datasets = self.get_by_action_id(context=context, uid=uid) @@ -217,10 +208,9 @@ def get_assets_by_action_id( warning=HighSideCRUDWarning(confirmation=True), ) def delete( - self, context: AuthedServiceContext, uid: UID, delete_assets: bool = True + self, context: AuthedServiceContext, uid: UID, delete_assets: bool = True, ) -> SyftSuccess | SyftError: - """ - Soft delete: keep the dataset object, only remove the blob store entries + """Soft delete: keep the dataset object, only remove the blob store entries After soft deleting a dataset, the user will not be able to see it using the `datasets.get_all` endpoint. Delete unique `dataset.name` key and leave UID, just rename it in case the @@ -242,10 +232,10 @@ def delete( ) action_service = cast( - ActionService, context.server.get_service(ActionService) + ActionService, context.server.get_service(ActionService), ) del_res: SyftSuccess | SyftError = action_service.delete( - context=context, uid=asset.action_id, soft_delete=True + context=context, uid=asset.action_id, soft_delete=True, ) if isinstance(del_res, SyftError): @@ -259,7 +249,7 @@ def delete( # soft delete the dataset object from the store dataset_update = DatasetUpdate( - id=uid, name=f"_deleted_{dataset.name}_{uid}", to_be_deleted=True + id=uid, name=f"_deleted_{dataset.name}_{uid}", to_be_deleted=True, ) result = self.stash.update(context.credentials, dataset_update) if result.is_err(): diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index f03715985c0..14fe6b8d239 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -1,22 +1,21 @@ # stdlib # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, +) from ...types.uid import UID from ...util.telemetry import instrument -from .dataset import Dataset -from .dataset import DatasetUpdate +from .dataset import Dataset, DatasetUpdate NamePartitionKey = PartitionKey(key="name", type_=str) ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) @@ -27,14 +26,14 @@ class DatasetStash(BaseUIDStoreStash): object_type = Dataset settings: PartitionSettings = PartitionSettings( - name=Dataset.__canonical_name__, object_type=Dataset + name=Dataset.__canonical_name__, object_type=Dataset, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_by_name( - self, credentials: SyftVerifyKey, name: str + self, credentials: SyftVerifyKey, name: str, ) -> Result[Dataset | None, str]: qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) return self.query_one(credentials=credentials, qks=qks) @@ -52,7 +51,7 @@ def update( return super().update(credentials=credentials, obj=res.ok()) def search_action_ids( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[list[Dataset], str]: qks = QueryKeys(qks=[ActionIDsPartitionKey.with_obj(uid)]) return self.query_all(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/job/html_template.py b/packages/syft/src/syft/service/job/html_template.py index 5ce2cfd2dd6..a2d5f1f8995 100644 --- a/packages/syft/src/syft/service/job/html_template.py +++ b/packages/syft/src/syft/service/job/html_template.py @@ -1,6 +1,5 @@ # relative -from ...util.notebook_ui.styles import CSS_CODE -from ...util.notebook_ui.styles import JS_DOWNLOAD_FONTS +from ...util.notebook_ui.styles import CSS_CODE, JS_DOWNLOAD_FONTS type_html = """
""" -) # noqa: E501 +) attrs_html = """
diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 3cfbe356f09..354d79d794d 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -1,9 +1,8 @@ # stdlib -from collections.abc import Callable import inspect import time -from typing import Any -from typing import cast +from collections.abc import Callable +from typing import Any, cast # relative from ...serde.serializable import serializable @@ -12,28 +11,24 @@ from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_object import ActionObject -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionPermission +from ..action.action_permissions import ActionObjectPermission, ActionPermission from ..code.user_code import UserCode from ..context import AuthedServiceContext from ..log.log_service import LogService from ..queue.queue_stash import ActionQueueItem -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_OWNER_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from .job_stash import Job -from .job_stash import JobStash -from .job_stash import JobStatus +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + ADMIN_ROLE_LEVEL, + DATA_OWNER_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, +) +from .job_stash import Job, JobStash, JobStatus def wait_until( - predicate: Callable[[], bool], timeout: int = 10 + predicate: Callable[[], bool], timeout: int = 10, ) -> SyftSuccess | SyftError: start = time.time() code_string = inspect.getsource(predicate).strip() @@ -82,7 +77,7 @@ def get_all(self, context: AuthedServiceContext) -> list[Job] | SyftError: roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_by_user_code_id( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, user_code_id: UID, ) -> list[Job] | SyftError: res = self.stash.get_by_user_code_id(context.credentials, user_code_id) if res.is_err(): @@ -97,7 +92,7 @@ def get_by_user_code_id( roles=ADMIN_ROLE_LEVEL, ) def delete( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: res = self.stash.delete_by_uid(context.credentials, uid) if res.is_err(): @@ -110,7 +105,7 @@ def delete( roles=ADMIN_ROLE_LEVEL, ) def get_by_result_id( - self, context: AuthedServiceContext, result_id: UID + self, context: AuthedServiceContext, result_id: UID, ) -> Job | None | SyftError: res = self.stash.get_by_result_id(context.credentials, result_id) if res.is_err(): @@ -123,7 +118,7 @@ def get_by_result_id( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def restart( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: job_or_err = self.stash.get_by_uid(context.credentials, uid=uid) if job_or_err.is_err(): @@ -134,12 +129,12 @@ def restart( job = job_or_err.ok() if job.parent_job_id is not None: return SyftError( - message="Not possible to restart subjobs. Please restart the parent job." + message="Not possible to restart subjobs. Please restart the parent job.", ) if job.status == JobStatus.PROCESSING: return SyftError( message="Jobs in progress cannot be restarted. " - "Please wait for completion or cancel the job via .cancel() to proceed." + "Please wait for completion or cancel the job via .cancel() to proceed.", ) job.status = JobStatus.CREATED @@ -148,7 +143,7 @@ def restart( task_uid = UID() worker_settings = WorkerSettings.from_server(context.server) worker_pool_ref = context.server.get_worker_pool_ref_by_name( - context.credentials + context.credentials, ) if isinstance(worker_pool_ref, SyftError): return worker_pool_ref @@ -181,7 +176,7 @@ def restart( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def update( - self, context: AuthedServiceContext, job: Job + self, context: AuthedServiceContext, job: Job, ) -> SyftSuccess | SyftError: res = self.stash.update(context.credentials, obj=job) if res.is_err(): @@ -214,7 +209,7 @@ def _kill(self, context: AuthedServiceContext, job: Job) -> SyftSuccess | SyftEr wait_until( lambda: all( subjob.fetched_status == JobStatus.INTERRUPTED for subjob in job.subjobs - ) + ), ) return SyftSuccess(message="Job killed successfully!") @@ -234,14 +229,14 @@ def kill(self, context: AuthedServiceContext, id: UID) -> SyftSuccess | SyftErro job = job_or_err.ok() if job.parent_job_id is not None: return SyftError( - message="Not possible to cancel subjobs. To stop execution, please cancel the parent job." + message="Not possible to cancel subjobs. To stop execution, please cancel the parent job.", ) if job.status != JobStatus.PROCESSING: return SyftError(message="Job is not running") if job.job_pid is None: return SyftError( message="Job termination disabled in dev mode. " - "Set 'dev_mode=False' or 'thread_workers=False' to enable." + "Set 'dev_mode=False' or 'thread_workers=False' to enable.", ) return self._kill(context, job) @@ -252,7 +247,7 @@ def kill(self, context: AuthedServiceContext, id: UID) -> SyftSuccess | SyftErro roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_subjobs( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> list[Job] | SyftError: res = self.stash.get_by_parent_id(context.credentials, uid=uid) if res.is_err(): @@ -261,7 +256,7 @@ def get_subjobs( return res.ok() @service_method( - path="job.get_active", name="get_active", roles=DATA_SCIENTIST_ROLE_LEVEL + path="job.get_active", name="get_active", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_active(self, context: AuthedServiceContext) -> list[Job] | SyftError: res = self.stash.get_active(context.credentials) @@ -275,10 +270,10 @@ def get_active(self, context: AuthedServiceContext) -> list[Job] | SyftError: roles=DATA_OWNER_ROLE_LEVEL, ) def add_read_permission_job_for_code_owner( - self, context: AuthedServiceContext, job: Job, user_code: UserCode + self, context: AuthedServiceContext, job: Job, user_code: UserCode, ) -> None: permission = ActionObjectPermission( - job.id, ActionPermission.READ, user_code.user_verify_key + job.id, ActionPermission.READ, user_code.user_verify_key, ) return self.stash.add_permission(permission=permission) @@ -288,14 +283,14 @@ def add_read_permission_job_for_code_owner( roles=DATA_OWNER_ROLE_LEVEL, ) def add_read_permission_log_for_code_owner( - self, context: AuthedServiceContext, log_id: UID, user_code: UserCode + self, context: AuthedServiceContext, log_id: UID, user_code: UserCode, ) -> Any: log_service = context.server.get_service("logservice") log_service = cast(LogService, log_service) return log_service.stash.add_permission( ActionObjectPermission( - log_id, ActionPermission.READ, user_code.user_verify_key - ) + log_id, ActionPermission.READ, user_code.user_verify_key, + ), ) @service_method( diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 2ad27746f57..c535aaa04f6 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -1,38 +1,32 @@ # stdlib -from datetime import datetime -from datetime import timedelta -from datetime import timezone -from enum import Enum import random +from datetime import datetime, timedelta, timezone +from enum import Enum from string import Template from time import sleep from typing import Any # third party -from pydantic import Field -from pydantic import model_validator -from result import Err -from result import Ok -from result import Result +from pydantic import Field, model_validator +from result import Err, Ok, Result from typing_extensions import Self # relative -from ...client.api import APIRegistry -from ...client.api import SyftAPICall +from ...client.api import APIRegistry, SyftAPICall from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey -from ...types.datetime import DateTime -from ...types.datetime import format_timedelta -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, + UIDPartitionKey, +) +from ...types.datetime import DateTime, format_timedelta +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID from ...util import options @@ -40,13 +34,10 @@ from ...util.markdown import as_markdown_code from ...util.telemetry import instrument from ...util.util import prompt_warning_message -from ..action.action_object import Action -from ..action.action_object import ActionObject +from ..action.action_object import Action, ActionObject from ..action.action_permissions import ActionObjectPermission from ..log.log import SyftLog -from ..response import SyftError -from ..response import SyftNotReady -from ..response import SyftSuccess +from ..response import SyftError, SyftNotReady, SyftSuccess from ..user.user import UserView from .html_template import job_repr_template @@ -100,7 +91,7 @@ class Job(SyncableSyftObject): n_iters: int | None = 0 current_iter: int | None = None creation_time: str | None = Field( - default_factory=lambda: str(datetime.now(tz=timezone.utc)) + default_factory=lambda: str(datetime.now(tz=timezone.utc)), ) action: Action | None = None job_pid: int | None = None @@ -204,7 +195,7 @@ def worker(self) -> SyftWorker | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.worker.get(self.job_worker_id) @@ -242,16 +233,15 @@ def progress(self) -> str | None: if self.status in [JobStatus.PROCESSING, JobStatus.COMPLETED]: if self.current_iter is None: return "" - else: - if self.n_iters is not None: - return self.time_remaining_string - # if self.current_iter !=0 - # we can compute the remaining time + elif self.n_iters is not None: + return self.time_remaining_string + # if self.current_iter !=0 + # we can compute the remaining time - # we cannot compute the remaining time - else: - n_iters_str = "?" if self.n_iters is None else str(self.n_iters) - return f"{self.current_iter}/{n_iters_str}" + # we cannot compute the remaining time + else: + n_iters_str = "?" if self.n_iters is None else str(self.n_iters) + return f"{self.current_iter}/{n_iters_str}" else: return "" @@ -277,7 +267,7 @@ def restart(self, kill: bool = False) -> None: ) if api is None: raise ValueError( - f"Can't access Syft API. You must login to {self.syft_server_location}" + f"Can't access Syft API. You must login to {self.syft_server_location}", ) call = SyftAPICall( server_uid=self.server_uid, @@ -297,7 +287,7 @@ def kill(self) -> SyftError | SyftSuccess: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) call = SyftAPICall( server_uid=self.server_uid, @@ -317,7 +307,7 @@ def fetch(self) -> None: ) if api is None: raise ValueError( - f"Can't access Syft API. You must login to {self.syft_server_location}" + f"Can't access Syft API. You must login to {self.syft_server_location}", ) call = SyftAPICall( server_uid=self.server_uid, @@ -328,7 +318,7 @@ def fetch(self) -> None: ) job: Job | None = api.make_call(call) if job is None: - return None + return self.resolved = job.resolved if job.resolved: self.result = job.result @@ -345,7 +335,7 @@ def subjobs(self) -> list["Job"] | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.job.get_subjobs(self.id) @@ -361,7 +351,7 @@ def owner(self) -> UserView | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.user.get_current_user(self.id) @@ -375,7 +365,7 @@ def _get_log_objs(self) -> SyftLog | SyftError: return api.services.log.get(self.log_id) def logs( - self, stdout: bool = True, stderr: bool = True, _print: bool = True + self, stdout: bool = True, stderr: bool = True, _print: bool = True, ) -> str | None: api = APIRegistry.api_for( server_uid=self.syft_server_location, @@ -409,18 +399,16 @@ def logs( # no access if isinstance(self.result, Err): results.append(self.result.value) - else: - # add short error - if isinstance(self.result, Err): - results.append(self.result.value) + elif isinstance(self.result, Err): + results.append(self.result.value) if has_permissions: has_storage_permission = api.services.log.has_storage_permission( - self.log_id + self.log_id, ) if not has_storage_permission: prompt_warning_message( - message="This is a placeholder object, the real data lives on a different server and is not synced." + message="This is a placeholder object, the real data lives on a different server and is not synced.", ) results_str = "\n".join(results) @@ -460,7 +448,7 @@ def summary_html(self) -> str: worker_summary = "" if self.job_worker_id: worker_copy_button = CopyIDButton( - copy_text=str(self.job_worker_id), max_width=60 + copy_text=str(self.job_worker_id), max_width=60, ) worker_summary = f"""
@@ -516,7 +504,7 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: logs = self.logs(_print=False) if logs is not None: logs_w_linenr = "\n".join( - [f"{i} {line}" for i, line in enumerate(logs.rstrip().split("\n"))] + [f"{i} {line}" for i, line in enumerate(logs.rstrip().split("\n"))], ) if self.status == JobStatus.COMPLETED: @@ -546,7 +534,7 @@ def requesting_user(self) -> UserView | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.user.view(self.requested_by) @@ -558,7 +546,7 @@ def server_name(self) -> str | SyftError | None: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.server_name @@ -570,7 +558,7 @@ def parent(self) -> Self | SyftError: ) if api is None: return SyftError( - message=f"Can't access Syft API. You must login to {self.syft_server_location}" + message=f"Can't access Syft API. You must login to {self.syft_server_location}", ) return api.services.job.get(self.parent_job_id) @@ -606,7 +594,7 @@ def _repr_html_(self) -> str: user_repr = "--" if self.requested_by and not isinstance( - requesting_user := self.requesting_user, SyftError + requesting_user := self.requesting_user, SyftError, ): user_repr = f"{requesting_user.name} {requesting_user.email}" @@ -615,7 +603,7 @@ def _repr_html_(self) -> str: worker = self.worker if not isinstance(worker, SyftError): worker_pool_id_button = CopyIDButton( - copy_text=str(worker.worker_pool_name), max_width=60 + copy_text=str(worker.worker_pool_name), max_width=60, ) worker_attr = f"""
@@ -652,7 +640,7 @@ def _repr_html_(self) -> str: ) def wait( - self, job_only: bool = False, timeout: int | None = None + self, job_only: bool = False, timeout: int | None = None, ) -> Any | SyftNotReady | SyftError: self.fetch() if self.resolved: @@ -665,7 +653,7 @@ def wait( if api is None: raise ValueError( - f"Can't access Syft API. You must login to server with id '{self.syft_server_location}'" + f"Can't access Syft API. You must login to server with id '{self.syft_server_location}'", ) workers = api.services.worker.get_all() @@ -673,7 +661,7 @@ def wait( return SyftError( message=f"Server {self.syft_server_location} has no workers. " f"You need to start a worker to run jobs " - f"by setting n_consumers > 0." + f"by setting n_consumers > 0.", ) print_warning = True @@ -689,13 +677,13 @@ def wait( break if print_warning and self.result is not None: result_obj = api.services.action.get( # type: ignore[unreachable] - self.result.id, resolve_nested=False + self.result.id, resolve_nested=False, ) if result_obj.is_link and job_only: print( "You're trying to wait on a job that has a link as a result." "This means that the job may be ready but the linked result may not." - "Use job.wait().get() instead to wait for the linked result." + "Use job.wait().get() instead to wait for the linked result.", ) print_warning = False @@ -742,7 +730,7 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # dependencies.append(self.user_code_id) output = context.server.get_service("outputservice").get_by_job_id( # type: ignore - context, self.id + context, self.id, ) if isinstance(output, SyftError): return output @@ -796,7 +784,7 @@ def _repr_html_(self) -> str: result_str = "

Result

" if self.includes_result: - result_str += f"

{str(self.result)}

" + result_str += f"

{self.result!s}

" else: result_str += "

No result included

" @@ -843,7 +831,7 @@ def from_job( class JobStash(BaseUIDStoreStash): object_type = Job settings: PartitionSettings = PartitionSettings( - name=Job.__canonical_name__, object_type=Job + name=Job.__canonical_name__, object_type=Job, ) def __init__(self, store: DocumentStore) -> None: @@ -875,7 +863,7 @@ def get_by_result_id( result_id: UID, ) -> Result[Job | None, str]: qks = QueryKeys( - qks=[PartitionKey(key="result_id", type_=UID).with_obj(result_id)] + qks=[PartitionKey(key="result_id", type_=UID).with_obj(result_id)], ) res = self.query_all(credentials=credentials, qks=qks) if res.is_err(): @@ -890,16 +878,16 @@ def get_by_result_id( return Ok(res[0]) def get_by_parent_id( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[Job | None, str]: qks = QueryKeys( - qks=[PartitionKey(key="parent_job_id", type_=UID).with_obj(uid)] + qks=[PartitionKey(key="parent_job_id", type_=UID).with_obj(uid)], ) item = self.query_all(credentials=credentials, qks=qks) return item def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[SyftSuccess, str]: qk = UIDPartitionKey.with_obj(uid) result = super().delete(credentials=credentials, qk=qk) @@ -911,25 +899,25 @@ def get_active(self, credentials: SyftVerifyKey) -> Result[SyftSuccess, str]: qks = QueryKeys( qks=[ PartitionKey(key="status", type_=JobStatus).with_obj( - JobStatus.PROCESSING - ) - ] + JobStatus.PROCESSING, + ), + ], ) return self.query_all(credentials=credentials, qks=qks) def get_by_worker( - self, credentials: SyftVerifyKey, worker_id: str + self, credentials: SyftVerifyKey, worker_id: str, ) -> Result[list[Job], str]: qks = QueryKeys( - qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)] + qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)], ) return self.query_all(credentials=credentials, qks=qks) def get_by_user_code_id( - self, credentials: SyftVerifyKey, user_code_id: UID + self, credentials: SyftVerifyKey, user_code_id: UID, ) -> Result[list[Job], str]: qks = QueryKeys( - qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)] + qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)], ) return self.query_all(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index 204409b0079..402d6f4da03 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,6 +1,5 @@ # stdlib -from typing import Any -from typing import ClassVar +from typing import Any, ClassVar # relative from ...serde.serializable import serializable @@ -37,6 +36,6 @@ def restart(self) -> None: self.stdout = "" def get_sync_dependencies( - self, context: AuthedServiceContext, **kwargs: dict + self, context: AuthedServiceContext, **kwargs: dict, ) -> list[UID]: # type: ignore return [self.job_id] diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index 8925b88a1a0..153c2c59245 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -5,13 +5,9 @@ from ...util.telemetry import instrument from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL, DATA_SCIENTIST_ROLE_LEVEL from .log import SyftLog from .log_stash import LogStash @@ -74,7 +70,7 @@ def get(self, context: AuthedServiceContext, uid: UID) -> SyftLog | SyftError: return result.ok() @service_method( - path="log.get_stdout", name="get_stdout", roles=DATA_SCIENTIST_ROLE_LEVEL + path="log.get_stdout", name="get_stdout", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_stdout(self, context: AuthedServiceContext, uid: UID) -> str | SyftError: result = self.get(context, uid) @@ -115,7 +111,7 @@ def get_all(self, context: AuthedServiceContext) -> SyftSuccess | SyftError: @service_method(path="log.delete", name="delete", roles=DATA_SCIENTIST_ROLE_LEVEL) def delete( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: result = self.stash.delete_by_uid(context.credentials, uid) if result.is_ok(): diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 54657982633..6c3f0a14dc4 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,8 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings +from ...store.document_store import BaseUIDStoreStash, DocumentStore, PartitionSettings from ...util.telemetry import instrument from .log import SyftLog @@ -12,7 +10,7 @@ class LogStash(BaseUIDStoreStash): object_type = SyftLog settings: PartitionSettings = PartitionSettings( - name=SyftLog.__canonical_name__, object_type=SyftLog + name=SyftLog.__canonical_name__, object_type=SyftLog, ) def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index ccdb0c0a8ec..ef46499cdc1 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -5,8 +5,7 @@ from ...store.document_store import DocumentStore from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..service import AbstractService -from ..service import service_method +from ..service import AbstractService, service_method from ..user.user_roles import GUEST_ROLE_LEVEL from .server_metadata import ServerMetadata @@ -18,7 +17,7 @@ def __init__(self, store: DocumentStore) -> None: self.store = store @service_method( - path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL + path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL, ) def get_metadata(self, context: AuthedServiceContext) -> ServerMetadata: return context.server.metadata # type: ignore diff --git a/packages/syft/src/syft/service/metadata/server_metadata.py b/packages/syft/src/syft/service/metadata/server_metadata.py index c56eb8a49ae..93c5199600f 100644 --- a/packages/syft/src/syft/service/metadata/server_metadata.py +++ b/packages/syft/src/syft/service/metadata/server_metadata.py @@ -6,26 +6,20 @@ # third party from packaging import version -from pydantic import BaseModel -from pydantic import model_validator +from pydantic import BaseModel, model_validator # relative from ...abstract_server import ServerType from ...protocol.data_protocol import get_data_protocol from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import StorableObjectType -from ...types.syft_object import SyftObject -from ...types.transforms import convert_types -from ...types.transforms import drop -from ...types.transforms import rename -from ...types.transforms import transform +from ...types.syft_object import SYFT_OBJECT_VERSION_1, StorableObjectType, SyftObject +from ...types.transforms import convert_types, drop, rename, transform from ...types.uid import UID def check_version( - client_version: str, server_version: str, server_name: str, silent: bool = False + client_version: str, server_version: str, server_name: str, silent: bool = False, ) -> bool: client_syft_version = version.parse(client_version) server_syft_version = version.parse(server_version) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 0075768f9ce..f33fcac6ac0 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,34 +1,29 @@ # stdlib -from collections import defaultdict import sys +from collections import defaultdict from typing import cast # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...store.document_store import StorePartition +from ...store.document_store import DocumentStore, StorePartition from ...types.blob_storage import BlobStorageEntry from ...types.syft_object import SyftObject -from ..action.action_object import Action -from ..action.action_object import ActionObject -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import StoragePermission +from ..action.action_object import Action, ActionObject +from ..action.action_permissions import ActionObjectPermission, StoragePermission from ..action.action_store import KeyValueActionStore from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import service_method +from ..response import SyftError, SyftSuccess +from ..service import AbstractService, service_method from ..user.user_roles import ADMIN_ROLE_LEVEL -from .object_migration_state import MigrationData -from .object_migration_state import StoreMetadata -from .object_migration_state import SyftMigrationStateStash -from .object_migration_state import SyftObjectMigrationState +from .object_migration_state import ( + MigrationData, + StoreMetadata, + SyftMigrationStateStash, + SyftObjectMigrationState, +) @serializable(canonical_name="MigrationService", version=1) @@ -42,12 +37,11 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="migration", name="get_version") def get_version( - self, context: AuthedServiceContext, canonical_name: str + self, context: AuthedServiceContext, canonical_name: str, ) -> int | SyftError: """Search for the metadata for an object.""" - result = self.stash.get_by_name( - canonical_name=canonical_name, credentials=context.credentials + canonical_name=canonical_name, credentials=context.credentials, ) if result.is_err(): @@ -57,17 +51,17 @@ def get_version( if migration_state is None: return SyftError( - message=f"No migration state exists for canonical name: {canonical_name}" + message=f"No migration state exists for canonical name: {canonical_name}", ) return migration_state.current_version @service_method(path="migration", name="get_state") def get_state( - self, context: AuthedServiceContext, canonical_name: str + self, context: AuthedServiceContext, canonical_name: str, ) -> bool | SyftError: result = self.stash.get_by_name( - canonical_name=canonical_name, credentials=context.credentials + canonical_name=canonical_name, credentials=context.credentials, ) if result.is_err(): @@ -83,7 +77,7 @@ def register_migration_state( canonical_name: str, ) -> SyftObjectMigrationState | SyftError: obj = SyftObjectMigrationState( - current_version=current_version, canonical_name=canonical_name + current_version=current_version, canonical_name=canonical_name, ) result = self.stash.set(migration_state=obj, credentials=context.credentials) @@ -93,7 +87,7 @@ def register_migration_state( return result.ok() def _find_klasses_pending_for_migration( - self, context: AuthedServiceContext, object_types: list[type[SyftObject]] + self, context: AuthedServiceContext, object_types: list[type[SyftObject]], ) -> list[type[SyftObject]]: klasses_to_be_migrated = [] @@ -104,7 +98,7 @@ def _find_klasses_pending_for_migration( migration_state = self.get_state(context, canonical_name) if isinstance(migration_state, SyftError): raise Exception( - f"Failed to get migration state for {canonical_name}. Error: {migration_state}" + f"Failed to get migration state for {canonical_name}. Error: {migration_state}", ) if ( migration_state is not None @@ -184,7 +178,7 @@ def _get_store_metadata( object_type=object_type, permissions=permissions, storage_permissions=storage_permissions, - ) + ), ) def _get_all_store_metadata( @@ -217,7 +211,7 @@ def _get_all_store_metadata( roles=ADMIN_ROLE_LEVEL, ) def update_store_metadata( - self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] + self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata], ) -> SyftSuccess | SyftError: res = self._update_store_metadata(context, store_metadata) if res.is_err(): @@ -226,7 +220,7 @@ def update_store_metadata( return SyftSuccess(message=res.ok()) def _update_store_metadata_for_klass( - self, context: AuthedServiceContext, metadata: StoreMetadata + self, context: AuthedServiceContext, metadata: StoreMetadata, ) -> Result[str, str]: object_partition = self._get_partition_from_type(context, metadata.object_type) if object_partition.is_err(): @@ -251,7 +245,7 @@ def _update_store_metadata_for_klass( return Ok("success") def _update_store_metadata( - self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] + self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata], ) -> Result[str, str]: print("Updating store metadata") for metadata in store_metadata.values(): @@ -290,7 +284,7 @@ def _get_migration_objects( klasses_to_migrate = document_store_object_types else: klasses_to_migrate = self._find_klasses_pending_for_migration( - context=context, object_types=document_store_object_types + context=context, object_types=document_store_object_types, ) result = defaultdict(list) @@ -301,7 +295,7 @@ def _get_migration_objects( if object_partition is None: continue objects_result = object_partition.all( - context.credentials, has_permission=True + context.credentials, has_permission=True, ) if objects_result.is_err(): return objects_result @@ -318,7 +312,7 @@ def _get_migration_objects( return Ok(dict(result)) def _search_partition_for_object( - self, context: AuthedServiceContext, obj: SyftObject + self, context: AuthedServiceContext, obj: SyftObject, ) -> Result[StorePartition, str]: klass = type(obj) mro = klass.__mro__ @@ -359,7 +353,7 @@ def _create_migrated_objects( ) -> Result[str, str]: for migrated_object in migrated_objects: object_partition_or_err = self._search_partition_for_object( - context, migrated_object + context, migrated_object, ) if object_partition_or_err.is_err(): return object_partition_or_err @@ -373,7 +367,7 @@ def _create_migrated_objects( if result.is_err(): if ignore_existing and "Duplication Key Error" in result.value: print( - f"{type(migrated_object)} #{migrated_object.id} already exists" + f"{type(migrated_object)} #{migrated_object.id} already exists", ) continue else: @@ -387,7 +381,7 @@ def _create_migrated_objects( roles=ADMIN_ROLE_LEVEL, ) def update_migrated_objects( - self, context: AuthedServiceContext, migrated_objects: list[SyftObject] + self, context: AuthedServiceContext, migrated_objects: list[SyftObject], ) -> SyftSuccess | SyftError: res = self._update_migrated_objects(context, migrated_objects) if res.is_err(): @@ -396,11 +390,11 @@ def update_migrated_objects( return SyftSuccess(message=res.ok()) def _update_migrated_objects( - self, context: AuthedServiceContext, migrated_objects: list[SyftObject] + self, context: AuthedServiceContext, migrated_objects: list[SyftObject], ) -> Result[str, str]: for migrated_object in migrated_objects: object_partition_or_err = self._search_partition_for_object( - context, migrated_object + context, migrated_object, ) if object_partition_or_err.is_err(): return object_partition_or_err @@ -447,7 +441,7 @@ def _migrate_objects( print(traceback.format_exc()) return Err( - f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}" + f"Failed to migrate data to {klass} for qk {klass.__version__}: {object.id}", ) return Ok(migrated_objects) @@ -479,7 +473,7 @@ def migrate_data( # client.migration.write_migrated_values(migrated_values) migration_objects_result = self._get_migration_objects( - context, document_store_object_types + context, document_store_object_types, ) if migration_objects_result.is_err(): return migration_objects_result @@ -491,7 +485,7 @@ def migrate_data( migrated_objects = migrated_objects_result.ok() objects_update_update_result = self._update_migrated_objects( - context, migrated_objects + context, migrated_objects, ) if objects_update_update_result.is_err(): return SyftError(message=objects_update_update_result.value) @@ -508,7 +502,7 @@ def migrate_data( migrated_actionobjects = migrated_actionobjects.ok() actionobjects_update_update_result = self._update_migrated_actionobjects( - context, migrated_actionobjects + context, migrated_actionobjects, ) if actionobjects_update_update_result.is_err(): return SyftError(message=actionobjects_update_update_result.err()) @@ -521,7 +515,7 @@ def migrate_data( roles=ADMIN_ROLE_LEVEL, ) def get_migration_actionobjects( - self, context: AuthedServiceContext, get_all: bool = False + self, context: AuthedServiceContext, get_all: bool = False, ) -> dict | SyftError: res = self._get_migration_actionobjects(context, get_all=get_all) if res.is_ok(): @@ -530,7 +524,7 @@ def get_migration_actionobjects( return SyftError(message=res.value) def _get_migration_actionobjects( - self, context: AuthedServiceContext, get_all: bool = False + self, context: AuthedServiceContext, get_all: bool = False, ) -> Result[dict[type[SyftObject], list[SyftObject]], str]: # Track all object types from action store action_object_types = [Action, ActionObject] @@ -540,12 +534,12 @@ def _get_migration_actionobjects( } action_object_pending_migration = self._find_klasses_pending_for_migration( - context=context, object_types=action_object_types + context=context, object_types=action_object_types, ) result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list) action_store = context.server.action_store action_store_objects_result = action_store._all( - context.credentials, has_permission=True + context.credentials, has_permission=True, ) if action_store_objects_result.is_err(): return action_store_objects_result @@ -563,7 +557,7 @@ def _get_migration_actionobjects( roles=ADMIN_ROLE_LEVEL, ) def update_migrated_actionobjects( - self, context: AuthedServiceContext, objects: list[SyftObject] + self, context: AuthedServiceContext, objects: list[SyftObject], ) -> SyftSuccess | SyftError: res = self._update_migrated_actionobjects(context, objects) if res.is_ok(): @@ -572,13 +566,13 @@ def update_migrated_actionobjects( return SyftError(message=res.value) def _update_migrated_actionobjects( - self, context: AuthedServiceContext, objects: list[SyftObject] + self, context: AuthedServiceContext, objects: list[SyftObject], ) -> Result[str, str]: # Track all object types from action store action_store = context.server.action_store for obj in objects: res = action_store.set( - uid=obj.id, credentials=context.credentials, syft_object=obj + uid=obj.id, credentials=context.credentials, syft_object=obj, ) if res.is_err(): return res @@ -590,7 +584,7 @@ def _update_migrated_actionobjects( roles=ADMIN_ROLE_LEVEL, ) def get_migration_data( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> MigrationData | SyftError: store_objects_result = self._get_migration_objects(context, get_all=True) if store_objects_result.is_err(): @@ -633,12 +627,12 @@ def apply_migration_data( if len(migration_data.blobs): return SyftError( message="Blob storage migration is not supported by this endpoint, " - "please use 'client.load_migration_data' instead." + "please use 'client.load_migration_data' instead.", ) # migrate + apply store objects migrated_objects_result = self._migrate_objects( - context, migration_data.store_objects + context, migration_data.store_objects, ) if migrated_objects_result.is_err(): return SyftError(message=migrated_objects_result.err()) @@ -649,13 +643,13 @@ def apply_migration_data( # migrate+apply action objects migrated_actionobjects = self._migrate_objects( - context, migration_data.action_objects + context, migration_data.action_objects, ) if migrated_actionobjects.is_err(): return SyftError(message=migrated_actionobjects.err()) migrated_actionobjects = migrated_actionobjects.ok() action_objects_result = self._update_migrated_actionobjects( - context, migrated_actionobjects + context, migrated_actionobjects, ) if action_objects_result.is_err(): return SyftError(message=action_objects_result.err()) diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 815bda81c83..bfddb0c70ec 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -1,36 +1,38 @@ # stdlib +import sys from io import BytesIO from pathlib import Path -import sys from typing import Any +import yaml + # third party from result import Result from typing_extensions import Self -import yaml # relative from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize -from ...server.credentials import SyftSigningKey -from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...types.blob_storage import BlobStorageEntry -from ...types.blob_storage import CreateBlobStorageEntry -from ...types.syft_object import Context -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftBaseObject -from ...types.syft_object import SyftObject +from ...server.credentials import SyftSigningKey, SyftVerifyKey +from ...store.document_store import ( + BaseStash, + DocumentStore, + PartitionKey, + PartitionSettings, +) +from ...types.blob_storage import BlobStorageEntry, CreateBlobStorageEntry +from ...types.syft_object import ( + SYFT_OBJECT_VERSION_1, + Context, + SyftBaseObject, + SyftObject, +) from ...types.syft_object_registry import SyftObjectRegistry from ...types.uid import UID from ...util.util import prompt_warning_message from ..action.action_permissions import ActionObjectPermission -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess @serializable() @@ -93,7 +95,7 @@ def set( ) def get_by_name( - self, canonical_name: str, credentials: SyftVerifyKey + self, canonical_name: str, credentials: SyftVerifyKey, ) -> Result[SyftObjectMigrationState, str]: qks = KlassNamePartitionKey.with_obj(canonical_name) return self.query_one(credentials=credentials, qks=qks) @@ -155,8 +157,8 @@ def make_migration_config(self) -> dict[str, Any]: "env": [ {"name": "SERVER_UID", "value": server_uid}, {"name": "SERVER_PRIVATE_KEY", "value": server_private_key}, - ] - } + ], + }, } return migration_config @@ -164,7 +166,7 @@ def make_migration_config(self) -> dict[str, Any]: def from_file(self, path: str | Path) -> Self | SyftError: path = Path(path) if not path.exists(): - return SyftError(f"File {str(path)} does not exist.") + return SyftError(f"File {path!s} does not exist.") with open(path, "rb") as f: res: MigrationData = _deserialize(f.read(), from_bytes=True) @@ -191,7 +193,7 @@ def save(self, path: str | Path, yaml_path: str | Path) -> SyftSuccess | SyftErr with open(yaml_path, "w") as f: yaml.dump(migration_config, f) - return SyftSuccess(message=f"Migration data saved to {str(path)}.") + return SyftSuccess(message=f"Migration data saved to {path!s}.") def download_blobs(self) -> None | SyftError: for obj in self.blob_storage_objects: @@ -233,7 +235,7 @@ def migrate_and_upload_blob(self, obj: BlobStorageEntry) -> SyftSuccess | SyftEr blob_create = CreateBlobStorageEntry.from_blob_storage_entry(migrated_obj) blob_create.file_size = size blob_deposit_object = api.services.blob_storage.allocate_for_user( - blob_create, migrated_obj.uploaded_by + blob_create, migrated_obj.uploaded_by, ) if isinstance(blob_deposit_object, SyftError): diff --git a/packages/syft/src/syft/service/network/association_request.py b/packages/syft/src/syft/service/network/association_request.py index ad0d358dd4b..b325ef88573 100644 --- a/packages/syft/src/syft/service/network/association_request.py +++ b/packages/syft/src/syft/service/network/association_request.py @@ -3,9 +3,7 @@ from typing import cast # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...client.client import SyftClient @@ -13,8 +11,7 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ..context import ChangeContext from ..request.request import Change -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from .routes import ServerRoute from .server_peer import ServerPeer @@ -31,17 +28,19 @@ class AssociationRequestChange(Change): __repr_attrs__ = ["self_server_route", "remote_peer"] def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[tuple[bytes, ServerPeer], SyftError]: - """ - Executes the association request. + """Executes the association request. Args: + ---- context (ChangeContext): The change context. apply (bool): A flag indicating whether to apply the association request. Returns: + ------- Result[tuple[bytes, ServerPeer], SyftError]: The result of the association request. + """ # relative from .network_service import NetworkService @@ -49,13 +48,13 @@ def _run( if not apply: # TODO: implement undo for AssociationRequestChange return Err( - SyftError(message="Undo not supported for AssociationRequestChange") + SyftError(message="Undo not supported for AssociationRequestChange"), ) # Get the network service service_ctx = context.to_service_ctx() network_service = cast( - NetworkService, service_ctx.server.get_service(NetworkService) + NetworkService, service_ctx.server.get_service(NetworkService), ) network_stash = network_service.stash @@ -76,17 +75,17 @@ def _run( # Pinging the remote peer to verify the connection try: remote_client: SyftClient = self.remote_peer.client_with_context( - context=service_ctx + context=service_ctx, ) if remote_client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{self.remote_peer.id}. Error: {remote_client.err()}" + f"{self.remote_peer.id}. Error: {remote_client.err()}", ) remote_client = remote_client.ok() random_challenge = secrets.token_bytes(16) remote_res = remote_client.api.services.network.ping( - challenge=random_challenge + challenge=random_challenge, ) except Exception as e: return SyftError(message="Remote Peer cannot ping peer:" + str(e)) @@ -99,14 +98,14 @@ def _run( # Verifying if the challenge is valid try: self.remote_peer.verify_key.verify_key.verify( - random_challenge, challenge_signature + random_challenge, challenge_signature, ) except Exception as e: return Err(SyftError(message=str(e))) # Adding the remote peer to the network stash result = network_stash.create_or_update_peer( - service_ctx.server.verify_key, self.remote_peer + service_ctx.server.verify_key, self.remote_peer, ) if result.is_err(): @@ -115,7 +114,7 @@ def _run( # this way they can match up who we are with who they think we are # Sending a signed messages for the peer to verify self_server_peer = self.self_server_route.validate_with_context( - context=service_ctx + context=service_ctx, ) if isinstance(self_server_peer, SyftError): @@ -123,8 +122,8 @@ def _run( return Ok( SyftSuccess( - message=f"Routes successfully added for peer: {self.remote_peer.name}" - ) + message=f"Routes successfully added for peer: {self.remote_peer.name}", + ), ) def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 24f117b7323..06fdc5a0d3f 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -1,65 +1,51 @@ # stdlib -from collections.abc import Callable -from enum import Enum import logging import secrets -from typing import Any -from typing import cast +from collections.abc import Callable +from enum import Enum +from typing import Any, cast # third party from result import Result # relative from ...abstract_server import ServerType -from ...client.client import HTTPConnection -from ...client.client import PythonConnection -from ...client.client import SyftClient +from ...client.client import HTTPConnection, PythonConnection, SyftClient from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, +) from ...types.server_url import ServerURL -from ...types.transforms import TransformContext -from ...types.transforms import keep -from ...types.transforms import make_set_default -from ...types.transforms import transform -from ...types.transforms import transform_method +from ...types.transforms import ( + TransformContext, + keep, + make_set_default, + transform, + transform_method, +) from ...types.uid import UID from ...util.telemetry import instrument -from ...util.util import generate_token -from ...util.util import get_env -from ...util.util import prompt_warning_message -from ...util.util import str_to_bool +from ...util.util import generate_token, get_env, prompt_warning_message, str_to_bool from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey from ..metadata.server_metadata import ServerMetadata -from ..request.request import Request -from ..request.request import RequestStatus -from ..request.request import SubmitRequest +from ..request.request import Request, RequestStatus, SubmitRequest from ..request.request_service import RequestService -from ..response import SyftError -from ..response import SyftInfo -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import DATA_OWNER_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL +from ..response import SyftError, SyftInfo, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import DATA_OWNER_ROLE_LEVEL, GUEST_ROLE_LEVEL from ..warnings import CRUDWarning from .association_request import AssociationRequestChange from .reverse_tunnel_service import ReverseTunnelService -from .routes import HTTPServerRoute -from .routes import PythonServerRoute -from .routes import ServerRoute -from .routes import ServerRouteType -from .server_peer import ServerPeer -from .server_peer import ServerPeerUpdate +from .routes import HTTPServerRoute, PythonServerRoute, ServerRoute, ServerRouteType +from .server_peer import ServerPeer, ServerPeerUpdate logger = logging.getLogger(__name__) @@ -86,14 +72,14 @@ class ServerPeerAssociationStatus(Enum): class NetworkStash(BaseUIDStoreStash): object_type = ServerPeer settings: PartitionSettings = PartitionSettings( - name=ServerPeer.__canonical_name__, object_type=ServerPeer + name=ServerPeer.__canonical_name__, object_type=ServerPeer, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_by_name( - self, credentials: SyftVerifyKey, name: str + self, credentials: SyftVerifyKey, name: str, ) -> Result[ServerPeer | None, str]: qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) return self.query_one(credentials=credentials, qks=qks) @@ -110,31 +96,33 @@ def update( return super().update(credentials, peer_update, has_permission=has_permission) def create_or_update_peer( - self, credentials: SyftVerifyKey, peer: ServerPeer + self, credentials: SyftVerifyKey, peer: ServerPeer, ) -> Result[ServerPeer, str]: - """ - Update the selected peer and its route priorities if the peer already exists + """Update the selected peer and its route priorities if the peer already exists If the peer does not exist, simply adds it to the database. Args: + ---- credentials (SyftVerifyKey): The credentials used to authenticate the request. peer (ServerPeer): The peer to be updated or added. Returns: + ------- Result[ServerPeer, str]: The updated or added peer if the operation was successful, or an error message if the operation failed. + """ valid = self.check_type(peer, ServerPeer) if valid.is_err(): return SyftError(message=valid.err()) existing: Result | ServerPeer = self.get_by_uid( - credentials=credentials, uid=peer.id + credentials=credentials, uid=peer.id, ) if existing.is_ok() and existing.ok(): existing_peer = existing.ok() existing_peer.update_routes(peer.server_routes) peer_update = ServerPeerUpdate( - id=peer.id, server_routes=existing_peer.server_routes + id=peer.id, server_routes=existing_peer.server_routes, ) result = self.update(credentials, peer_update) else: @@ -142,17 +130,17 @@ def create_or_update_peer( return result def get_by_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> Result[ServerPeer | None, SyftError]: qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) return self.query_one(credentials, qks) def get_by_server_type( - self, credentials: SyftVerifyKey, server_type: ServerType + self, credentials: SyftVerifyKey, server_type: ServerType, ) -> Result[list[ServerPeer], SyftError]: qks = QueryKeys(qks=[ServerTypePartitionKey.with_obj(server_type)]) return self.query_all( - credentials=credentials, qks=qks, order_by=OrderByNamePartitionKey + credentials=credentials, qks=qks, order_by=OrderByNamePartitionKey, ) @@ -182,10 +170,8 @@ def exchange_credentials_with( remote_server_verify_key: SyftVerifyKey, reverse_tunnel: bool = False, ) -> Request | SyftSuccess | SyftError: + """Exchange Route With Another Server. If there is a pending association request, return it """ - Exchange Route With Another Server. If there is a pending association request, return it - """ - # Step 1: Validate the Route self_server_peer = self_server_route.validate_with_context(context=context) @@ -204,30 +190,30 @@ def exchange_credentials_with( # Also give them their own to validate that it belongs to them # random challenge prevents replay attacks remote_client: SyftClient = remote_server_route.client_with_context( - context=context + context=context, ) remote_server_peer = ServerPeer.from_client(remote_client) # Step 3: Check remotely if the self server already exists as a peer # Update the peer if it exists, otherwise add it remote_self_server_peer = remote_client.api.services.network.get_peer_by_name( - name=self_server_peer.name + name=self_server_peer.name, ) association_request_approved = True if isinstance(remote_self_server_peer, ServerPeer): updated_peer = ServerPeerUpdate( - id=self_server_peer.id, server_routes=self_server_peer.server_routes + id=self_server_peer.id, server_routes=self_server_peer.server_routes, ) result = remote_client.api.services.network.update_peer( - peer_update=updated_peer + peer_update=updated_peer, ) if isinstance(result, SyftError): logger.error( - f"Failed to update peer information on remote client. {result.message}" + f"Failed to update peer information on remote client. {result.message}", ) return SyftError( - message=f"Failed to add peer information on remote client : {remote_client.id}" + message=f"Failed to add peer information on remote client : {remote_client.id}", ) # If peer does not exist, ask the remote client to add this server @@ -243,7 +229,7 @@ def exchange_credentials_with( if isinstance(remote_res, SyftError): return SyftError( - message=f"Failed to add peer to remote client: {remote_client.id}. Error: {remote_res.message}" + message=f"Failed to add peer to remote client: {remote_client.id}. Error: {remote_res.message}", ) association_request_approved = not isinstance(remote_res, Request) @@ -255,7 +241,7 @@ def exchange_credentials_with( ) if result.is_err(): logging.error( - f"Failed to save peer: {remote_server_peer}. Error: {result.err()}" + f"Failed to save peer: {remote_server_peer}. Error: {result.err()}", ) return SyftError(message="Failed to update route information.") @@ -292,24 +278,24 @@ def add_peer( message=( f"The {type(peer).__name__}.verify_key: " f"{peer.verify_key} does not match the signature of the message" - ) + ), ) if verify_key != context.server.verify_key: return SyftError( - message="verify_key does not match the remote server's verify_key for add_peer" + message="verify_key does not match the remote server's verify_key for add_peer", ) # check if the peer already is a server peer existing_peer_res = self.stash.get_by_uid(context.server.verify_key, peer.id) if existing_peer_res.is_err(): return SyftError( - message=f"Failed to query peer from stash: {existing_peer_res.err()}" + message=f"Failed to query peer from stash: {existing_peer_res.err()}", ) if isinstance(existing_peer := existing_peer_res.ok(), ServerPeer): msg = [ - f"The peer '{peer.name}' is already associated with '{context.server.name}'" + f"The peer '{peer.name}' is already associated with '{context.server.name}'", ] if existing_peer != peer: @@ -330,7 +316,7 @@ def add_peer( # check if the peer already submitted an association request association_requests: list[Request] = self._get_association_requests_by_peer_id( - context=context, peer_id=peer.id + context=context, peer_id=peer.id, ) if ( association_requests @@ -341,7 +327,7 @@ def add_peer( # only create and submit a new request if there is no requests yet # or all previous requests have been rejected association_request_change = AssociationRequestChange( - self_server_route=self_server_route, challenge=challenge, remote_peer=peer + self_server_route=self_server_route, challenge=challenge, remote_peer=peer, ) submit_request = SubmitRequest( changes=[association_request_change], @@ -354,7 +340,7 @@ def add_peer( and context.server.settings.association_request_auto_approval ): request_apply_method = context.server.get_service_method( - RequestService.apply + RequestService.apply, ) return request_apply_method(context, uid=request.id) @@ -362,10 +348,9 @@ def add_peer( @service_method(path="network.ping", name="ping", roles=GUEST_ROLE_LEVEL) def ping( - self, context: AuthedServiceContext, challenge: bytes + self, context: AuthedServiceContext, challenge: bytes, ) -> bytes | SyftError: """To check alivesness/authenticity of a peer""" - # # Only the root user can ping the server to check its state # if context.server.verify_key != context.credentials: # return SyftError(message=("Only the root user can access ping endpoint")) @@ -374,7 +359,7 @@ def ping( # Sending a signed messages for the peer to verify challenge_signature = context.server.signing_key.signing_key.sign( - challenge + challenge, ).signature return challenge_signature @@ -385,10 +370,9 @@ def ping( roles=GUEST_ROLE_LEVEL, ) def check_peer_association( - self, context: AuthedServiceContext, peer_id: UID + self, context: AuthedServiceContext, peer_id: UID, ) -> ServerPeerAssociationStatus | SyftError: """Check if a peer exists in the network stash""" - # get the server peer for the given sender peer_id peer = self.stash.get_by_uid(context.server.verify_key, peer_id) if err := peer.is_err(): @@ -400,7 +384,7 @@ def check_peer_association( if peer.ok() is None: # peer is either pending or not found association_requests: list[Request] = ( self._get_association_requests_by_peer_id( - context=context, peer_id=peer_id + context=context, peer_id=peer_id, ) ) if ( @@ -412,13 +396,12 @@ def check_peer_association( return ServerPeerAssociationStatus.PEER_NOT_FOUND @service_method( - path="network.get_all_peers", name="get_all_peers", roles=GUEST_ROLE_LEVEL + path="network.get_all_peers", name="get_all_peers", roles=GUEST_ROLE_LEVEL, ) def get_all_peers( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[ServerPeer] | SyftError: """Get all Peers""" - result = self.stash.get_all( credentials=context.server.verify_key, order_by=OrderByNamePartitionKey, @@ -429,13 +412,12 @@ def get_all_peers( return SyftError(message=result.err()) @service_method( - path="network.get_peer_by_name", name="get_peer_by_name", roles=GUEST_ROLE_LEVEL + path="network.get_peer_by_name", name="get_peer_by_name", roles=GUEST_ROLE_LEVEL, ) def get_peer_by_name( - self, context: AuthedServiceContext, name: str + self, context: AuthedServiceContext, name: str, ) -> ServerPeer | None | SyftError: """Get Peer by Name""" - result = self.stash.get_by_name( credentials=context.server.verify_key, name=name, @@ -451,7 +433,7 @@ def get_peer_by_name( roles=GUEST_ROLE_LEVEL, ) def get_peers_by_type( - self, context: AuthedServiceContext, server_type: ServerType + self, context: AuthedServiceContext, server_type: ServerType, ) -> list[ServerPeer] | SyftError: result = self.stash.get_by_server_type( credentials=context.server.verify_key, @@ -482,14 +464,14 @@ def update_peer( ) if result.is_err(): return SyftError( - message=f"Failed to update peer '{peer_update.name}'. Error: {result.err()}" + message=f"Failed to update peer '{peer_update.name}'. Error: {result.err()}", ) peer = result.ok() self.set_reverse_tunnel_config(context=context, remote_server_peer=peer) return SyftSuccess( - message=f"Peer '{result.ok().name}' information successfully updated." + message=f"Peer '{result.ok().name}' information successfully updated.", ) def set_reverse_tunnel_config( @@ -528,13 +510,13 @@ def set_reverse_tunnel_config( roles=DATA_OWNER_ROLE_LEVEL, ) def delete_peer_by_id( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Delete Server Peer""" retrieve_result = self.stash.get_by_uid(context.credentials, uid) if err := retrieve_result.is_err(): return SyftError( - message=f"Failed to retrieve peer with UID {uid}: {retrieve_result.err()}." + message=f"Failed to retrieve peer with UID {uid}: {retrieve_result.err()}.", ) peer_to_delete = cast(ServerPeer, retrieve_result.ok()) @@ -554,11 +536,11 @@ def delete_peer_by_id( return SyftError(message=f"Failed to delete peer with UID {uid}: {err}.") # Delete all the association requests from this peer association_requests: list[Request] = self._get_association_requests_by_peer_id( - context=context, peer_id=uid + context=context, peer_id=uid, ) for request in association_requests: request_delete_method = context.server.get_service_method( - RequestService.delete_by_uid + RequestService.delete_by_uid, ) res = request_delete_method(context, request.id) if isinstance(res, SyftError): @@ -574,17 +556,19 @@ def add_route_on_peer( peer: ServerPeer, route: ServerRoute, ) -> SyftSuccess | SyftError: - """ - Add or update the route information on the remote peer. + """Add or update the route information on the remote peer. Args: + ---- context (AuthedServiceContext): The authentication context. peer (ServerPeer): The peer representing the remote server. route (ServerRoute): The route to be added. Returns: + ------- SyftSuccess | SyftError: A success message if the route is verified, otherwise an error message. + """ # creates a client on the remote server based on the credentials # of the current server's client @@ -592,7 +576,7 @@ def add_route_on_peer( if remote_client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{peer.id}. Error: {remote_client.err()}" + f"{peer.id}. Error: {remote_client.err()}", ) remote_client = remote_client.ok() # ask the remote server to add the route to the self server @@ -611,17 +595,19 @@ def add_route( route: ServerRoute, called_by_peer: bool = False, ) -> SyftSuccess | SyftError: - """ - Add a route to the peer. If the route already exists, update its priority. + """Add a route to the peer. If the route already exists, update its priority. Args: + ---- context (AuthedServiceContext): The authentication context of the remote server. peer_verify_key (SyftVerifyKey): The verify key of the remote server peer. route (ServerRoute): The route to be added. called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: + ------- SyftSuccess | SyftError + """ # verify if the peer is truly the one sending the request to add the route to itself if called_by_peer and peer_verify_key != context.credentials: @@ -629,7 +615,7 @@ def add_route( message=( f"The {type(peer_verify_key).__name__}: " f"{peer_verify_key} does not match the signature of the message" - ) + ), ) # get the full peer object from the store to update its routes remote_server_peer: ServerPeer | SyftError = ( @@ -641,13 +627,13 @@ def add_route( if route in remote_server_peer.server_routes: return SyftSuccess( message=f"The route already exists between '{context.server.name}' and " - f"peer '{remote_server_peer.name}'." + f"peer '{remote_server_peer.name}'.", ) remote_server_peer.update_route(route=route) # update the peer in the store with the updated routes peer_update = ServerPeerUpdate( - id=remote_server_peer.id, server_routes=remote_server_peer.server_routes + id=remote_server_peer.id, server_routes=remote_server_peer.server_routes, ) result = self.stash.update( credentials=context.server.verify_key, @@ -656,9 +642,9 @@ def add_route( if result.is_err(): return SyftError(message=str(result.err())) return SyftSuccess( - message=f"New route ({str(route)}) with id '{route.id}' " + message=f"New route ({route!s}) with id '{route.id}' " f"to peer {remote_server_peer.server_type.value} '{remote_server_peer.name}' " - f"was added for {str(context.server.server_type)} '{context.server.name}'" + f"was added for {context.server.server_type!s} '{context.server.name}'", ) @service_method(path="network.delete_route_on_peer", name="delete_route_on_peer") @@ -668,19 +654,21 @@ def delete_route_on_peer( peer: ServerPeer, route: ServerRoute, ) -> SyftSuccess | SyftError | SyftInfo: - """ - Delete the route on the remote peer. + """Delete the route on the remote peer. Args: + ---- context (AuthedServiceContext): The authentication context for the service. peer (ServerPeer): The peer for which the route will be deleted. route (ServerRoute): The route to be deleted. Returns: + ------- SyftSuccess: If the route is successfully deleted. SyftError: If there is an error deleting the route. SyftInfo: If there is only one route left for the peer and the admin chose not to remove it + """ # creates a client on the remote server based on the credentials # of the current server's client @@ -688,7 +676,7 @@ def delete_route_on_peer( if remote_client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{peer.id}. Error: {remote_client.err()}" + f"{peer.id}. Error: {remote_client.err()}", ) remote_client = remote_client.ok() # ask the remote server to delete the route to the self server, @@ -700,7 +688,7 @@ def delete_route_on_peer( return result @service_method( - path="network.delete_route", name="delete_route", roles=GUEST_ROLE_LEVEL + path="network.delete_route", name="delete_route", roles=GUEST_ROLE_LEVEL, ) def delete_route( self, @@ -709,22 +697,24 @@ def delete_route( route: ServerRoute | None = None, called_by_peer: bool = False, ) -> SyftSuccess | SyftError | SyftInfo: - """ - Delete a route for a given peer. + """Delete a route for a given peer. If a peer has no routes left, there will be a prompt asking if the user want to remove it. If the answer is yes, it will be removed from the stash and will no longer be a peer. Args: + ---- context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the remote server peer. route (ServerRoute): The route to be deleted. called_by_peer (bool): The flag to indicate that it's called by a remote peer. Returns: + ------- SyftSuccess: If the route is successfully deleted. SyftError: If there is an error deleting the route. SyftInfo: If there is only one route left for the peer and the admin chose not to remove it + """ if called_by_peer and peer_verify_key != context.credentials: # verify if the peer is truly the one sending the request to delete the route to itself @@ -732,12 +722,12 @@ def delete_route( message=( f"The {type(peer_verify_key).__name__}: " f"{peer_verify_key} does not match the signature of the message" - ) + ), ) remote_server_peer: ServerPeer | SyftError = ( self._get_remote_server_peer_by_verify_key( - context=context, peer_verify_key=peer_verify_key + context=context, peer_verify_key=peer_verify_key, ) ) @@ -746,7 +736,7 @@ def delete_route( f"There is only one route left to peer " f"{remote_server_peer.server_type.value} '{remote_server_peer.name}'. " f"Removing this route will remove the peer for " - f"{str(context.server.server_type)} '{context.server.name}'." + f"{context.server.server_type!s} '{context.server.name}'." ) response: bool = prompt_warning_message( message=warning_message, @@ -756,14 +746,14 @@ def delete_route( return SyftInfo( message=f"The last route to {remote_server_peer.server_type.value} " f"'{remote_server_peer.name}' with id " - f"'{remote_server_peer.server_routes[0].id}' was not deleted." + f"'{remote_server_peer.server_routes[0].id}' was not deleted.", ) result = remote_server_peer.delete_route(route=route) return_message = ( - f"Route '{str(route)}' to peer " + f"Route '{route!s}' to peer " f"{remote_server_peer.server_type.value} '{remote_server_peer.name}' " - f"was deleted for {str(context.server.server_type)} '{context.server.name}'." + f"was deleted for {context.server.server_type!s} '{context.server.name}'." ) if isinstance(result, SyftError): return result @@ -772,22 +762,22 @@ def delete_route( # remove the peer # TODO: should we do this as we are deleting the peer with a guest role level? result = self.stash.delete_by_uid( - credentials=context.server.verify_key, uid=remote_server_peer.id + credentials=context.server.verify_key, uid=remote_server_peer.id, ) if isinstance(result, SyftError): return result return_message += ( f" There is no routes left to connect to peer " f"{remote_server_peer.server_type.value} '{remote_server_peer.name}', so it is deleted for " - f"{str(context.server.server_type)} '{context.server.name}'." + f"{context.server.server_type!s} '{context.server.name}'." ) else: # update the peer with the route removed peer_update = ServerPeerUpdate( - id=remote_server_peer.id, server_routes=remote_server_peer.server_routes + id=remote_server_peer.id, server_routes=remote_server_peer.server_routes, ) result = self.stash.update( - credentials=context.server.verify_key, peer_update=peer_update + credentials=context.server.verify_key, peer_update=peer_update, ) if result.is_err(): return SyftError(message=str(result.err())) @@ -805,10 +795,10 @@ def update_route_priority_on_peer( route: ServerRoute, priority: int | None = None, ) -> SyftSuccess | SyftError: - """ - Update the route priority on the remote peer. + """Update the route priority on the remote peer. Args: + ---- context (AuthedServiceContext): The authentication context. peer (ServerPeer): The peer representing the remote server. route (ServerRoute): The route to be added. @@ -816,8 +806,10 @@ def update_route_priority_on_peer( provided, it will be assigned the highest priority among all peers Returns: + ------- SyftSuccess | SyftError: A success message if the route is verified, otherwise an error message. + """ # creates a client on the remote server based on the credentials # of the current server's client @@ -825,7 +817,7 @@ def update_route_priority_on_peer( if remote_client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{peer.id}. Error: {remote_client.err()}" + f"{peer.id}. Error: {remote_client.err()}", ) remote_client = remote_client.ok() result = remote_client.api.services.network.update_route_priority( @@ -849,10 +841,10 @@ def update_route_priority( priority: int | None = None, called_by_peer: bool = False, ) -> SyftSuccess | SyftError: - """ - Updates a route's priority for the given peer + """Updates a route's priority for the given peer Args: + ---- context (AuthedServiceContext): The authentication context for the service. peer_verify_key (SyftVerifyKey): The verify key of the peer whose route priority needs to be updated. route (ServerRoute): The route for which the priority needs to be updated. @@ -860,14 +852,16 @@ def update_route_priority( provided, it will be assigned the highest priority among all peers Returns: + ------- SyftSuccess | SyftError: Successful / Error response + """ if called_by_peer and peer_verify_key != context.credentials: return SyftError( message=( f"The {type(peer_verify_key).__name__}: " f"{peer_verify_key} does not match the signature of the message" - ) + ), ) # get the full peer object from the store to update its routes remote_server_peer: ServerPeer | SyftError = ( @@ -878,7 +872,7 @@ def update_route_priority( # update the route's priority for the peer updated_server_route: ServerRouteType | SyftError = ( remote_server_peer.update_existed_route_priority( - route=route, priority=priority + route=route, priority=priority, ) ) if isinstance(updated_server_route, SyftError): @@ -886,7 +880,7 @@ def update_route_priority( new_priority: int = updated_server_route.priority # update the peer in the store peer_update = ServerPeerUpdate( - id=remote_server_peer.id, server_routes=remote_server_peer.server_routes + id=remote_server_peer.id, server_routes=remote_server_peer.server_routes, ) result = self.stash.update(context.server.verify_key, peer_update) if result.is_err(): @@ -894,14 +888,13 @@ def update_route_priority( return SyftSuccess( message=f"Route {route.id}'s priority updated to " - f"{new_priority} for peer {remote_server_peer.name}" + f"{new_priority} for peer {remote_server_peer.name}", ) def _get_remote_server_peer_by_verify_key( - self, context: AuthedServiceContext, peer_verify_key: SyftVerifyKey + self, context: AuthedServiceContext, peer_verify_key: SyftVerifyKey, ) -> ServerPeer | SyftError: - """ - Helper function to get the full server peer object from t + """Helper function to get the full server peer object from t he stash using its verify key """ remote_server_peer: Result[ServerPeer | None, SyftError] = ( @@ -915,18 +908,17 @@ def _get_remote_server_peer_by_verify_key( remote_server_peer = remote_server_peer.ok() if remote_server_peer is None: return SyftError( - message=f"Can't retrieve {remote_server_peer.name} from the store of peers (None)." + message=f"Can't retrieve {remote_server_peer.name} from the store of peers (None).", ) return remote_server_peer def _get_association_requests_by_peer_id( - self, context: AuthedServiceContext, peer_id: UID + self, context: AuthedServiceContext, peer_id: UID, ) -> list[Request]: - """ - Get all the association requests from a peer. The association requests are sorted by request_time. + """Get all the association requests from a peer. The association requests are sorted by request_time. """ request_get_all_method: Callable = context.server.get_service_method( - RequestService.get_all + RequestService.get_all, ) all_requests: list[Request] = request_get_all_method(context) association_requests: list[Request] = [ @@ -940,7 +932,7 @@ def _get_association_requests_by_peer_id( ] return sorted( - association_requests, key=lambda request: request.request_time.utc_timestamp + association_requests, key=lambda request: request.request_time.utc_timestamp, ) @@ -971,7 +963,7 @@ def get_python_server_route(context: TransformContext) -> TransformContext: if context.output is not None and context.obj is not None: context.output["id"] = context.obj.server.id context.output["worker_settings"] = WorkerSettings.from_server( - context.obj.server + context.obj.server, ) context.output["proxy_target_uid"] = context.obj.proxy_target_uid return context @@ -984,17 +976,17 @@ def python_connection_to_server_route() -> list[Callable]: @transform_method(PythonServerRoute, PythonConnection) def server_route_to_python_connection( - obj: Any, context: TransformContext | None = None + obj: Any, context: TransformContext | None = None, ) -> list[Callable]: return PythonConnection(server=obj.server, proxy_target_uid=obj.proxy_target_uid) @transform_method(HTTPServerRoute, HTTPConnection) def server_route_to_http_connection( - obj: Any, context: TransformContext | None = None + obj: Any, context: TransformContext | None = None, ) -> list[Callable]: url = ServerURL( - protocol=obj.protocol, host_or_ip=obj.host_or_ip, port=obj.port + protocol=obj.protocol, host_or_ip=obj.host_or_ip, port=obj.port, ).as_container_host() return HTTPConnection( url=url, @@ -1012,7 +1004,7 @@ def metadata_to_peer() -> list[Callable]: "name", "verify_key", "server_type", - ] + ], ), make_set_default("admin_email", ""), ] diff --git a/packages/syft/src/syft/service/network/rathole_config_builder.py b/packages/syft/src/syft/service/network/rathole_config_builder.py index fb0ef01d798..b0303d6cf27 100644 --- a/packages/syft/src/syft/service/network/rathole_config_builder.py +++ b/packages/syft/src/syft/service/network/rathole_config_builder.py @@ -2,18 +2,16 @@ import secrets from typing import cast +import yaml + # third party from kr8s.objects import Service -import yaml # relative -from ...custom_worker.k8s import KubeUtils -from ...custom_worker.k8s import get_kr8s_client +from ...custom_worker.k8s import KubeUtils, get_kr8s_client from ...types.uid import UID -from .rathole import RatholeConfig -from .rathole import get_rathole_port -from .rathole_toml import RatholeClientToml -from .rathole_toml import RatholeServerToml +from .rathole import RatholeConfig, get_rathole_port +from .rathole_toml import RatholeClientToml, RatholeServerToml from .server_peer import ServerPeer RATHOLE_TOML_CONFIG_MAP = "rathole-config" @@ -30,12 +28,14 @@ def add_host_to_server(self, peer: ServerPeer) -> None: """Add a host to the rathole server toml file. Args: + ---- peer (ServerPeer): The peer to be added to the rathole server. Returns: + ------- None - """ + """ rathole_route = peer.get_rtunnel_route() if not rathole_route: raise Exception(f"Peer: {peer} has no rathole route: {rathole_route}") @@ -54,7 +54,7 @@ def add_host_to_server(self, peer: ServerPeer) -> None: # Get rathole toml config map rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP, ) if rathole_config_map is None: @@ -85,15 +85,17 @@ def remove_host_from_server(self, peer_id: str, server_name: str) -> None: """Remove a host from the rathole server toml file. Args: + ---- peer_id (str): The id of the peer to be removed. server_name (str): The name of the peer to be removed. Returns: + ------- None - """ + """ rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP, ) if rathole_config_map is None: @@ -120,10 +122,9 @@ def _get_random_port(self) -> int: return secrets.randbits(15) def add_host_to_client( - self, peer_name: str, peer_id: str, rtunnel_token: str, remote_addr: str + self, peer_name: str, peer_id: str, rtunnel_token: str, remote_addr: str, ) -> None: """Add a host to the rathole client toml file.""" - config = RatholeConfig( uuid=peer_id, secret_token=rtunnel_token, @@ -134,7 +135,7 @@ def add_host_to_client( # Get rathole toml config map rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP, ) if rathole_config_map is None: @@ -157,9 +158,8 @@ def add_host_to_client( def remove_host_from_client(self, peer_id: str) -> None: """Remove a host from the rathole client toml file.""" - rathole_config_map = KubeUtils.get_configmap( - client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP + client=self.k8rs_client, name=RATHOLE_TOML_CONFIG_MAP, ) if rathole_config_map is None: @@ -181,12 +181,11 @@ def remove_host_from_client(self, peer_id: str) -> None: KubeUtils.update_configmap(config_map=rathole_config_map, patch={"data": data}) def _add_dynamic_addr_to_rathole( - self, config: RatholeConfig, entrypoint: str = "web" + self, config: RatholeConfig, entrypoint: str = "web", ) -> None: """Add a port to the rathole proxy config map.""" - rathole_proxy_config_map = KubeUtils.get_configmap( - self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP, ) if rathole_proxy_config_map is None: @@ -201,12 +200,12 @@ def _add_dynamic_addr_to_rathole( rathole_proxy["http"]["services"][config.server_name] = { "loadBalancer": { - "servers": [{"url": f"http://rathole:{config.local_addr_port}"}] - } + "servers": [{"url": f"http://rathole:{config.local_addr_port}"}], + }, } rathole_proxy["http"]["middlewares"]["strip-rathole-prefix"] = { - "replacePathRegex": {"regex": "^/rathole/(.*)", "replacement": "/$1"} + "replacePathRegex": {"regex": "^/rathole/(.*)", "replacement": "/$1"}, } proxy_rule = ( @@ -230,9 +229,8 @@ def _add_dynamic_addr_to_rathole( def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: """Remove a port from the rathole proxy config map.""" - rathole_proxy_config_map = KubeUtils.get_configmap( - self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP + self.k8rs_client, RATHOLE_PROXY_CONFIG_MAP, ) if rathole_proxy_config_map is None: @@ -260,7 +258,6 @@ def _remove_dynamic_addr_from_rathole(self, server_name: str) -> None: def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: """Expose a port on the rathole service.""" - rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) @@ -284,14 +281,13 @@ def _expose_port_on_rathole_service(self, port_name: str, port: int) -> None: "port": port, "targetPort": port, "protocol": "TCP", - } + }, ) rathole_service.patch(config) def _remove_port_on_rathole_service(self, port_name: str) -> None: """Remove a port from the rathole service.""" - rathole_service = KubeUtils.get_service(self.k8rs_client, "rathole") rathole_service = cast(Service, rathole_service) diff --git a/packages/syft/src/syft/service/network/rathole_toml.py b/packages/syft/src/syft/service/network/rathole_toml.py index 8ded821279e..0d4a5ef7e63 100644 --- a/packages/syft/src/syft/service/network/rathole_toml.py +++ b/packages/syft/src/syft/service/network/rathole_toml.py @@ -42,7 +42,6 @@ class RatholeClientToml(RatholeBaseToml): def set_remote_addr(self, remote_host: str) -> None: """Add a new remote address to the client toml file.""" - toml = self.read() # Add the new remote address @@ -55,7 +54,6 @@ def set_remote_addr(self, remote_host: str) -> None: def clear_remote_addr(self) -> None: """Clear the remote address from the client toml file.""" - toml = self.read() # Clear the remote address @@ -68,7 +66,6 @@ def clear_remote_addr(self) -> None: def add_config(self, config: RatholeConfig) -> None: """Add a new config to the toml file.""" - toml = self.read() # Add the new config @@ -87,7 +84,6 @@ def add_config(self, config: RatholeConfig) -> None: def remove_config(self, uuid: str) -> None: """Remove a config from the toml file.""" - toml = self.read() # Remove the config @@ -103,7 +99,6 @@ def remove_config(self, uuid: str) -> None: def update_config(self, config: RatholeConfig) -> None: """Update a config in the toml file.""" - toml = self.read() # Update the config @@ -122,7 +117,6 @@ def update_config(self, config: RatholeConfig) -> None: def get_config(self, uuid: str) -> RatholeConfig | None: """Get a config from the toml file.""" - toml = self.read() # Get the config @@ -162,7 +156,6 @@ class RatholeServerToml(RatholeBaseToml): def set_rathole_listener_addr(self, bind_addr: str) -> None: """Set the bind address in the server toml file.""" - toml = self.read() # Set the bind address @@ -172,14 +165,12 @@ def set_rathole_listener_addr(self, bind_addr: str) -> None: def get_rathole_listener_addr(self) -> str: """Get the bind address from the server toml file.""" - toml = self.read() return toml["server"]["bind_addr"] def add_config(self, config: RatholeConfig) -> None: """Add a new config to the toml file.""" - toml = self.read() # Add the new config @@ -198,7 +189,6 @@ def add_config(self, config: RatholeConfig) -> None: def remove_config(self, uuid: str) -> None: """Remove a config from the toml file.""" - toml = self.read() # Remove the config @@ -214,7 +204,6 @@ def remove_config(self, uuid: str) -> None: def update_config(self, config: RatholeConfig) -> None: """Update a config in the toml file.""" - toml = self.read() # Update the config diff --git a/packages/syft/src/syft/service/network/reverse_tunnel_service.py b/packages/syft/src/syft/service/network/reverse_tunnel_service.py index e155dd3f5b5..8d1f186961f 100644 --- a/packages/syft/src/syft/service/network/reverse_tunnel_service.py +++ b/packages/syft/src/syft/service/network/reverse_tunnel_service.py @@ -18,11 +18,11 @@ def set_client_config( if not rathole_route: raise Exception( "Failed to exchange routes via . " - + f"Peer: {self_server_peer} has no rathole route: {rathole_route}" + + f"Peer: {self_server_peer} has no rathole route: {rathole_route}", ) remote_url = ServerURL( - host_or_ip=remote_server_route.host_or_ip, port=remote_server_route.port + host_or_ip=remote_server_route.host_or_ip, port=remote_server_route.port, ) rathole_remote_addr = remote_url.as_container_host() @@ -44,5 +44,5 @@ def clear_client_config(self, self_server_peer: ServerPeer) -> None: def clear_server_config(self, remote_peer: ServerPeer) -> None: self.builder.remove_host_from_server( - str(remote_peer.id), server_name=remote_peer.name + str(remote_peer.id), server_name=remote_peer.name, ) diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 04e758b1bdc..a5b9ccf34f3 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -3,7 +3,6 @@ # stdlib import secrets -from typing import Any from typing import TYPE_CHECKING # third party @@ -11,18 +10,18 @@ # relative from ...abstract_server import AbstractServer -from ...client.client import HTTPConnection -from ...client.client import PythonConnection -from ...client.client import ServerConnection -from ...client.client import SyftClient +from ...client.client import ( + HTTPConnection, + PythonConnection, + ServerConnection, + SyftClient, +) from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.transforms import TransformContext from ...types.uid import UID -from ..context import AuthedServiceContext -from ..context import ServerServiceContext +from ..context import AuthedServiceContext, ServerServiceContext from ..response import SyftError if TYPE_CHECKING: @@ -33,29 +32,31 @@ @serializable(canonical_name="ServerRoute", version=1) class ServerRoute: def client_with_context( - self, context: ServerServiceContext + self, context: ServerServiceContext, ) -> SyftClient | SyftError: - """ - Convert the current route (self) to a connection (either HTTP, Veilid or Python) + """Convert the current route (self) to a connection (either HTTP, Veilid or Python) and create a SyftClient from the connection. Args: + ---- context (ServerServiceContext): The ServerServiceContext containing the server information. Returns: + ------- SyftClient | SyftError: Returns the created SyftClient, or SyftError if the client type is not valid or if the context's server is None. + """ connection = route_to_connection(route=self, context=context) client_type = connection.get_client_type() if isinstance(client_type, SyftError): return client_type return client_type( - connection=connection, credentials=context.server.signing_key + connection=connection, credentials=context.server.signing_key, ) def validate_with_context( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> ServerPeer | SyftError: # relative from .server_peer import ServerPeer @@ -74,7 +75,7 @@ def validate_with_context( try: # Verifying if the challenge is valid context.server.verify_key.verify_key.verify( - random_challenge, challenge_signature + random_challenge, challenge_signature, ) except Exception: return SyftError(message="Signature Verification Failed in ping") @@ -100,7 +101,7 @@ class HTTPServerRoute(SyftObject, ServerRoute): priority: int = 1 rtunnel_token: str | None = None - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, HTTPServerRoute): return False return hash(self) == hash(other) @@ -150,7 +151,7 @@ def with_server(cls, server: AbstractServer) -> Self: worker_settings = WorkerSettings.from_server(server) return cls(id=worker_settings.id, worker_settings=worker_settings) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, PythonServerRoute): return False return hash(self) == hash(other) @@ -177,7 +178,7 @@ class VeilidServerRoute(SyftObject, ServerRoute): proxy_target_uid: UID | None = None priority: int = 1 - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, VeilidServerRoute): return False return hash(self) == hash(other) @@ -191,7 +192,7 @@ def __hash__(self) -> int: def route_to_connection( - route: ServerRoute, context: TransformContext | None = None + route: ServerRoute, context: TransformContext | None = None, ) -> ServerConnection: if isinstance(route, HTTPServerRoute): return route.to(HTTPConnection, context=context) diff --git a/packages/syft/src/syft/service/network/server_peer.py b/packages/syft/src/syft/service/network/server_peer.py index 941396820d5..664b3d1f7ae 100644 --- a/packages/syft/src/syft/service/network/server_peer.py +++ b/packages/syft/src/syft/service/network/server_peer.py @@ -1,36 +1,32 @@ # stdlib +import logging from collections.abc import Callable from enum import Enum -import logging # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...abstract_server import ServerType -from ...client.client import ServerConnection -from ...client.client import SyftClient +from ...client.client import ServerConnection, SyftClient from ...serde.serializable import serializable -from ...server.credentials import SyftSigningKey -from ...server.credentials import SyftVerifyKey +from ...server.credentials import SyftSigningKey, SyftVerifyKey from ...service.response import SyftError from ...types.datetime import DateTime -from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, PartialSyftObject, SyftObject from ...types.transforms import TransformContext from ...types.uid import UID from ..context import ServerServiceContext from ..metadata.server_metadata import ServerMetadata -from .routes import HTTPServerRoute -from .routes import PythonServerRoute -from .routes import ServerRoute -from .routes import ServerRouteType -from .routes import VeilidServerRoute -from .routes import connection_to_route -from .routes import route_to_connection +from .routes import ( + HTTPServerRoute, + PythonServerRoute, + ServerRoute, + ServerRouteType, + VeilidServerRoute, + connection_to_route, + route_to_connection, +) logger = logging.getLogger(__name__) @@ -73,17 +69,19 @@ def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: """Check if a route exists in self.server_routes Args: + ---- route: the route to be checked. For now it can be either HTTPServerRoute or PythonServerRoute Returns: + ------- if the route exists, returns (True, index of the existed route in self.server_routes) if the route does not exist returns (False, None) - """ + """ if route: if not isinstance( - route, HTTPServerRoute | PythonServerRoute | VeilidServerRoute + route, HTTPServerRoute | PythonServerRoute | VeilidServerRoute, ): raise ValueError(f"Unsupported route type: {type(route)}") for i, r in enumerate(self.server_routes): @@ -93,24 +91,26 @@ def existed_route(self, route: ServerRouteType) -> tuple[bool, int | None]: return (False, None) def update_route_priority(self, route: ServerRoute) -> ServerRoute: - """ - Assign the new_route's priority to be current max + 1 + """Assign the new_route's priority to be current max + 1 Args: + ---- route (ServerRoute): The new route whose priority is to be updated. Returns: + ------- ServerRoute: The new route with the updated priority + """ current_max_priority: int = max(route.priority for route in self.server_routes) route.priority = current_max_priority + 1 return route def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: - """ - Picks the route with the highest priority from the list of server routes. + """Picks the route with the highest priority from the list of server routes. Args: + ---- oldest (bool): If True, picks the oldest route to have the highest priority, meaning the route with min priority value. @@ -118,6 +118,7 @@ def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: meaning the route with max priority value. Returns: + ------- ServerRoute: The route with the highest priority. """ @@ -126,19 +127,19 @@ def pick_highest_priority_route(self, oldest: bool = True) -> ServerRoute: if oldest: if route.priority < highest_priority_route.priority: highest_priority_route = route - else: - if route.priority > highest_priority_route.priority: - highest_priority_route = route + elif route.priority > highest_priority_route.priority: + highest_priority_route = route return highest_priority_route def update_route(self, route: ServerRoute) -> None: - """ - Update the route for the server. + """Update the route for the server. If the route already exists, return it. If the route is new, assign it to have the priority of (current_max + 1) Args: + ---- route (ServerRoute): The new route to be added to the peer. + """ existed, idx = self.existed_route(route) if existed: @@ -148,8 +149,7 @@ def update_route(self, route: ServerRoute) -> None: self.server_routes.append(new_route) def update_routes(self, new_routes: list[ServerRoute]) -> None: - """ - Update multiple routes of the server peer. + """Update multiple routes of the server peer. This method takes a list of new routes as input. It first updates the priorities of the new routes. @@ -158,32 +158,37 @@ def update_routes(self, new_routes: list[ServerRoute]) -> None: If it doesn't, it adds the new route to the server. Args: + ---- new_routes (list[ServerRoute]): The new routes to be added to the server. Returns: + ------- None + """ for new_route in new_routes: self.update_route(new_route) def update_existed_route_priority( - self, route: ServerRoute, priority: int | None = None + self, route: ServerRoute, priority: int | None = None, ) -> ServerRouteType | SyftError: - """ - Update the priority of an existed route. + """Update the priority of an existed route. Args: + ---- route (ServerRoute): The route whose priority is to be updated. priority (int | None): The new priority of the route. If not given, the route will be assigned with the highest priority. Returns: + ------- ServerRoute: The route with updated priority if the route exists SyftError: If the route does not exist or the priority is invalid + """ if priority is not None and priority <= 0: return SyftError( - message="Priority must be greater than 0. Now it is {priority}." + message="Priority must be greater than 0. Now it is {priority}.", ) existed, index = self.existed_route(route=route) @@ -195,7 +200,7 @@ def update_existed_route_priority( self.server_routes[index].priority = priority else: self.server_routes[index].priority = self.update_route_priority( - route + route, ).priority return self.server_routes[index] @@ -212,16 +217,17 @@ def from_client(client: SyftClient) -> "ServerPeer": @property def latest_added_route(self) -> ServerRoute | None: - """ - Returns the latest added route from the list of server routes. + """Returns the latest added route from the list of server routes. - Returns: + Returns + ------- ServerRoute | None: The latest added route, or None if there are no routes. + """ return self.server_routes[-1] if self.server_routes else None def client_with_context( - self, context: ServerServiceContext + self, context: ServerServiceContext, ) -> Result[type[SyftClient], str]: # third party @@ -239,7 +245,7 @@ def client_with_context( if isinstance(client_type, SyftError): return Err(client_type.message) return Ok( - client_type(connection=connection, credentials=context.server.signing_key) + client_type(connection=connection, credentials=context.server.signing_key), ) def client_with_key(self, credentials: SyftSigningKey) -> SyftClient | SyftError: @@ -270,22 +276,24 @@ def get_rtunnel_route(self) -> HTTPServerRoute | None: return None def delete_route(self, route: ServerRouteType) -> SyftError | None: - """ - Deletes a route from the peer's route list. + """Deletes a route from the peer's route list. Takes O(n) where is n is the number of routes in self.server_routes. Args: + ---- route (ServerRouteType): The route to be deleted; Returns: + ------- SyftError: If failing to delete server route + """ if route: try: self.server_routes = [r for r in self.server_routes if r != route] except Exception as e: return SyftError( - message=f"Error deleting route with id {route.id}. Exception: {e}" + message=f"Error deleting route with id {route.id}. Exception: {e}", ) return None diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index c5b9e0c084e..af9f3983c43 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -9,11 +9,8 @@ from ...types.datetime import DateTime from ..context import AuthedServiceContext from ..response import SyftError -from .network_service import NetworkService -from .network_service import ServerPeerAssociationStatus -from .server_peer import ServerPeer -from .server_peer import ServerPeerConnectionStatus -from .server_peer import ServerPeerUpdate +from .network_service import NetworkService, ServerPeerAssociationStatus +from .server_peer import ServerPeer, ServerPeerConnectionStatus, ServerPeerUpdate logger = logging.getLogger(__name__) @@ -28,20 +25,21 @@ def __init__(self) -> None: self._stop = False def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | None: - """ - Perform a health check on the peers in the network stash. + """Perform a health check on the peers in the network stash. - If peer is accessible, ping the peer. - Peer is connected to the network. Args: + ---- context (AuthedServiceContext): The authenticated service context. Returns: + ------- None - """ + """ network_service = cast( - NetworkService, context.server.get_service(NetworkService) + NetworkService, context.server.get_service(NetworkService), ) network_stash = network_service.stash @@ -60,7 +58,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | No peer_client = peer.client_with_context(context=context) if peer_client.is_err(): logger.error( - f"Failed to create client for peer: {peer}: {peer_client.err()}" + f"Failed to create client for peer: {peer}: {peer_client.err()}", ) peer_update.ping_status = ServerPeerConnectionStatus.TIMEOUT peer_client = None @@ -73,7 +71,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> SyftError | No if peer_client is not None: peer_client = peer_client.ok() peer_status = peer_client.api.services.network.check_peer_association( - peer_id=context.server.id + peer_id=context.server.id, ) peer_update.ping_status = ( ServerPeerConnectionStatus.ACTIVE @@ -113,13 +111,13 @@ def run(self, context: AuthedServiceContext) -> None: if self.thread is not None: logger.info( f"Peer health check task is already running in thread " - f"{self.thread.name} with ID: {self.thread.ident}." + f"{self.thread.name} with ID: {self.thread.ident}.", ) else: self.thread = threading.Thread(target=self._run, args=(context,)) logger.info( f"Start running peers health check in thread " - f"{self.thread.name} with ID: {self.thread.ident}." + f"{self.thread.name} with ID: {self.thread.ident}.", ) self.thread.start() diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py index fec2810b02a..696f8ca9973 100644 --- a/packages/syft/src/syft/service/notification/email_templates.py +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -1,7 +1,6 @@ # stdlib from datetime import datetime -from typing import TYPE_CHECKING -from typing import cast +from typing import TYPE_CHECKING, cast # relative from ...serde.serializable import serializable @@ -38,12 +37,12 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s raise Exception("User not found!") user.reset_token = user_service.generate_new_password_reset_token( - context.server.settings.pwd_token_config + context.server.settings.pwd_token_config, ) user.reset_token_date = datetime.now() result = user_service.stash.update( - credentials=context.credentials, user=user, has_permission=True + credentials=context.credentials, user=user, has_permission=True, ) if result.is_err(): raise Exception("Couldn't update the user password") @@ -118,7 +117,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) -> def email_body(notification: "Notification", context: AuthedServiceContext) -> str: user_service = context.server.get_service("userservice") admin_name = user_service.get_by_verify_key( - user_service.admin_verify_key() + user_service.admin_verify_key(), ).name head = ( diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 9873a78ad2c..610a62ede53 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -8,21 +8,21 @@ from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext from ..notifier.notifier import NotifierSettings -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL +from ..response import SyftError, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + ADMIN_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, +) from .notification_stash import NotificationStash -from .notifications import CreateNotification -from .notifications import LinkedObject -from .notifications import Notification -from .notifications import NotificationStatus -from .notifications import ReplyNotification +from .notifications import ( + CreateNotification, + LinkedObject, + Notification, + NotificationStatus, + ReplyNotification, +) @instrument @@ -37,7 +37,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="notifications.send", name="send") def send( - self, context: AuthedServiceContext, notification: CreateNotification + self, context: AuthedServiceContext, notification: CreateNotification, ) -> Notification | SyftError: """Send a new notification""" new_notification = notification.to(Notification, context=context) @@ -45,12 +45,12 @@ def send( # Add read permissions to person receiving this message permissions = [ ActionObjectREAD( - uid=new_notification.id, credentials=new_notification.to_user_verify_key - ) + uid=new_notification.id, credentials=new_notification.to_user_verify_key, + ), ] result = self.stash.set( - context.credentials, new_notification, add_permissions=permissions + context.credentials, new_notification, add_permissions=permissions, ) notifier_service = context.server.get_service("notifierservice") @@ -70,11 +70,11 @@ def reply( reply: ReplyNotification, ) -> ReplyNotification | SyftError: msg = self.stash.get_by_uid( - credentials=context.credentials, uid=reply.target_msg + credentials=context.credentials, uid=reply.target_msg, ) if msg.is_err(): return SyftError( - message=f"The target notification id {reply.target_msg} was not found!. Error: {msg.err()}" + message=f"The target notification id {reply.target_msg} was not found!. Error: {msg.err()}", ) msg = msg.ok() reply.from_user_verify_key = context.credentials @@ -83,7 +83,7 @@ def reply( if result.is_err(): return SyftError( - message=f"Couldn't add a new notification reply in the target notification. Error: {result.err()}" + message=f"Couldn't add a new notification reply in the target notification. Error: {result.err()}", ) return result.ok() @@ -162,10 +162,10 @@ def get_all( roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_all_sent( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[Notification] | SyftError: result = self.stash.get_all_sent_for_verify_key( - context.credentials, context.credentials + context.credentials, context.credentials, ) if result.err(): return SyftError(message=str(result.err())) @@ -181,7 +181,7 @@ def get_all_for_status( status: NotificationStatus, ) -> list[Notification] | SyftError: result = self.stash.get_all_by_verify_key_for_status( - context.credentials, verify_key=context.credentials, status=status + context.credentials, verify_key=context.credentials, status=status, ) if result.err(): return SyftError(message=str(result.err())) @@ -218,10 +218,10 @@ def get_all_unread( @service_method(path="notifications.mark_as_read", name="mark_as_read") def mark_as_read( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Notification | SyftError: result = self.stash.update_notification_status( - context.credentials, uid=uid, status=NotificationStatus.READ + context.credentials, uid=uid, status=NotificationStatus.READ, ) if result.is_err(): return SyftError(message=str(result.err())) @@ -229,10 +229,10 @@ def mark_as_read( @service_method(path="notifications.mark_as_unread", name="mark_as_unread") def mark_as_unread( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Notification | SyftError: result = self.stash.update_notification_status( - context.credentials, uid=uid, status=NotificationStatus.UNREAD + context.credentials, uid=uid, status=NotificationStatus.UNREAD, ) if result.is_err(): return SyftError(message=str(result.err())) @@ -244,7 +244,7 @@ def mark_as_unread( roles=GUEST_ROLE_LEVEL, ) def resolve_object( - self, context: AuthedServiceContext, linked_obj: LinkedObject + self, context: AuthedServiceContext, linked_obj: LinkedObject, ) -> Notification | SyftError: service = context.server.get_service(linked_obj.service_type) result = service.resolve_link(context=context, linked_obj=linked_obj) @@ -255,14 +255,14 @@ def resolve_object( @service_method(path="notifications.clear", name="clear") def clear(self, context: AuthedServiceContext) -> SyftError | SyftSuccess: result = self.stash.delete_all_for_verify_key( - credentials=context.credentials, verify_key=context.credentials + credentials=context.credentials, verify_key=context.credentials, ) if result.is_ok(): return SyftSuccess(message="All notifications cleared !!") return SyftError(message=str(result.err())) def filter_by_obj( - self, context: AuthedServiceContext, obj_uid: UID + self, context: AuthedServiceContext, obj_uid: UID, ) -> Notification | SyftError: notifications = self.stash.get_all(context.credentials) if notifications.is_err(): diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index 8521080fee5..c81f3d261fd 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -1,29 +1,28 @@ # stdlib # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + PartitionKey, + PartitionSettings, + QueryKeys, +) from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.uid import UID from ...util.telemetry import instrument -from .notifications import Notification -from .notifications import NotificationStatus +from .notifications import Notification, NotificationStatus FromUserVerifyKeyPartitionKey = PartitionKey( - key="from_user_verify_key", type_=SyftVerifyKey + key="from_user_verify_key", type_=SyftVerifyKey, ) ToUserVerifyKeyPartitionKey = PartitionKey( - key="to_user_verify_key", type_=SyftVerifyKey + key="to_user_verify_key", type_=SyftVerifyKey, ) StatusPartitionKey = PartitionKey(key="status", type_=NotificationStatus) @@ -42,29 +41,29 @@ class NotificationStash(BaseUIDStoreStash): ) def get_all_inbox_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> Result[list[Notification], str]: qks = QueryKeys( qks=[ ToUserVerifyKeyPartitionKey.with_obj(verify_key), - ] + ], ) return self.get_all_for_verify_key( - credentials=credentials, verify_key=verify_key, qks=qks + credentials=credentials, verify_key=verify_key, qks=qks, ) def get_all_sent_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> Result[list[Notification], str]: qks = QueryKeys( qks=[ FromUserVerifyKeyPartitionKey.with_obj(verify_key), - ] + ], ) return self.get_all_for_verify_key(credentials, verify_key=verify_key, qks=qks) def get_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, qks: QueryKeys + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, qks: QueryKeys, ) -> Result[list[Notification], str]: if isinstance(verify_key, str): verify_key = SyftVerifyKey.from_string(verify_key) @@ -84,7 +83,7 @@ def get_all_by_verify_key_for_status( qks=[ ToUserVerifyKeyPartitionKey.with_obj(verify_key), StatusPartitionKey.with_obj(status), - ] + ], ) return self.query_all( credentials, @@ -100,12 +99,12 @@ def get_notification_for_linked_obj( qks = QueryKeys( qks=[ LinkedObjectPartitionKey.with_obj(linked_obj), - ] + ], ) return self.query_one(credentials=credentials, qks=qks) def update_notification_status( - self, credentials: SyftVerifyKey, uid: UID, status: NotificationStatus + self, credentials: SyftVerifyKey, uid: UID, status: NotificationStatus, ) -> Result[Notification, str]: result = self.get_by_uid(credentials, uid=uid) if result.is_err(): @@ -118,7 +117,7 @@ def update_notification_status( return self.update(credentials, obj=notification) def delete_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> Result[bool, str]: result = self.get_all_inbox_for_verify_key( credentials, diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index 6a95861566f..5430822d63b 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -4,19 +4,19 @@ from typing import cast # relative -from ...client.api import APIRegistry -from ...client.api import SyftAPI +from ...client.api import APIRegistry, SyftAPI from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.transforms import TransformContext -from ...types.transforms import add_credentials_for_key -from ...types.transforms import add_server_uid_for_key -from ...types.transforms import generate_id -from ...types.transforms import transform +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject +from ...types.transforms import ( + TransformContext, + add_credentials_for_key, + add_server_uid_for_key, + generate_id, + transform, +) from ...types.uid import UID from ...util import options from ...util.colors import SURFACE @@ -110,7 +110,7 @@ def mark_read(self) -> None: api: SyftAPI = cast( SyftAPI, APIRegistry.api_for( - self.server_uid, user_verify_key=self.syft_client_verify_key + self.server_uid, user_verify_key=self.syft_client_verify_key, ), ) return api.services.notifications.mark_as_read(uid=self.id) @@ -119,7 +119,7 @@ def mark_unread(self) -> None: api: SyftAPI = cast( SyftAPI, APIRegistry.api_for( - self.server_uid, user_verify_key=self.syft_client_verify_key + self.server_uid, user_verify_key=self.syft_client_verify_key, ), ) return api.services.notifications.mark_as_unread(uid=self.id) diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 1f5e34cf1c3..171f4e634d1 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -13,30 +13,28 @@ # third party from pydantic import BaseModel -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default +from ...types.syft_object import ( + SYFT_OBJECT_VERSION_1, + SYFT_OBJECT_VERSION_2, + SyftObject, +) +from ...types.transforms import drop, make_set_default from ..context import AuthedServiceContext from ..notification.notifications import Notification -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from .notifier_enums import NOTIFIERS from .smtp_client import SMTPClient class BaseNotifier(BaseModel): def send( - self, target: SyftVerifyKey, notification: Notification + self, target: SyftVerifyKey, notification: Notification, ) -> SyftSuccess | SyftError: return SyftError(message="Not implemented") @@ -89,7 +87,7 @@ def check_credentials( ) def send( - self, context: AuthedServiceContext, notification: Notification + self, context: AuthedServiceContext, notification: Notification, ) -> Result[Ok, Err]: try: user_service = context.server.get_service("userservice") @@ -98,17 +96,17 @@ def send( if not receiver.notifications_enabled[NOTIFIERS.EMAIL]: return Ok( - "Email notifications are disabled for this user." + "Email notifications are disabled for this user.", ) # TODO: Should we return an error here? receiver_email = receiver.email if notification.email_template: subject = notification.email_template.email_title( - notification, context=context + notification, context=context, ) body = notification.email_template.email_body( - notification, context=context + notification, context=context, ) else: subject = notification.subject @@ -118,12 +116,12 @@ def send( receiver_email = [receiver_email] self.smtp_client.send( - sender=self.sender, receiver=receiver_email, subject=subject, body=body + sender=self.sender, receiver=receiver_email, subject=subject, body=body, ) return Ok("Email sent successfully!") except Exception: return Err( - "Some notifications failed to be delivered. Please check the health of the mailing server." + "Some notifications failed to be delivered. Please check the health of the mailing server.", ) @@ -251,13 +249,14 @@ def send_notifications( return Ok("Notification sent successfully!") def select_notifiers(self, notification: Notification) -> list[BaseNotifier]: - """ - Return a list of the notifiers enabled for the given notification" + """Return a list of the notifiers enabled for the given notification" Args: + ---- notification (Notification): The notification object Returns: List[BaseNotifier]: A list of enabled notifier objects + """ notifier_objs = [] for notifier_type in notification.notifier_types: @@ -274,7 +273,7 @@ def select_notifiers(self, notification: Notification) -> list[BaseNotifier]: password=self.email_password, sender=self.email_sender, server=self.email_server, - ) + ), ) # If notifier is not email, we just create the notifier object # TODO: Add the other notifiers, and its auth methods diff --git a/packages/syft/src/syft/service/notifier/notifier_enums.py b/packages/syft/src/syft/service/notifier/notifier_enums.py index f8c2d887ff4..bfc3d717a28 100644 --- a/packages/syft/src/syft/service/notifier/notifier_enums.py +++ b/packages/syft/src/syft/service/notifier/notifier_enums.py @@ -1,6 +1,5 @@ # stdlib -from enum import Enum -from enum import auto +from enum import Enum, auto # relative from ...serde.serializable import serializable diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index c8c09ba3d50..90525a2d36c 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -1,13 +1,11 @@ # stdlib -from datetime import datetime import logging import traceback +from datetime import datetime # third party from pydantic import EmailStr -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...abstract_server import AbstractServer @@ -16,14 +14,14 @@ from ..context import AuthedServiceContext from ..notification.email_templates import PasswordResetTemplate from ..notification.notifications import Notification -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from ..service import AbstractService -from .notifier import NotificationPreferences -from .notifier import NotifierSettings -from .notifier import UserNotificationActivity -from .notifier_enums import EMAIL_TYPES -from .notifier_enums import NOTIFIERS +from .notifier import ( + NotificationPreferences, + NotifierSettings, + UserNotificationActivity, +) +from .notifier_enums import EMAIL_TYPES, NOTIFIERS from .notifier_stash import NotifierStash logger = logging.getLogger(__name__) @@ -45,9 +43,11 @@ def settings( # Maybe just notifier.settings """Get Notifier Settings Args: + ---- context: The request context Returns: Union[NotifierSettings, SyftError]: Notifier Settings or SyftError + """ result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -70,7 +70,7 @@ def user_settings( ) def set_notifier_active_to_true( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> SyftSuccess | SyftError: result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -86,10 +86,9 @@ def set_notifier_active_to_true( return SyftSuccess(message="notifier.active set to true.") def set_notifier_active_to_false( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> SyftSuccess: - """ - Essentially a duplicate of turn_off method. + """Essentially a duplicate of turn_off method. """ result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -117,17 +116,20 @@ def turn_on( """Turn on email notifications. Args: + ---- email_username (Optional[str]): Email server username. Defaults to None. email_password (Optional[str]): Email email server password. Defaults to None. sender_email (Optional[str]): Email sender email. Defaults to None. + Returns: + ------- Union[SyftSuccess, SyftError]: A union type representing the success or error response. Raises: + ------ None """ - result = self.stash.get(credentials=context.credentials) # 1 - If something went wrong at db level, return the error @@ -148,7 +150,7 @@ def turn_on( # 3 - If notifier doesn't have a email server / port and the user didn't provide them, return an error if not (email_server and email_port) and not notifier.email_server: return SyftError( - message="You must provide both server and port to enable notifications." + message="You must provide both server and port to enable notifications.", ) logging.debug("Got notifier from db") @@ -158,7 +160,7 @@ def turn_on( return SyftError( message="No valid token has been added to the datasite." + "You can add a pair of SMTP credentials via " - + ".settings.enable_notifications(email=<>, password=<>)" + + ".settings.enable_notifications(email=<>, password=<>)", ) else: logging.debug("No new credentials provided. Using existing ones.") @@ -175,7 +177,7 @@ def turn_on( if validation_result.is_err(): logging.error(f"Invalid SMTP credentials {validation_result.err()}") return SyftError( - message="Invalid SMTP credentials. Please check your username and password." + message="Invalid SMTP credentials. Please check your username and password.", ) notifier.email_password = email_password @@ -189,7 +191,7 @@ def turn_on( # Email sender verification if not email_sender and not notifier.email_sender: return SyftError( - message="You must provide a sender email address to enable notifications." + message="You must provide a sender email address to enable notifications.", ) # If email_rate_limit isn't defined yet. @@ -201,13 +203,13 @@ def turn_on( EmailStr._validate(email_sender) except ValueError: return SyftError( - message="Invalid sender email address. Please check your email address." + message="Invalid sender email address. Please check your email address.", ) notifier.email_sender = email_sender notifier.active = True logging.debug( - "Email credentials are valid. Updating the notifier settings in the db." + "Email credentials are valid. Updating the notifier settings in the db.", ) result = self.stash.update(credentials=context.credentials, settings=notifier) @@ -225,11 +227,9 @@ def turn_off( self, context: AuthedServiceContext, ) -> SyftSuccess | SyftError: - """ - Turn off email notifications service. + """Turn off email notifications service. PySyft notifications will still work. """ - result = self.stash.get(credentials=context.credentials) if result.is_err(): @@ -249,23 +249,20 @@ def turn_off( return SyftSuccess(message="Notifications disabled succesfullly") def activate( - self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL + self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL, ) -> SyftSuccess | SyftError: - """ - Activate email notifications for the authenticated user. + """Activate email notifications for the authenticated user. This will only work if the datasite owner has enabled notifications. """ - user_service = context.server.get_service("userservice") return user_service.enable_notifications(context, notifier_type=notifier_type) def deactivate( - self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL + self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL, ) -> SyftSuccess | SyftError: """Deactivate email notifications for the authenticated user This will only work if the datasite owner has enabled notifications. """ - user_service = context.server.get_service("userservice") return user_service.disable_notifications(context, notifier_type=notifier_type) @@ -283,6 +280,7 @@ def init_notifier( If not, it will create a new one. Args: + ---- server: Server to initialize the notifier active: If notifier should be active email_username: Email username to send notifications @@ -291,6 +289,7 @@ def init_notifier( Exception: If something went wrong Returns: Union: SyftSuccess or SyftError + """ try: # Create a new NotifierStash since its a static method. @@ -339,7 +338,7 @@ def init_notifier( raise Exception(f"Error initializing notifier. \n {traceback.format_exc()}") def set_email_rate_limit( - self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int + self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int, ) -> SyftSuccess | SyftError: notifier = self.stash.get(context.credentials) if notifier.is_err(): @@ -357,14 +356,14 @@ def set_email_rate_limit( # This is not a public API. # This method is used by other services to dispatch notifications internally def dispatch_notification( - self, context: AuthedServiceContext, notification: Notification + self, context: AuthedServiceContext, notification: Notification, ) -> SyftError: admin_key = context.server.get_service("userservice").admin_verify_key() notifier = self.stash.get(admin_key) if notifier.is_err(): return SyftError( message="The mail service ran out of quota or some notifications failed to be delivered.\n" - + "Please check the health of the mailing server." + + "Please check the health of the mailing server.", ) notifier = notifier.ok() @@ -378,7 +377,7 @@ def dispatch_notification( # If there's no user activity if user_activity is None: notifier.email_activity[notification.email_template.__name__][ - notification.to_user_verify_key, None + notification.to_user_verify_key, None, ] = UserNotificationActivity(count=1, date=datetime.now()) else: # If there's a previous user activity current_state: UserNotificationActivity = notifier.email_activity[ @@ -387,7 +386,7 @@ def dispatch_notification( date_refresh = abs(datetime.now() - current_state.date).days > 1 limit = notifier.email_rate_limit.get( - notification.email_template.__name__, 0 + notification.email_template.__name__, 0, ) still_in_limit = current_state.count < limit # Time interval reseted. @@ -401,13 +400,13 @@ def dispatch_notification( else: return SyftError( message="Couldn't send the email. You have surpassed the" - + " email threshold limit. Please try again later." + + " email threshold limit. Please try again later.", ) else: notifier.email_activity[notification.email_template.__name__] = { notification.to_user_verify_key: UserNotificationActivity( - count=1, date=datetime.now() - ) + count=1, date=datetime.now(), + ), } result = self.stash.update(credentials=admin_key, settings=notifier) @@ -415,7 +414,7 @@ def dispatch_notification( return SyftError(message="Couldn't update the notifier.") resp = notifier.send_notifications( - context=context, notification=notification + context=context, notification=notification, ) if resp.is_err(): return SyftError(message=resp.err()) diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 1d28c90a380..5c05451f71d 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -1,18 +1,18 @@ # stdlib # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...service.response import SyftError -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings +from ...store.document_store import ( + BaseStash, + DocumentStore, + PartitionKey, + PartitionSettings, +) from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission @@ -27,7 +27,7 @@ class NotifierStash(BaseStash): object_type = NotifierSettings settings: PartitionSettings = PartitionSettings( - name=NotifierSettings.__canonical_name__, object_type=NotifierSettings + name=NotifierSettings.__canonical_name__, object_type=NotifierSettings, ) def __init__(self, store: DocumentStore) -> None: @@ -44,7 +44,7 @@ def get(self, credentials: SyftVerifyKey) -> Result[NotifierSettings, Err]: settings = result.ok() if len(settings) == 0: return Ok( - None + None, ) # TODO: Stash shouldn't be empty after init. Return Err instead? result = settings[ 0 @@ -66,7 +66,7 @@ def set( if result.is_err(): return Err(SyftError(message=result.err())) return super().set( - credentials=credentials, obj=result.ok() + credentials=credentials, obj=result.ok(), ) # TODO check if result isInstance(Ok) def update( @@ -80,5 +80,5 @@ def update( if result.is_err(): return Err(SyftError(message=result.err())) return super().update( - credentials=credentials, obj=result.ok() + credentials=credentials, obj=result.ok(), ) # TODO check if result isInstance(Ok) diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py index f7041c9f722..9f99f568ace 100644 --- a/packages/syft/src/syft/service/notifier/smtp_client.py +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -1,13 +1,11 @@ # stdlib +import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -import smtplib # third party from pydantic import BaseModel -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result class SMTPClient(BaseModel): @@ -39,7 +37,7 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non msg.attach(MIMEText(body, "html")) with smtplib.SMTP( - self.server, self.port, timeout=self.SOCKET_TIMEOUT + self.server, self.port, timeout=self.SOCKET_TIMEOUT, ) as server: server.ehlo() if server.has_extn("STARTTLS"): @@ -52,12 +50,14 @@ def send(self, sender: str, receiver: list[str], subject: str, body: str) -> Non @classmethod def check_credentials( - cls, server: str, port: int, username: str, password: str + cls, server: str, port: int, username: str, password: str, ) -> Result[Ok, Err]: """Check if the credentials are valid. - Returns: + Returns + ------- bool: True if the credentials are valid, False otherwise. + """ try: with smtplib.SMTP(server, port, timeout=cls.SOCKET_TIMEOUT) as smtp_server: diff --git a/packages/syft/src/syft/service/object_search/object_migration_state.py b/packages/syft/src/syft/service/object_search/object_migration_state.py index 1b81dbb1d07..9d32f1fd094 100644 --- a/packages/syft/src/syft/service/object_search/object_migration_state.py +++ b/packages/syft/src/syft/service/object_search/object_migration_state.py @@ -6,13 +6,17 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftMigrationRegistry -from ...types.syft_object import SyftObject +from ...store.document_store import ( + BaseStash, + DocumentStore, + PartitionKey, + PartitionSettings, +) +from ...types.syft_object import ( + SYFT_OBJECT_VERSION_1, + SyftMigrationRegistry, + SyftObject, +) from ..action.action_permissions import ActionObjectPermission @@ -76,7 +80,7 @@ def set( ) def get_by_name( - self, canonical_name: str, credentials: SyftVerifyKey + self, canonical_name: str, credentials: SyftVerifyKey, ) -> Result[SyftObjectMigrationState, str]: qks = KlassNamePartitionKey.with_obj(canonical_name) return self.query_one(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 386aaa926a9..c41fe360258 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -3,19 +3,19 @@ # third party from pydantic import model_validator -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...client.api import APIRegistry from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, +) from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 @@ -26,11 +26,8 @@ from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext from ..response import SyftError -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL, GUEST_ROLE_LEVEL CreatedAtPartitionKey = PartitionKey(key="created_at", type_=DateTime) UserCodeIdPartitionKey = PartitionKey(key="user_code_id", type_=UID) @@ -91,10 +88,8 @@ def from_ids( input_ids: dict[str, UID] | None = None, ) -> "ExecutionOutput": # relative - from ..code.user_code_service import UserCode - from ..code.user_code_service import UserCodeService - from ..job.job_service import Job - from ..job.job_service import JobService + from ..code.user_code_service import UserCode, UserCodeService + from ..job.job_service import Job, JobService if isinstance(output_ids, UID): output_ids = [output_ids] @@ -132,7 +127,7 @@ def outputs(self) -> list[ActionObject] | dict[str, ActionObject] | None: ) if api is None: raise ValueError( - f"Can't access the api. Please log in to {self.syft_server_location}" + f"Can't access the api. Please log in to {self.syft_server_location}", ) action_service = api.services.action @@ -161,14 +156,16 @@ def input_id_list(self) -> list[UID]: return [] def check_input_ids(self, kwargs: dict[str, UID]) -> bool: - """ - Checks the input IDs against the stored input IDs. + """Checks the input IDs against the stored input IDs. Args: + ---- kwargs (dict[str, UID]): A dictionary containing the input IDs to be checked. Returns: + ------- bool: True if the input IDs are valid, False otherwise. + """ if not self.input_ids: return True @@ -194,7 +191,7 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: class OutputStash(BaseUIDStoreStash): object_type = ExecutionOutput settings: PartitionSettings = PartitionSettings( - name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput + name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput, ) def __init__(self, store: DocumentStore) -> None: @@ -204,23 +201,23 @@ def __init__(self, store: DocumentStore) -> None: self._object_type = self.object_type def get_by_user_code_id( - self, credentials: SyftVerifyKey, user_code_id: UID + self, credentials: SyftVerifyKey, user_code_id: UID, ) -> Result[list[ExecutionOutput], str]: qks = QueryKeys( qks=[UserCodeIdPartitionKey.with_obj(user_code_id)], ) return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey, ) def get_by_job_id( - self, credentials: SyftVerifyKey, user_code_id: UID + self, credentials: SyftVerifyKey, user_code_id: UID, ) -> Result[ExecutionOutput | None, str]: qks = QueryKeys( qks=[JobIdPartitionKey.with_obj(user_code_id)], ) res = self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey, ) if res.is_err(): return res @@ -234,13 +231,13 @@ def get_by_job_id( return Ok(res[0]) def get_by_output_policy_id( - self, credentials: SyftVerifyKey, output_policy_id: UID + self, credentials: SyftVerifyKey, output_policy_id: UID, ) -> Result[list[ExecutionOutput], str]: qks = QueryKeys( qks=[OutputPolicyIdPartitionKey.with_obj(output_policy_id)], ) return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey, ) @@ -288,7 +285,7 @@ def create( roles=GUEST_ROLE_LEVEL, ) def get_by_user_code_id( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, user_code_id: UID, ) -> list[ExecutionOutput] | SyftError: result = self.stash.get_by_user_code_id( credentials=context.server.verify_key, # type: ignore @@ -338,7 +335,7 @@ def has_output_read_permissions( roles=ADMIN_ROLE_LEVEL, ) def get_by_job_id( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, user_code_id: UID, ) -> ExecutionOutput | None | SyftError: result = self.stash.get_by_job_id( credentials=context.server.verify_key, # type: ignore @@ -354,7 +351,7 @@ def get_by_job_id( roles=GUEST_ROLE_LEVEL, ) def get_by_output_policy_id( - self, context: AuthedServiceContext, output_policy_id: UID + self, context: AuthedServiceContext, output_policy_id: UID, ) -> list[ExecutionOutput] | SyftError: result = self.stash.get_by_output_policy_id( credentials=context.server.verify_key, # type: ignore @@ -370,7 +367,7 @@ def get_by_output_policy_id( roles=GUEST_ROLE_LEVEL, ) def get( - self, context: AuthedServiceContext, id: UID + self, context: AuthedServiceContext, id: UID, ) -> ExecutionOutput | SyftError: result = self.stash.get_by_uid(context.credentials, id) if result.is_ok(): @@ -379,7 +376,7 @@ def get( @service_method(path="output.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[ExecutionOutput] | SyftError: result = self.stash.get_all(context.credentials) if result.is_ok(): diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 4bf96a58c5a..0dfac3dd8f5 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -3,63 +3,50 @@ # stdlib import ast +import hashlib +import inspect +import sys from collections.abc import Callable from copy import deepcopy from enum import Enum -import hashlib -import inspect -from inspect import Parameter -from inspect import Signature +from inspect import Parameter, Signature from io import StringIO -import sys -from typing import Any -from typing import ClassVar +from typing import Any, ClassVar + +import requests +from pydantic import field_validator, model_validator # third party from RestrictedPython import compile_restricted -from pydantic import field_validator -from pydantic import model_validator -import requests -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...abstract_server import ServerType -from ...client.api import APIRegistry -from ...client.api import RemoteFunction -from ...client.api import ServerIdentity +from ...client.api import APIRegistry, RemoteFunction, ServerIdentity from ...serde.recursive_primitives import recursive_serde_register_type from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.document_store import PartitionKey from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.syft_object_registry import SyftObjectRegistry -from ...types.transforms import TransformContext -from ...types.transforms import generate_id -from ...types.transforms import transform +from ...types.transforms import TransformContext, generate_id, transform from ...types.twin_object import TwinObject from ...types.uid import UID from ...util.util import is_interpreter_jupyter from ..action.action_endpoint import CustomEndpointActionObject from ..action.action_object import ActionObject -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionPermission +from ..action.action_permissions import ActionObjectPermission, ActionPermission from ..code.code_parse import GlobalsVisitor from ..code.unparse import unparse -from ..context import AuthedServiceContext -from ..context import ChangeContext -from ..context import ServerServiceContext +from ..context import AuthedServiceContext, ChangeContext, ServerServiceContext from ..dataset.dataset import Asset -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess DEFAULT_USER_POLICY_VERSION = 1 PolicyUserVerifyKeyPartitionKey = PartitionKey( - key="user_verify_key", type_=SyftVerifyKey + key="user_verify_key", type_=SyftVerifyKey, ) PyCodeObject = Any @@ -140,9 +127,7 @@ class UserPolicyStatus(Enum): def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str, UID]]: # relative - from ...client.api import APIRegistry - from ...client.api import RemoteFunction - from ...client.api import ServerIdentity + from ...client.api import APIRegistry, RemoteFunction, ServerIdentity from ...types.twin_object import TwinObject from ..action.action_object import ActionObject @@ -194,7 +179,7 @@ class PolicyRule(SyftObject): requires_input: bool = True def is_met( - self, context: AuthedServiceContext, action_object: ActionObject + self, context: AuthedServiceContext, action_object: ActionObject, ) -> bool: return False @@ -245,7 +230,7 @@ class Matches(PolicyRule): val: UID def is_met( - self, context: AuthedServiceContext, action_object: ActionObject + self, context: AuthedServiceContext, action_object: ActionObject, ) -> bool: return action_object.id == self.val @@ -263,7 +248,7 @@ def is_met(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> bo return True def transform_kwarg( - self, context: AuthedServiceContext, val: Any + self, context: AuthedServiceContext, val: Any, ) -> Result[Any, str]: if isinstance(self.val, UID): if issubclass(self.klass, CustomEndpointActionObject): @@ -284,39 +269,29 @@ class UserOwned(PolicyRule): # str, float, int, bool, dict, list, set, tuple type: ( - type[str] - | type[float] - | type[int] - | type[bool] - | type[dict] - | type[list] - | type[set] - | type[tuple] - | None + type[str | float | int | bool | dict | list | set | tuple] | None ) def is_owned( - self, context: AuthedServiceContext, action_object: ActionObject + self, context: AuthedServiceContext, action_object: ActionObject, ) -> bool: action_store = context.server.get_service("actionservice").store return action_store.has_permission( ActionObjectPermission( - action_object.id, ActionPermission.OWNER, context.credentials - ) + action_object.id, ActionPermission.OWNER, context.credentials, + ), ) def is_met( - self, context: AuthedServiceContext, action_object: ActionObject + self, context: AuthedServiceContext, action_object: ActionObject, ) -> bool: return type(action_object.syft_action_data) == self.type and self.is_owned( - context, action_object + context, action_object, ) def user_code_arg2id(arg: Any) -> UID: - if isinstance(arg, ActionObject): - uid = arg.id - elif isinstance(arg, TwinObject): + if isinstance(arg, ActionObject) or isinstance(arg, TwinObject): uid = arg.id elif isinstance(arg, Asset): uid = arg.action_id @@ -335,7 +310,7 @@ def retrieve_item_from_db(id: UID, context: AuthedServiceContext) -> ActionObjec action_service = context.server.get_service("actionservice") root_context = AuthedServiceContext( - server=context.server, credentials=context.server.verify_key + server=context.server, credentials=context.server.verify_key, ) value = action_service._get( context=root_context, @@ -386,7 +361,7 @@ def _inputs_for_context(self, context: ChangeContext) -> dict | SyftError: user_server_view = ServerIdentity.from_change_context(context) inputs = self.inputs[user_server_view] root_context = AuthedServiceContext( - server=context.server, credentials=context.approving_user_credentials + server=context.server, credentials=context.approving_user_credentials, ).as_root_context() action_service = context.server.get_service("actionservice") @@ -414,7 +389,7 @@ class MixedInputPolicy(InputPolicy): kwarg_rules: dict[ServerIdentity, dict[str, PolicyRule]] def __init__( - self, init_kwargs: Any = None, client: Any = None, *args: Any, **kwargs: Any + self, init_kwargs: Any = None, client: Any = None, *args: Any, **kwargs: Any, ) -> None: if init_kwargs is not None: kwarg_rules = init_kwargs @@ -424,10 +399,10 @@ def __init__( kwarg_rules_current_server = {} for kw, arg in kwargs.items(): if isinstance( - arg, UID | Asset | ActionObject | TwinObject | RemoteFunction + arg, UID | Asset | ActionObject | TwinObject | RemoteFunction, ): kwarg_rules_current_server[kw] = Matches( - kw=kw, val=user_code_arg2id(arg) + kw=kw, val=user_code_arg2id(arg), ) elif arg in [str, float, int, bool, dict, list, set, tuple]: # type: ignore[unreachable] kwarg_rules_current_server[kw] = UserOwned(kw=kw, type=arg) @@ -438,16 +413,16 @@ def __init__( kwarg_rules = {server_identity: kwarg_rules_current_server} super().__init__( - *args, kwarg_rules=kwarg_rules, init_kwargs=kwarg_rules, **kwargs + *args, kwarg_rules=kwarg_rules, init_kwargs=kwarg_rules, **kwargs, ) def transform_kwargs( - self, context: AuthedServiceContext, kwargs: dict[str, Any] + self, context: AuthedServiceContext, kwargs: dict[str, Any], ) -> dict[str, Any]: for _, rules in self.kwarg_rules.items(): for kw, rule in rules.items(): if hasattr(rule, "transform_kwarg"): - res_val = rule.transform_kwarg(context, kwargs.get(kw, None)) + res_val = rule.transform_kwarg(context, kwargs.get(kw)) if res_val.is_err(): return res_val else: @@ -455,7 +430,7 @@ def transform_kwargs( return Ok(kwargs) def find_server_identity( - self, kwargs: dict[str, Any], client: Any = None + self, kwargs: dict[str, Any], client: Any = None, ) -> ServerIdentity: if client is not None: return ServerIdentity.from_api(client.api) @@ -467,7 +442,7 @@ def find_server_identity( # we mostly get the UID here because we don't want to store all those # other objects, so we need to create a global UID obj lookup service if isinstance( - val, UID | Asset | ActionObject | TwinObject | RemoteFunction + val, UID | Asset | ActionObject | TwinObject | RemoteFunction, ): has_ids = True id = user_code_arg2id(val) @@ -494,7 +469,7 @@ def find_server_identity( return ServerIdentity.from_api(api) else: raise ValueError( - "Multiple Server Identities, please only login to one client (for this policy) and try again" + "Multiple Server Identities, please only login to one client (for this policy) and try again", ) else: raise ValueError("No Server Identities") @@ -503,7 +478,7 @@ def find_server_identity( raise ValueError("Multiple Server Identities") # we need to fix this as its possible we could # grab the wrong API and call a different user context in jupyter testing - pass # just grab the first one + # just grab the first one return matches.pop() def filter_kwargs( @@ -519,7 +494,7 @@ def filter_kwargs( if rule.requires_input: passed_id = kwargs[kw] actionobject: ActionObject = retrieve_item_from_db( - passed_id, context + passed_id, context, ) rule_check_args = (actionobject,) else: @@ -562,13 +537,13 @@ def _is_valid( not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs) if len(not_approved_kwargs) > 0: return Err( - f"Input arguments: {not_approved_kwargs} to the function are not approved yet." + f"Input arguments: {not_approved_kwargs} to the function are not approved yet.", ) return Ok(True) def retrieve_from_db( - code_item_id: UID, allowed_inputs: dict[str, UID], context: AuthedServiceContext + code_item_id: UID, allowed_inputs: dict[str, UID], context: AuthedServiceContext, ) -> Result[dict[str, Any], str]: # relative from ...service.action.action_object import TwinMode @@ -582,7 +557,7 @@ def retrieve_from_db( # but we are not modifying the permissions of the private data root_context = AuthedServiceContext( - server=context.server, credentials=context.server.verify_key + server=context.server, credentials=context.server.verify_key, ) if context.server.server_type == ServerType.DATASITE: for var_name, arg_id in allowed_inputs.items(): @@ -597,7 +572,7 @@ def retrieve_from_db( code_inputs[var_name] = kwarg_value.ok() else: raise Exception( - f"Invalid Server Type for Code Submission:{context.server.server_type}" + f"Invalid Server Type for Code Submission:{context.server.server_type}", ) return Ok(code_inputs) @@ -616,7 +591,7 @@ def allowed_ids_only( allowed_inputs = allowed_inputs.get(server_identity, {}) else: raise Exception( - f"Invalid Server Type for Code Submission:{context.server.server_type}" + f"Invalid Server Type for Code Submission:{context.server.server_type}", ) filtered_kwargs = {} for key in allowed_inputs.keys(): @@ -628,7 +603,7 @@ def allowed_ids_only( if uid != allowed_inputs[key]: raise Exception( - f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}" + f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}", ) filtered_kwargs[key] = value return filtered_kwargs @@ -648,7 +623,7 @@ def filter_kwargs( ) -> Result[dict[Any, Any], str]: try: allowed_inputs = allowed_ids_only( - allowed_inputs=self.inputs, kwargs=kwargs, context=context + allowed_inputs=self.inputs, kwargs=kwargs, context=context, ) results = retrieve_from_db( @@ -688,7 +663,7 @@ def _is_valid( not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs) if len(not_approved_kwargs) > 0: return Err( - f"Function arguments: {not_approved_kwargs} are not approved yet." + f"Function arguments: {not_approved_kwargs} are not approved yet.", ) return Ok(True) @@ -732,7 +707,7 @@ def apply_to_output( return outputs def is_valid(self, context: AuthedServiceContext) -> SyftSuccess | SyftError: # type: ignore - raise NotImplementedError() + raise NotImplementedError @serializable() @@ -745,11 +720,11 @@ class OutputPolicyExecuteCount(OutputPolicy): @property def count(self) -> SyftError | int: api = APIRegistry.api_for( - self.syft_server_location, self.syft_client_verify_key + self.syft_server_location, self.syft_client_verify_key, ) if api is None: raise ValueError( - f"api is None. You must login to {self.syft_server_location}" + f"api is None. You must login to {self.syft_server_location}", ) output_history = api.services.output.get_by_output_policy_id(self.id) @@ -763,10 +738,10 @@ def is_valid(self) -> SyftSuccess | SyftError: # type: ignore is_valid = execution_count < self.limit if is_valid: return SyftSuccess( - message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}" + message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}", ) return SyftError( - message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}" + message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}", ) def _is_valid(self, context: AuthedServiceContext) -> SyftSuccess | SyftError: @@ -779,10 +754,10 @@ def _is_valid(self, context: AuthedServiceContext) -> SyftSuccess | SyftError: is_valid = execution_count < self.limit if is_valid: return SyftSuccess( - message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}" + message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}", ) return SyftError( - message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}" + message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}", ) def public_state(self) -> dict[str, int]: @@ -829,19 +804,16 @@ class UserOutputPolicy(OutputPolicy): # Do not validate private attributes of user-defined policies, User annotations can # contain any type and throw a NameError when resolving. __validate_private_attrs__ = False - pass class UserInputPolicy(InputPolicy): __canonical_name__ = "UserInputPolicy" __validate_private_attrs__ = False - pass @serializable() class EmpyInputPolicy(InputPolicy): __canonical_name__ = "EmptyInputPolicy" - pass class CustomInputPolicy(metaclass=CustomPolicy): @@ -900,12 +872,11 @@ def new_getfile(object: Any) -> Any: # TODO: fix the mypy issue and object.__qualname__ + "." + member.__name__ == member.__qualname__ ): return inspect.getfile(member) - else: - raise TypeError(f"Source for {object!r} not found") + raise TypeError(f"Source for {object!r} not found") def get_code_from_class(policy: type[CustomPolicy]) -> str: - klasses = [inspect.getmro(policy)[0]] # + klasses = [inspect.getmro(policy)[0]] whole_str = "" for klass in klasses: if is_interpreter_jupyter(): @@ -983,7 +954,7 @@ def process_class_code(raw_code: str, class_name: str) -> str: v.visit(tree) if len(tree.body) != 1 or not isinstance(tree.body[0], ast.ClassDef): raise Exception( - "Class code should only contain the Class definition for your policy" + "Class code should only contain the Class definition for your policy", ) old_class = tree.body[0] if len(old_class.bases) != 1 or old_class.bases[0].attr not in [ @@ -992,7 +963,7 @@ def process_class_code(raw_code: str, class_name: str) -> str: ]: raise Exception( f"Class code should either implement {CustomInputPolicy.__name__} " - f"or {CustomOutputPolicy.__name__}" + f"or {CustomOutputPolicy.__name__}", ) # TODO: changes the bases @@ -1027,7 +998,7 @@ def process_class_code(raw_code: str, class_name: str) -> str: module="__future__", names=[ast.alias(name="annotations", asname="annotations")], level=0, - ) + ), ) new_body.append(ast.Import(names=[ast.alias(name="syft", asname="sy")], level=0)) typing_types = [ @@ -1049,7 +1020,7 @@ def process_class_code(raw_code: str, class_name: str) -> str: for typing_type in typing_types ], level=0, - ) + ), ) new_body.append(new_class) module = ast.Module(new_body, type_ignores=[]) @@ -1087,7 +1058,7 @@ def compile_code(context: TransformContext) -> TransformContext: if byte_code is None: raise Exception( "Unable to compile byte code from parsed code. " - + context.output["parsed_code"] + + context.output["parsed_code"], ) else: raise ValueError(f"{context}'s output is None. No transformation happened") @@ -1158,7 +1129,7 @@ def register_policy_class(klass: type, unique_name: str) -> None: ) SyftObjectRegistry.register_cls( - canonical_name=unique_name, version=version, serde_attributes=serde_attributes + canonical_name=unique_name, version=version, serde_attributes=serde_attributes, ) @@ -1177,7 +1148,7 @@ def execute_policy_code(user_policy: UserPolicy) -> Any: try: policy_class = SyftObjectRegistry.get_serde_class( - class_name, version=DEFAULT_USER_POLICY_VERSION + class_name, version=DEFAULT_USER_POLICY_VERSION, ) except Exception: exec(user_policy.byte_code) # nosec diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index bea4c2a66ae..dd0ae9013b5 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -5,13 +5,9 @@ from ...store.document_store import DocumentStore from ...types.uid import UID from ..context import AuthedServiceContext -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from .policy import SubmitUserPolicy -from .policy import UserPolicy +from ..response import SyftError, SyftSuccess +from ..service import TYPE_TO_SERVICE, AbstractService, service_method +from .policy import SubmitUserPolicy, UserPolicy from .user_policy_stash import UserPolicyStash @@ -26,7 +22,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="policy.get_all", name="get_all") def get_all_user_policy( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[UserPolicy] | SyftError: result = self.stash.get_all(context.credentials) if result.is_ok(): @@ -48,7 +44,7 @@ def add_user_policy( @service_method(path="policy.get_by_uid", name="get_by_uid") def get_policy_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py index 4779be12215..5d29005ef83 100644 --- a/packages/syft/src/syft/service/policy/user_policy_stash.py +++ b/packages/syft/src/syft/service/policy/user_policy_stash.py @@ -6,26 +6,27 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from .policy import PolicyUserVerifyKeyPartitionKey -from .policy import UserPolicy +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionSettings, + QueryKeys, +) +from .policy import PolicyUserVerifyKeyPartitionKey, UserPolicy @serializable(canonical_name="UserPolicyStash", version=1) class UserPolicyStash(BaseUIDStoreStash): object_type = UserPolicy settings: PartitionSettings = PartitionSettings( - name=UserPolicy.__canonical_name__, object_type=UserPolicy + name=UserPolicy.__canonical_name__, object_type=UserPolicy, ) def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def get_all_by_user_verify_key( - self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey + self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey, ) -> Result[list[UserPolicy], str]: qks = QueryKeys(qks=[PolicyUserVerifyKeyPartitionKey.with_obj(user_verify_key)]) return self.query_one(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 6854503ff8a..9307b52a4e4 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -1,40 +1,32 @@ # future from __future__ import annotations -# stdlib -from collections.abc import Callable -from collections.abc import Iterable import copy import hashlib import textwrap import time + +# stdlib +from collections.abc import Callable, Iterable from typing import Any # third party -from pydantic import Field -from pydantic import field_validator +from pydantic import Field, field_validator from rich.progress import Progress from typing_extensions import Self # relative from ...client.api import ServerIdentity -from ...client.client import SyftClient -from ...client.client import SyftClientSessionCache +from ...client.client import SyftClient, SyftClientSessionCache from ...serde.serializable import serializable from ...serde.serialize import _serialize -from ...server.credentials import SyftSigningKey -from ...server.credentials import SyftVerifyKey +from ...server.credentials import SyftSigningKey, SyftVerifyKey from ...service.metadata.server_metadata import ServerMetadata from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.identity import Identity -from ...types.identity import UserIdentity -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.syft_object import short_qual_name -from ...types.transforms import TransformContext -from ...types.transforms import rename -from ...types.transforms import transform +from ...types.identity import Identity, UserIdentity +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject, short_qual_name +from ...types.transforms import TransformContext, rename, transform from ...types.uid import UID from ...util import options from ...util.colors import SURFACE @@ -43,15 +35,9 @@ from ...util.util import full_name_with_qualname from ..code.user_code import SubmitUserCode from ..network.network_service import ServerPeer -from ..network.routes import ServerRoute -from ..network.routes import connection_to_route -from ..request.request import Request -from ..request.request import RequestStatus -from ..response import SyftError -from ..response import SyftException -from ..response import SyftInfo -from ..response import SyftNotReady -from ..response import SyftSuccess +from ..network.routes import ServerRoute, connection_to_route +from ..request.request import Request, RequestStatus +from ..response import SyftError, SyftException, SyftInfo, SyftNotReady, SyftSuccess from ..user.user import UserView @@ -119,7 +105,7 @@ def valid(self) -> SyftSuccess | SyftError: event_hash_bytes, current_hash = create_project_event_hash(self) if current_hash != self.event_hash: raise Exception( - f"Event hash {current_hash} does not match {self.event_hash}" + f"Event hash {current_hash} does not match {self.event_hash}", ) if self.creator_verify_key is None: return SyftError(message=f"{self}'s creator_verify_key is None") @@ -129,7 +115,7 @@ def valid(self) -> SyftSuccess | SyftError: return SyftError(message=f"Failed to validate message. {e}") def valid_descendant( - self, project: Project, prev_event: Self | None + self, project: Project, prev_event: Self | None, ) -> SyftSuccess | SyftError: valid = self.valid if not valid: @@ -147,13 +133,13 @@ def valid_descendant( if self.prev_event_uid != prev_event_id: return SyftError( message=f"{self} prev_event_uid: {self.prev_event_uid} " - "does not match {prev_event_id}" + "does not match {prev_event_id}", ) if self.prev_event_hash != prev_event_hash: return SyftError( message=f"{self} prev_event_hash: {self.prev_event_hash} " - "does not match {prev_event_hash}" + "does not match {prev_event_hash}", ) if ( @@ -163,13 +149,13 @@ def valid_descendant( ): return SyftError( message=f"{self} seq_no: {self.seq_no} " - "is not subsequent to {prev_seq_no}" + "is not subsequent to {prev_seq_no}", ) if self.project_id != project.id: return SyftError( message=f"{self} project_id: {self.project_id} " - "does not match {project.id}" + "does not match {project.id}", ) if hasattr(self, "parent_event_id"): @@ -179,7 +165,7 @@ def valid_descendant( and type(self) not in parent_event.allowed_sub_types ): return SyftError( - message=f"{self} is not a valid subevent" f"for {parent_event}" + message=f"{self} is not a valid subeventfor {parent_event}", ) return SyftSuccess(message=f"{self} is valid descendant of {prev_event}") @@ -187,7 +173,7 @@ def sign(self, signing_key: SyftSigningKey) -> None: if self.creator_verify_key != signing_key.verify_key: raise Exception( f"creator_verify_key has changed from: {self.creator_verify_key} to " - f"{signing_key.verify_key}" + f"{signing_key.verify_key}", ) # Calculate Hash event_hash_bytes, event_hash = create_project_event_hash(self) @@ -277,7 +263,7 @@ def _validate_linked_request(cls, v: Any) -> LinkedObject: return v else: raise ValueError( - f"linked_request should be either Request or LinkedObject, got {type(v)}" + f"linked_request should be either Request or LinkedObject, got {type(v)}", ) @property @@ -306,7 +292,7 @@ def approve(self) -> ProjectRequestResponse: return ProjectRequestResponse(response=True, parent_event_id=self.id) def accept_by_depositing_result( - self, result: Any, force: bool = False + self, result: Any, force: bool = False, ) -> SyftError | SyftSuccess: return self.request.accept_by_depositing_result(result=result, force=force) @@ -316,17 +302,20 @@ def status(self, project: Project) -> SyftInfo | SyftError | None: """Returns the status of the request. Args: + ---- project (Project): Project object to check the status Returns: + ------- str: Status of the request. During Request status calculation, we do not allow multiple responses + """ responses: list[ProjectEvent] = project.get_children(self) if len(responses) == 0: return SyftInfo( - "No one has responded to the request yet. Kindly recheck later 🙂" + "No one has responded to the request yet. Kindly recheck later 🙂", ) elif len(responses) > 1: return SyftError( @@ -334,12 +323,12 @@ def status(self, project: Project) -> SyftInfo | SyftError | None: "which is currently not possible" "The request should contain only one response" "Kindly re-submit a new request" - "The Syft Team is working on this issue to handle multiple responses" + "The Syft Team is working on this issue to handle multiple responses", ) response = responses[0] if not isinstance(response, ProjectRequestResponse): return SyftError( # type: ignore[unreachable] - message=f"Response : {type(response)} is not of type ProjectRequestResponse" + message=f"Response : {type(response)} is not of type ProjectRequestResponse", ) print("Request Status : ", "Approved" if response.response else "Denied") @@ -411,8 +400,8 @@ def poll_creation_wizard() -> tuple[str, list[str]]: print() print( w.fill( - "Question 2: Enter the number of choices, you would like to have in the poll" - ) + "Question 2: Enter the number of choices, you would like to have in the poll", + ), ) print() while True: @@ -421,7 +410,7 @@ def poll_creation_wizard() -> tuple[str, list[str]]: except ValueError: print() print( - w.fill("Number of choices, should be an integer.Kindly re-enter again.") + w.fill("Number of choices, should be an integer.Kindly re-enter again."), ) print() continue @@ -444,7 +433,7 @@ def poll_creation_wizard() -> tuple[str, list[str]]: print() print( - w.fill("All done! You have successfully completed the Poll Creation Wizard! 🎩") + w.fill("All done! You have successfully completed the Poll Creation Wizard! 🎩"), ) return (question, choices) @@ -508,13 +497,13 @@ def poll_answer_wizard(poll: ProjectMultipleChoicePoll) -> int: try: choice: int = int(input("\t")) if choice < 1 or choice > len(poll.choices): - raise ValueError() + raise ValueError except ValueError: print() print( w.fill( - f"Poll Answer should be a natural number between 1 and {len(poll.choices)}" - ) + f"Poll Answer should be a natural number between 1 and {len(poll.choices)}", + ), ) print() continue @@ -523,7 +512,7 @@ def poll_answer_wizard(poll: ProjectMultipleChoicePoll) -> int: print("\t" + "=" * 69) print() print( - w.fill("All done! You have successfully completed the Poll Answer Wizard! 🎩") + w.fill("All done! You have successfully completed the Poll Answer Wizard! 🎩"), ) print() @@ -558,18 +547,21 @@ def answer(self, answer: int) -> ProjectMessage: return AnswerProjectPoll(answer=answer, parent_event_id=self.id) def status( - self, project: Project, pretty_print: bool = True + self, project: Project, pretty_print: bool = True, ) -> dict | SyftError | SyftInfo | None: """Returns the status of the poll Args: + ---- project (Project): Project object to check the status Returns: + ------- str: Status of the poll During Poll calculation, a user would have answered the poll many times The status of the poll would be calculated based on the latest answer of the user + """ poll_answers = project.get_children(self) if len(poll_answers) == 0: @@ -579,7 +571,7 @@ def status( for poll_answer in poll_answers[::-1]: if not isinstance(poll_answer, AnswerProjectPoll): return SyftError( # type: ignore[unreachable] - message=f"Poll answer: {type(poll_answer)} is not of type AnswerProjectPoll" + message=f"Poll answer: {type(poll_answer)} is not of type AnswerProjectPoll", ) creator_verify_key = poll_answer.creator_verify_key @@ -624,7 +616,7 @@ def add_code_request_to_project( # TODO: fix the mypy issue if not isinstance(code, SubmitUserCode): return SyftError( # type: ignore[unreachable] - message=f"Currently we are only support creating requests for SubmitUserCode: {type(code)}" + message=f"Currently we are only support creating requests for SubmitUserCode: {type(code)}", ) if not isinstance(client, SyftClient): @@ -634,7 +626,7 @@ def add_code_request_to_project( reason = f"Code Request for Project: {project.name} has been submitted by {project.created_by}" submitted_req = client.api.services.code.request_code_execution( - code=code, reason=reason + code=code, reason=reason, ) if isinstance(submitted_req, SyftError): return submitted_req @@ -650,7 +642,7 @@ def add_code_request_to_project( return SyftSuccess( message=f"Code request for '{code.func_name}' successfully added to '{project.name}' Project. " - f"To see code requests by a client, run `[your_client].code`" + f"To see code requests by a client, run `[your_client].code`", ) @@ -737,7 +729,7 @@ def key_in_project(self, verify_key: SyftVerifyKey) -> bool: return verify_key in project_verify_keys def get_identity_from_key( - self, verify_key: SyftVerifyKey + self, verify_key: SyftVerifyKey, ) -> list[ServerIdentity | UserIdentity]: identities: list[Identity] = self.get_all_identities() for identity in identities: @@ -755,7 +747,7 @@ def get_leader_client(self, signing_key: SyftSigningKey) -> SyftClient: verify_key = signing_key.verify_key leader_client = SyftClientSessionCache.get_client_by_uid_and_verify_key( - verify_key=verify_key, server_uid=self.leader_server_peer.id + verify_key=verify_key, server_uid=self.leader_server_peer.id, ) if leader_client is None: @@ -776,7 +768,7 @@ def has_permission(self, verify_key: SyftVerifyKey) -> bool: return self.key_in_project(verify_key) def _append_event( - self, event: ProjectEvent, credentials: SyftSigningKey + self, event: ProjectEvent, credentials: SyftSigningKey, ) -> SyftSuccess | SyftError: prev_event = self.events[-1] if self.events else None valid = event.valid_descendant(self, prev_event) @@ -850,7 +842,7 @@ def valid_str(current_hash: int) -> str: prev_event = last_event if last_event is not None else self print( f"{icon} {type(event).__name__}: {event.id} " - f"after {type(prev_event).__name__}: {prev_event.id}" + f"after {type(prev_event).__name__}: {prev_event.id}", ) if not result: @@ -989,12 +981,12 @@ def reply_message( reply_event = message.reply(reply) elif isinstance(message, ProjectThreadMessage): # type: ignore[unreachable] reply_event = ProjectThreadMessage( - message=reply, parent_event_id=message.parent_event_id + message=reply, parent_event_id=message.parent_event_id, ) else: return SyftError( message=f"You can only reply to a message: {type(message)}" - "Kindly re-check the msg" + "Kindly re-check the msg", ) result = self.add_event(reply_event) @@ -1034,7 +1026,7 @@ def answer_poll( if not isinstance(poll, ProjectMultipleChoicePoll): return SyftError( # type: ignore[unreachable] message=f"You can only reply to a poll: {type(poll)}" - "Kindly re-check the poll" + "Kindly re-check the poll", ) if not isinstance(answer, int) or answer <= 0 or answer > len(poll.choices): @@ -1078,7 +1070,7 @@ def approve_request( else: return SyftError( # type: ignore[unreachable] message=f"You can only approve a request: {type(request)}" - "Kindly re-check the request" + "Kindly re-check the request", ) result = self.add_event(request_event) if isinstance(result, SyftSuccess): @@ -1087,11 +1079,10 @@ def approve_request( def sync(self, verbose: bool | None = True) -> SyftSuccess | SyftError: """Sync the latest project with the state sync leader""" - leader_client = self.get_leader_client(self.user_signing_key) unsynced_events = leader_client.api.services.project.sync( - project_id=self.id, seq_no=self.get_last_seq_no() + project_id=self.id, seq_no=self.get_last_seq_no(), ) if isinstance(unsynced_events, SyftError): return unsynced_events @@ -1135,7 +1126,7 @@ def requests(self) -> list[Request]: @property def pending_requests(self) -> int: return sum( - [request.status == RequestStatus.PENDING for request in self.requests] + [request.status == RequestStatus.PENDING for request in self.requests], ) @@ -1223,7 +1214,7 @@ def _repr_html_(self) -> Any: @field_validator("members", mode="before") @classmethod def verify_members( - cls, val: list[SyftClient] | list[ServerIdentity] + cls, val: list[SyftClient] | list[ServerIdentity], ) -> list[SyftClient] | list[ServerIdentity]: # SyftClients must be logged in by the same emails clients = cls.get_syft_clients(val) @@ -1231,7 +1222,7 @@ def verify_members( emails = {client.logged_in_user for client in clients} if len(emails) > 1: raise ValueError( - f"All clients must be logged in from the same account. Found multiple: {emails}" + f"All clients must be logged in from the same account. Found multiple: {emails}", ) return val @@ -1250,11 +1241,11 @@ def to_server_identity(val: SyftClient | ServerIdentity) -> ServerIdentity: return metadata.to(ServerIdentity) else: raise SyftException( - f"members must be SyftClient or ServerIdentity. Received: {type(val)}" + f"members must be SyftClient or ServerIdentity. Received: {type(val)}", ) def create_code_request( - self, obj: SubmitUserCode, client: SyftClient, reason: str | None = None + self, obj: SubmitUserCode, client: SyftClient, reason: str | None = None, ) -> SyftError | SyftSuccess: return add_code_request_to_project( project=self, @@ -1264,7 +1255,7 @@ def create_code_request( ) @deprecated( - reason="Project.start has been renamed to Project.send", return_syfterror=True + reason="Project.start has been renamed to Project.send", return_syfterror=True, ) def start(self, return_all_projects: bool = False) -> Project | list[Project]: return self.send(return_all_projects=return_all_projects) @@ -1371,7 +1362,7 @@ def check_permissions(context: TransformContext) -> TransformContext: if len(context.output["project_permissions"]) == 0: project_permissions = context.output["project_permissions"] project_permissions = project_permissions.union( - add_members_as_owners(context.output["members"]) + add_members_as_owners(context.output["members"]), ) context.output["project_permissions"] = project_permissions @@ -1393,10 +1384,13 @@ def hash_object(obj: Any) -> tuple[bytes, str]: """Hashes an object using sha256 Args: + ---- obj (Any): Object to be hashed Returns: + ------- str: Hashed value of the object + """ hash_bytes = _serialize(obj, to_bytes=True, for_hashing=True) hash = hashlib.sha256(hash_bytes) @@ -1418,7 +1412,7 @@ def create_project_hash(project: Project) -> tuple[bytes, str]: project.created_by, [hash_object(member) for member in project.members], [hash_object(user) for user in project.users], - ] + ], ) @@ -1440,5 +1434,5 @@ def create_project_event_hash(project_event: ProjectEvent) -> tuple[bytes, str]: project_event.timestamp.utc_timestamp, project_event.prev_event_hash, hash_object(project_event.creator_verify_key)[1], - ] + ], ) diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 0da9d043e18..09153a6600d 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -9,21 +9,20 @@ from ..context import AuthedServiceContext from ..notification.notification_service import NotificationService from ..notification.notifications import CreateNotification -from ..response import SyftError -from ..response import SyftNotReady -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method -from ..user.user_roles import GUEST_ROLE_LEVEL -from ..user.user_roles import ONLY_DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import ServiceRole -from .project import Project -from .project import ProjectEvent -from .project import ProjectRequest -from .project import ProjectSubmit -from .project import create_project_hash +from ..response import SyftError, SyftNotReady, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method +from ..user.user_roles import ( + GUEST_ROLE_LEVEL, + ONLY_DATA_SCIENTIST_ROLE_LEVEL, + ServiceRole, +) +from .project import ( + Project, + ProjectEvent, + ProjectRequest, + ProjectSubmit, + create_project_hash, +) from .project_stash import ProjectStash @@ -55,10 +54,9 @@ def can_create_project(self, context: AuthedServiceContext) -> bool | SyftError: roles=ONLY_DATA_SCIENTIST_ROLE_LEVEL, ) def create_project( - self, context: AuthedServiceContext, project: ProjectSubmit + self, context: AuthedServiceContext, project: ProjectSubmit, ) -> SyftSuccess | SyftError: """Start a Project""" - check_role = self.can_create_project(context) if isinstance(check_role, SyftError): return check_role @@ -66,7 +64,7 @@ def create_project( try: # Check if the project with given id already exists project_id_check = self.stash.get_by_uid( - credentials=context.server.verify_key, uid=project.id + credentials=context.server.verify_key, uid=project.id, ) if project_id_check.is_err(): @@ -74,7 +72,7 @@ def create_project( if project_id_check.ok() is not None: return SyftError( - message=f"Project with id: {project.id} already exists." + message=f"Project with id: {project.id} already exists.", ) project_obj: Project = project.to(Project, context=context) @@ -103,23 +101,19 @@ def create_project( message=( f"Leader Server(id={leader_server.id.short()}) is not a " f"peer of this Server(id={this_server_id})" - ) + ), ) leader_server_peer = peer.ok() - else: - # for the leader server, as it does not have route information to itself - # we rely on the data scientist to provide the route - # the route is then validated by the leader - if project.leader_server_route is not None: - leader_server_peer = ( - project.leader_server_route.validate_with_context( - context=context - ) - ) - else: - return SyftError( - message=f"project {project}'s leader_server_route is None" + elif project.leader_server_route is not None: + leader_server_peer = ( + project.leader_server_route.validate_with_context( + context=context, ) + ) + else: + return SyftError( + message=f"project {project}'s leader_server_route is None", + ) project_obj.leader_server_peer = leader_server_peer @@ -132,7 +126,7 @@ def create_project( project_obj_store = result.ok() project_obj_store = self.add_signing_key_to_project( - context, project_obj_store + context, project_obj_store, ) return project_obj_store @@ -147,15 +141,14 @@ def create_project( roles=GUEST_ROLE_LEVEL, ) def add_event( - self, context: AuthedServiceContext, project_event: ProjectEvent + self, context: AuthedServiceContext, project_event: ProjectEvent, ) -> SyftSuccess | SyftError: """To add events to a projects""" - # Event object should be received from the leader of the project # retrieve the project object by server verify key project_obj = self.stash.get_by_uid( - context.server.verify_key, uid=project_event.project_id + context.server.verify_key, uid=project_event.project_id, ) if project_obj.is_err(): return SyftError(message=str(project_obj.err())) @@ -163,7 +156,7 @@ def add_event( project: Project = project_obj.ok() if project.state_sync_leader.verify_key == context.server.verify_key: return SyftError( - message="Project Events should be passed to leader by broadcast endpoint" + message="Project Events should be passed to leader by broadcast endpoint", ) if context.credentials != project.state_sync_leader.verify_key: return SyftError(message="Only the leader of the project can add events") @@ -181,7 +174,7 @@ def add_event( if result.is_err(): return SyftError(message=str(result.err())) return SyftSuccess( - message=f"Project event {project_event.id} added successfully " + message=f"Project event {project_event.id} added successfully ", ) @service_method( @@ -190,7 +183,7 @@ def add_event( roles=GUEST_ROLE_LEVEL, ) def broadcast_event( - self, context: AuthedServiceContext, project_event: ProjectEvent + self, context: AuthedServiceContext, project_event: ProjectEvent, ) -> SyftSuccess | SyftError: """To add events to a projects""" # Only the leader of the project could add events to the projects @@ -198,7 +191,7 @@ def broadcast_event( # The leader broadcasts the event to all the members of the project project_obj = self.stash.get_by_uid( - context.server.verify_key, uid=project_event.project_id + context.server.verify_key, uid=project_event.project_id, ) if project_obj.is_err(): @@ -210,7 +203,7 @@ def broadcast_event( if project.state_sync_leader.verify_key != context.server.verify_key: return SyftError( - message="Only the leader of the project can broadcast events" + message="Only the leader of the project can broadcast events", ) if project_event.seq_no is None: @@ -240,19 +233,19 @@ def broadcast_event( if peer.is_err(): return SyftError( message=f"Leader server does not have peer {member.name}-{member.id.short()}" - + " Kindly exchange routes with the peer" + + " Kindly exchange routes with the peer", ) peer = peer.ok() remote_client = peer.client_with_context(context=context) if remote_client.is_err(): return SyftError( message=f"Failed to create remote client for peer: " - f"{peer.id}. Error: {remote_client.err()}" + f"{peer.id}. Error: {remote_client.err()}", ) remote_client = remote_client.ok() event_result = remote_client.api.services.project.add_event( - project_event + project_event, ) if isinstance(event_result, SyftError): return event_result @@ -269,10 +262,9 @@ def broadcast_event( roles=GUEST_ROLE_LEVEL, ) def sync( - self, context: AuthedServiceContext, project_id: UID, seq_no: int + self, context: AuthedServiceContext, project_id: UID, seq_no: int, ) -> list[ProjectEvent] | SyftError: """To fetch unsynced events from the project""" - # Event object should be received from the leader of the project # retrieve the project object by server verify key @@ -283,7 +275,7 @@ def sync( project: Project = project_obj.ok() if project.state_sync_leader.verify_key != context.server.verify_key: return SyftError( - message="Project Events should be synced only with the leader" + message="Project Events should be synced only with the leader", ) if not project.has_permission(context.credentials): @@ -319,7 +311,7 @@ def get_all(self, context: AuthedServiceContext) -> list[Project] | SyftError: roles=GUEST_ROLE_LEVEL, ) def get_by_name( - self, context: AuthedServiceContext, name: str + self, context: AuthedServiceContext, name: str, ) -> Project | SyftError: result = self.stash.get_by_name(context.credentials, project_name=name) if result.is_err(): @@ -335,7 +327,7 @@ def get_by_name( roles=GUEST_ROLE_LEVEL, ) def get_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Project | SyftError: result = self.stash.get_by_uid( credentials=context.server.verify_key, @@ -348,14 +340,14 @@ def get_by_uid( return SyftError(message=f'Project(id="{uid}") does not exist') def add_signing_key_to_project( - self, context: AuthedServiceContext, project: Project + self, context: AuthedServiceContext, project: Project, ) -> Project | SyftError: # Automatically infuse signing key of user # requesting get_all() or creating the project object user_service = context.server.get_service("userservice") user = user_service.stash.get_by_verify_key( - credentials=context.credentials, verify_key=context.credentials + credentials=context.credentials, verify_key=context.credentials, ) if user.is_err(): return SyftError(message=str(user.err())) @@ -377,14 +369,16 @@ def check_for_project_request( """To check for project request event and create a message for the root user Args: + ---- project (Project): Project object project_event (ProjectEvent): Project event object context (AuthedServiceContext): Context of the server Returns: + ------- Union[SyftSuccess, SyftError]: SyftSuccess if message is created else SyftError - """ + """ if ( isinstance(project_event, ProjectRequest) and project_event.linked_request.server_uid == context.server.id diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index dcd258938b3..a2c5ed0f9bb 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -6,11 +6,13 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey +from ...store.document_store import ( + BaseUIDStoreStash, + PartitionKey, + PartitionSettings, + QueryKeys, + UIDPartitionKey, +) from ...types.uid import UID from ...util.telemetry import instrument from ..request.request import Request @@ -26,11 +28,11 @@ class ProjectStash(BaseUIDStoreStash): object_type = Project settings: PartitionSettings = PartitionSettings( - name=Project.__canonical_name__, object_type=Project + name=Project.__canonical_name__, object_type=Project, ) def get_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: VerifyKeyPartitionKey + self, credentials: SyftVerifyKey, verify_key: VerifyKeyPartitionKey, ) -> Result[list[Request], SyftError]: if isinstance(verify_key, str): verify_key = SyftVerifyKey.from_string(verify_key) @@ -41,13 +43,13 @@ def get_all_for_verify_key( ) def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[Project | None, str]: qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) return self.query_one(credentials=credentials, qks=qks) def get_by_name( - self, credentials: SyftVerifyKey, project_name: str + self, credentials: SyftVerifyKey, project_name: str, ) -> Result[Project | None, str]: qks = QueryKeys(qks=[NamePartitionKey.with_obj(project_name)]) return self.query_one(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/queue/base_queue.py b/packages/syft/src/syft/service/queue/base_queue.py index aed6f2244d4..3b488560869 100644 --- a/packages/syft/src/syft/service/queue/base_queue.py +++ b/packages/syft/src/syft/service/queue/base_queue.py @@ -1,14 +1,12 @@ # stdlib -from typing import Any -from typing import ClassVar +from typing import Any, ClassVar # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext from ...store.document_store import BaseStash from ...types.uid import UID -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from ..worker.worker_stash import WorkerStash diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 67d3aa44a90..df5f0250fb0 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -1,15 +1,14 @@ # stdlib import logging -from multiprocessing import Process import threading -from threading import Thread import time +from multiprocessing import Process +from threading import Thread from typing import Any # third party import psutil -from result import Err -from result import Ok +from result import Err, Ok # relative from ...serde.deserialize import _deserialize as deserialize @@ -20,18 +19,17 @@ from ...store.document_store import BaseStash from ...types.datetime import DateTime from ...types.uid import UID -from ..job.job_stash import Job -from ..job.job_stash import JobStatus -from ..response import SyftError -from ..response import SyftSuccess +from ..job.job_stash import Job, JobStatus +from ..response import SyftError, SyftSuccess from ..worker.worker_stash import WorkerStash -from .base_queue import AbstractMessageHandler -from .base_queue import BaseQueueManager -from .base_queue import QueueConfig -from .base_queue import QueueConsumer -from .base_queue import QueueProducer -from .queue_stash import QueueItem -from .queue_stash import Status +from .base_queue import ( + AbstractMessageHandler, + BaseQueueManager, + QueueConfig, + QueueConsumer, + QueueProducer, +) +from .queue_stash import QueueItem, Status logger = logging.getLogger(__name__) @@ -59,7 +57,7 @@ def run(self) -> None: def monitor(self) -> None: # Implement the monitoring logic here job = self.worker.job_stash.get_by_uid( - self.credentials, self.queue_item.job_id + self.credentials, self.queue_item.job_id, ).ok() if job and job.status == JobStatus.TERMINATING: self.terminate(job) @@ -210,7 +208,7 @@ def handle_message_multiprocessing( if isinstance(result, Ok): result = result.ok() if hasattr(result, "syft_action_data") and isinstance( - result.syft_action_data, Err + result.syft_action_data, Err, ): status = Status.ERRORED job_status = JobStatus.ERRORED @@ -307,7 +305,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: logger.info( f"Handling queue item: id={queue_item.id}, method={queue_item.method} " f"args={queue_item.args}, kwargs={queue_item.kwargs} " - f"service={queue_item.service}, as_thread={queue_config.thread_workers}" + f"service={queue_item.service}, as_thread={queue_config.thread_workers}", ) if queue_config.thread_workers: diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index 1c56b494b02..fd0bcbb5c32 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -7,11 +7,9 @@ from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError -from ..service import AbstractService -from ..service import service_method +from ..service import AbstractService, service_method from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from .queue_stash import QueueItem -from .queue_stash import QueueStash +from .queue_stash import QueueItem, QueueStash @instrument @@ -30,7 +28,7 @@ def __init__(self, store: DocumentStore) -> None: roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_subjobs( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> list[QueueItem] | SyftError: res = self.stash.get_by_parent_id(context.credentials, uid=uid) if res.is_err(): diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 8565ca6118f..ecfa56b222f 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -3,27 +3,26 @@ from typing import Any # third party -from result import Ok -from result import Result +from result import Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey +from ...store.document_store import ( + BaseStash, + DocumentStore, + PartitionKey, + PartitionSettings, + QueryKeys, + UIDPartitionKey, +) from ...store.linked_obj import LinkedObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess @serializable(canonical_name="Status", version=1) @@ -100,7 +99,7 @@ class APIEndpointQueueItem(QueueItem): class QueueStash(BaseStash): object_type = QueueItem settings: PartitionSettings = PartitionSettings( - name=QueueItem.__canonical_name__, object_type=QueueItem + name=QueueItem.__canonical_name__, object_type=QueueItem, ) def __init__(self, store: DocumentStore) -> None: @@ -136,21 +135,21 @@ def set_placeholder( return item def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[QueueItem | None, str]: qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) item = self.query_one(credentials=credentials, qks=qks) return item def pop( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[QueueItem | None, str]: item = self.get_by_uid(credentials=credentials, uid=uid) self.delete_by_uid(credentials=credentials, uid=uid) return item def pop_on_complete( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[QueueItem | None, str]: item = self.get_by_uid(credentials=credentials, uid=uid) if item.is_ok(): @@ -160,7 +159,7 @@ def pop_on_complete( return item def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, ) -> Result[SyftSuccess, str]: qk = UIDPartitionKey.with_obj(uid) result = super().delete(credentials=credentials, qk=qk) @@ -169,7 +168,7 @@ def delete_by_uid( return result def get_by_status( - self, credentials: SyftVerifyKey, status: Status + self, credentials: SyftVerifyKey, status: Status, ) -> Result[list[QueueItem], str]: qks = QueryKeys(qks=StatusPartitionKey.with_obj(status)) diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 4d6ba436387..22fac690728 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -1,23 +1,22 @@ # stdlib -from binascii import hexlify -from collections import defaultdict import itertools import logging import socketserver import sys import threading -from threading import Event import time +from binascii import hexlify +from collections import defaultdict +from threading import Event from time import sleep -from typing import Any -from typing import cast +from typing import Any, cast + +import zmq # third party from pydantic import field_validator from result import Result -import zmq -from zmq import Frame -from zmq import LINGER +from zmq import LINGER, Frame from zmq.error import ContextTerminated # relative @@ -28,25 +27,22 @@ from ...service.action.action_object import ActionObject from ...service.context import AuthedServiceContext from ...types.base import SyftBaseModel -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.uid import UID from ...util.util import get_queue_address -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from ..service import AbstractService -from ..worker.worker_pool import ConsumerState -from ..worker.worker_pool import SyftWorker +from ..worker.worker_pool import ConsumerState, SyftWorker from ..worker.worker_stash import WorkerStash -from .base_queue import AbstractMessageHandler -from .base_queue import QueueClient -from .base_queue import QueueClientConfig -from .base_queue import QueueConfig -from .base_queue import QueueConsumer -from .base_queue import QueueProducer -from .queue_stash import ActionQueueItem -from .queue_stash import QueueStash -from .queue_stash import Status +from .base_queue import ( + AbstractMessageHandler, + QueueClient, + QueueClientConfig, + QueueConfig, + QueueConsumer, + QueueProducer, +) +from .queue_stash import ActionQueueItem, QueueStash, Status # Producer/Consumer heartbeat interval (in seconds) HEARTBEAT_INTERVAL_SEC = 2 @@ -134,7 +130,7 @@ def reset_expiry(self) -> None: self.expiry_t.reset() def _syft_worker( - self, stash: WorkerStash, credentials: SyftVerifyKey + self, stash: WorkerStash, credentials: SyftVerifyKey, ) -> Result[SyftWorker | None, str]: return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id) @@ -173,7 +169,6 @@ def address(self) -> str: def post_init(self) -> None: """Initialize producer state.""" - self.services: dict[str, Service] = {} self.workers: dict[bytes, Worker] = {} self.waiting: list[Worker] = [] @@ -196,7 +191,7 @@ def close(self) -> None: if self.thread.is_alive(): logger.error( f"ZMQProducer message sending thread join timed out during closing. " - f"Queue name {self.queue_name}, " + f"Queue name {self.queue_name}, ", ) self.thread = None @@ -205,7 +200,7 @@ def close(self) -> None: if self.producer_thread.is_alive(): logger.error( f"ZMQProducer queue thread join timed out during closing. " - f"Queue name {self.queue_name}, " + f"Queue name {self.queue_name}, ", ) self.producer_thread = None @@ -226,7 +221,7 @@ def action_service(self) -> AbstractService: raise Exception(f"{self.auth_context} does not have a server.") def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: - """recursively check collections for unresolved action objects""" + """Recursively check collections for unresolved action objects""" if isinstance(arg, UID): arg = self.action_service.get(self.auth_context, arg).ok() return self.contains_unresolved_action_objects(arg, recursion=recursion + 1) @@ -245,14 +240,14 @@ def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bo if isinstance(arg, list): for elem in arg: value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 + elem, recursion=recursion + 1, ) if value: return True if isinstance(arg, dict): for elem in arg.values(): value = self.contains_unresolved_action_objects( - elem, recursion=recursion + 1 + elem, recursion=recursion + 1, ) if value: return True @@ -290,13 +285,13 @@ def read_items(self) -> None: if isinstance(item, ActionQueueItem): action = item.kwargs["action"] if self.contains_unresolved_action_objects( - action.args + action.args, ) or self.contains_unresolved_action_objects(action.kwargs): continue msg_bytes = serialize(item, to_bytes=True) worker_pool = item.worker_pool.resolve_with_context( - self.auth_context + self.auth_context, ) worker_pool = worker_pool.ok() service_name = worker_pool.name @@ -316,7 +311,7 @@ def read_items(self) -> None: res = self.queue_stash.update(item.syft_client_verify_key, item) if res.is_err(): logger.error( - f"Failed to update queue item={item} error={res.err()}" + f"Failed to update queue item={item} error={res.err()}", ) elif item.status == Status.PROCESSING: # Evaluate Retry condition here @@ -331,7 +326,7 @@ def read_items(self) -> None: res = self.queue_stash.update(item.syft_client_verify_key, item) if res.is_err(): logger.error( - f"Failed to update queue item={item} error={res.err()}" + f"Failed to update queue item={item} error={res.err()}", ) def run(self) -> None: @@ -377,16 +372,16 @@ def purge_workers(self) -> None: from ...service.worker.worker_service import WorkerService worker_service = cast( - WorkerService, self.auth_context.server.get_service(WorkerService) + WorkerService, self.auth_context.server.get_service(WorkerService), ) worker_service._delete(self.auth_context, syft_worker) def update_consumer_state_for_worker( - self, syft_worker_id: UID, consumer_state: ConsumerState + self, syft_worker_id: UID, consumer_state: ConsumerState, ) -> None: if self.worker_stash is None: logger.error( # type: ignore[unreachable] - f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}" + f"ZMQProducer worker stash not defined for {self.queue_name} - {self.id}", ) return @@ -449,7 +444,6 @@ def send_to_worker( If message is provided, sends that message. """ - if self.socket.closed: logger.warning("Socket is closed. Cannot send message.") return @@ -527,7 +521,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N worker_ready = hexlify(address) in self.workers worker = self.require_worker(address) - if QueueMsgProtocol.W_READY == command: + if command == QueueMsgProtocol.W_READY: service_name = data.pop(0).decode() syft_worker_id = data.pop(0).decode() if worker_ready: @@ -549,7 +543,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N worker.syft_worker_id = UID(syft_worker_id) self.worker_waiting(worker) - elif QueueMsgProtocol.W_HEARTBEAT == command: + elif command == QueueMsgProtocol.W_HEARTBEAT: if worker_ready: # If worker is ready then reset expiry # and add it to worker waiting list @@ -558,7 +552,7 @@ def process_worker(self, address: bytes, command: bytes, data: list[bytes]) -> N else: logger.info(f"Got heartbeat, but worker not ready. {worker}") self.delete_worker(worker, True) - elif QueueMsgProtocol.W_DISCONNECT == command: + elif command == QueueMsgProtocol.W_DISCONNECT: logger.info(f"Removing disconnected worker: {worker}") self.delete_worker(worker, False) else: @@ -579,7 +573,7 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: if worker.syft_worker_id is not None: self.update_consumer_state_for_worker( - worker.syft_worker_id, ConsumerState.DETACHED + worker.syft_worker_id, ConsumerState.DETACHED, ) @property @@ -651,7 +645,7 @@ def close(self) -> None: logger.error( f"ZMQConsumer thread join timed out during closing. " f"SyftWorker id {self.syft_worker_id}, " - f"service name {self.service_name}." + f"service name {self.service_name}.", ) self.thread = None self.poller.unregister(self.socket) @@ -747,11 +741,10 @@ def _run(self) -> None: self.reconnect_to_producer() else: logger.error(f"ZMQConsumer invalid command: {command}") - else: - if not self.is_producer_alive(): - logger.info("Producer check-alive timed out. Reconnecting.") - self.reconnect_to_producer() - self.set_producer_alive() + elif not self.is_producer_alive(): + logger.info("Producer check-alive timed out. Reconnecting.") + self.reconnect_to_producer() + self.set_producer_alive() if not self._stop.is_set(): self.send_heartbeat() @@ -802,7 +795,7 @@ def _set_worker_job(self, job_id: UID | None) -> None: ) if res.is_err(): logger.error( - f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}" + f"Failed to update consumer state for {self.service_name}-{self.id}, error={res.err()}", ) @property @@ -857,7 +850,6 @@ def add_producer( A queue can have at most one producer attached to it. """ - if port is None: if self.config.queue_port is None: self.config.queue_port = self._get_free_tcp_port(self.host) @@ -889,7 +881,6 @@ def add_consumer( A queue should have at least one producer attached to the group. """ - if address is None: address = get_queue_address(self.config.queue_port) @@ -914,14 +905,14 @@ def send_message( producer = self.producers.get(queue_name) if producer is None: return SyftError( - message=f"No producer attached for queue: {queue_name}. Please add a producer for it." + message=f"No producer attached for queue: {queue_name}. Please add a producer for it.", ) try: producer.send(message=message, worker=worker) except Exception as e: # stdlib return SyftError( - message=f"Failed to send message to: {queue_name} with error: {e}" + message=f"Failed to send message to: {queue_name} with error: {e}", ) return SyftSuccess( message=f"Successfully queued message to : {queue_name}", diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 3ec7c5ef184..43f8b72d03b 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -1,40 +1,37 @@ # stdlib -from collections.abc import Callable -from enum import Enum import hashlib import inspect import logging +from collections.abc import Callable +from enum import Enum from typing import Any # third party from pydantic import model_validator -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result from typing_extensions import Self # relative from ...abstract_server import ServerSideType from ...client.api import APIRegistry from ...client.client import SyftClient -from ...custom_worker.config import DockerWorkerConfig -from ...custom_worker.config import WorkerConfig +from ...custom_worker.config import DockerWorkerConfig, WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable from ...serde.serialize import _serialize from ...server.credentials import SyftVerifyKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject from ...types.syncable_object import SyncableSyftObject -from ...types.transforms import TransformContext -from ...types.transforms import add_server_uid_for_key -from ...types.transforms import generate_id -from ...types.transforms import transform +from ...types.transforms import ( + TransformContext, + add_server_uid_for_key, + generate_id, + transform, +) from ...types.twin_object import TwinObject -from ...types.uid import LineageID -from ...types.uid import UID +from ...types.uid import UID, LineageID from ...util import options from ...util.colors import SURFACE from ...util.decorators import deprecated @@ -43,19 +40,13 @@ from ...util.util import prompt_warning_message from ..action.action_object import ActionObject from ..action.action_service import ActionService -from ..action.action_store import ActionObjectPermission -from ..action.action_store import ActionPermission +from ..action.action_store import ActionObjectPermission, ActionPermission from ..blob_storage.service import BlobStorageService -from ..code.user_code import UserCode -from ..code.user_code import UserCodeStatus -from ..code.user_code import UserCodeStatusCollection -from ..context import AuthedServiceContext -from ..context import ChangeContext -from ..job.job_stash import Job -from ..job.job_stash import JobStatus +from ..code.user_code import UserCode, UserCodeStatus, UserCodeStatusCollection +from ..context import AuthedServiceContext, ChangeContext +from ..job.job_stash import Job, JobStatus from ..notification.notifications import Notification -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from ..user.user import UserView logger = logging.getLogger(__name__) @@ -113,7 +104,7 @@ class ActionStoreChange(Change): __repr_attrs__ = ["linked_obj", "apply_permission_type"] def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: try: action_service: ActionService = context.server.get_service(ActionService) # type: ignore[assignment] @@ -165,12 +156,12 @@ def _run( ) if apply: logger.debug( - "ADDING PERMISSION", requesting_permission_action_obj, id_action + "ADDING PERMISSION", requesting_permission_action_obj, id_action, ) action_store.add_permission(requesting_permission_action_obj) ( blob_storage_service.stash.add_permission( - requesting_permission_blob_obj + requesting_permission_blob_obj, ) if requesting_permission_blob_obj else None @@ -181,17 +172,17 @@ def _run( if ( requesting_permission_blob_obj and blob_storage_service.stash.has_permission( - requesting_permission_blob_obj + requesting_permission_blob_obj, ) ): blob_storage_service.stash.remove_permission( - requesting_permission_blob_obj + requesting_permission_blob_obj, ) else: return Err( SyftError( - message=f"No permission for approving_user_credentials {context.approving_user_credentials}" - ) + message=f"No permission for approving_user_credentials {context.approving_user_credentials}", + ), ) return Ok(SyftSuccess(message=f"{type(self)} Success")) except Exception as e: @@ -228,21 +219,21 @@ def _tag_required_for_dockerworkerconfig(self) -> Self: return self def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: try: worker_image_service = context.server.get_service("SyftWorkerImageService") service_context = context.to_service_ctx() result = worker_image_service.submit( - service_context, worker_config=self.config + service_context, worker_config=self.config, ) if isinstance(result, SyftError): return Err(result) result = worker_image_service.stash.get_by_worker_config( - service_context.credentials, config=self.config + service_context.credentials, config=self.config, ) if result.is_err(): @@ -268,7 +259,7 @@ def _run( build_success_message = build_result.message build_success = SyftSuccess( - message=f"Build result: {build_success_message}" + message=f"Build result: {build_success_message}", ) if IN_KUBERNETES and not worker_image.is_prebuilt: @@ -284,8 +275,8 @@ def _run( return Ok( SyftSuccess( - message=f"{build_success}\nPush result: {push_result.message}" - ) + message=f"{build_success}\nPush result: {push_result.message}", + ), ) return Ok(build_success) @@ -318,10 +309,9 @@ class CreateCustomWorkerPoolChange(Change): __repr_attrs__ = ["pool_name", "num_workers", "image_uid"] def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: - """ - This function is run when the DO approves (apply=True) + """This function is run when the DO approves (apply=True) or deny (apply=False) the request. """ # TODO: refactor the returned Err(SyftError) or Ok(SyftSuccess) to just @@ -333,7 +323,7 @@ def _run( if self.config is not None: result = worker_pool_service.image_stash.get_by_worker_config( - service_context.credentials, self.config + service_context.credentials, self.config, ) if result.is_err(): return Err(SyftError(message=f"{result.err()}")) @@ -357,8 +347,8 @@ def _run( else: return Err( SyftError( - message=f"Request to create a worker pool with name {self.name} denied" - ) + message=f"Request to create a worker pool with name {self.name} denied", + ), ) def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: @@ -440,7 +430,7 @@ def _repr_html_(self) -> Any: if self.code and len(self.code.output_readers) > 0: # owner_names = ["canada", "US"] owners_string = " and ".join( - [f"{x}" for x in self.code.output_reader_names] # type: ignore + [f"{x}" for x in self.code.output_reader_names], # type: ignore ) shared_with_line += ( f"

Custom Policy: " @@ -525,7 +515,7 @@ def code_id(self) -> UID: if isinstance(change, UserCodeStatusChange): return change.linked_user_code.object_uid return SyftError( - message="This type of request does not have code associated with it." + message="This type of request does not have code associated with it.", ) @property @@ -534,7 +524,7 @@ def codes(self) -> Any: if isinstance(change, UserCodeStatusChange): return change.codes return SyftError( - message="This type of request does not have code associated with it." + message="This type of request does not have code associated with it.", ) def get_user_code(self, context: AuthedServiceContext) -> UserCode | None: @@ -549,7 +539,7 @@ def code(self) -> UserCode | SyftError: if isinstance(change, UserCodeStatusChange): return change.code return SyftError( - message="This type of request does not have code associated with it." + message="This type of request does not have code associated with it.", ) @property @@ -602,7 +592,7 @@ def approve( if self.is_l0_deployment: return SyftError( - message="This request is a low-side request. Please sync your results to approve." + message="This request is a low-side request. Please sync your results to approve.", ) # TODO: Refactor so that object can also be passed to generate warnings if api.connection: @@ -615,7 +605,7 @@ def approve( if is_code_request and len(self.codes) > 1 and not approve_nested: return SyftError( - message="Multiple codes detected, please use approve_nested=True" + message="Multiple codes detected, please use approve_nested=True", ) if metadata and metadata.server_side_type == ServerSideType.HIGH_SIDE.value: @@ -640,7 +630,9 @@ def deny(self, reason: str) -> SyftSuccess | SyftError: """Denies the particular request. Args: + ---- reason (str): Reason for which the request has been denied. + """ api = self._get_api() if isinstance(api, SyftError): @@ -650,7 +642,7 @@ def deny(self, reason: str) -> SyftSuccess | SyftError: if self.status == RequestStatus.APPROVED: prompt_warning_message( "This request already has results published to the data scientist. " - "They will still be able to access those results." + "They will still be able to access those results.", ) result = api.code.update(id=self.code_id, l0_deny_reason=reason) if isinstance(result, SyftError): @@ -673,7 +665,7 @@ def get_is_l0_deployment(self, context: AuthedServiceContext) -> bool: def approve_with_client(self, client: SyftClient) -> Result[SyftSuccess, SyftError]: if self.is_l0_deployment: return SyftError( - message="This request is a low-side request. Please sync your results to approve." + message="This request is a low-side request. Please sync your results to approve.", ) print(f"Approving request for datasite {client.name}") @@ -755,7 +747,7 @@ def _create_action_object_for_deposited_result( existing_job = api.services.job.get_by_result_id(result.id.id) if existing_job is not None: return SyftError( - message=f"This ActionObject is already the result of Job {existing_job.id}" + message=f"This ActionObject is already the result of Job {existing_job.id}", ) action_object = result else: @@ -767,7 +759,7 @@ def _create_action_object_for_deposited_result( # Ensure ActionObject exists on this server action_object_is_from_this_server = isinstance( - api.services.action.exists(action_object.id.id), SyftSuccess + api.services.action.exists(action_object.id.id), SyftSuccess, ) if ( action_object.syft_blob_storage_entry_id is None @@ -781,7 +773,7 @@ def _create_action_object_for_deposited_result( return action_object def _create_output_history_for_deposited_result( - self, job: Job, result: Any + self, job: Job, result: Any, ) -> SyftSuccess | SyftError: code = self.code if isinstance(code, SyftError): @@ -810,21 +802,22 @@ def deposit_result( log_stdout: str = "", log_stderr: str = "", ) -> Job | SyftError: - """ - Adds a result to this Request: + """Adds a result to this Request: - Create an ActionObject from the result (if not already an ActionObject) - Ensure ActionObject exists on this server - Create Job with new result and logs - Update the output history Args: + ---- result (Any): ActionObject or any object to be saved as an ActionObject. logs (str | None, optional): Optional logs to be saved with the Job. Defaults to None. Returns: + ------- Job | SyftError: Job object if successful, else SyftError. - """ + """ # TODO check if this is a low-side request. If not, SyftError api = self._get_api() @@ -837,7 +830,7 @@ def deposit_result( if not self.is_l0_deployment: return SyftError( message="deposit_result is only available for low side code requests. " - "Please use request.approve() instead." + "Please use request.approve() instead.", ) # Create ActionObject @@ -874,7 +867,7 @@ def accept_by_depositing_result(self, result: Any, force: bool = False) -> Any: pass def get_sync_dependencies( - self, context: AuthedServiceContext + self, context: AuthedServiceContext, ) -> list[UID] | SyftError: dependencies = [] code_id = self.code_id @@ -923,7 +916,7 @@ def hash_changes(context: TransformContext) -> TransformContext: changes = context.output["changes"] time_hash = hashlib.sha256( - _serialize(request_time.utc_timestamp, to_bytes=True) + _serialize(request_time.utc_timestamp, to_bytes=True), ).digest() key_hash = hashlib.sha256(bytes(key.verify_key)).digest() changes_hash = hashlib.sha256(_serialize(changes, to_bytes=True)).digest() @@ -943,7 +936,7 @@ def add_request_time(context: TransformContext) -> TransformContext: def check_requesting_user_verify_key(context: TransformContext) -> TransformContext: if context.output and context.server and context.obj: if context.obj.requesting_user_verify_key and context.server.is_root( - context.credentials + context.credentials, ): context.output["requesting_user_verify_key"] = ( context.obj.requesting_user_verify_key @@ -1013,7 +1006,7 @@ def __repr_syft_nested__(self) -> str: return f"Mutate {self.attr_name} to {self.value}" def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: if self.linked_obj is None: return Err(SyftError(message=f"{self}'s linked object is None")) @@ -1069,13 +1062,13 @@ class EnumMutation(ObjectMutation): def valid(self) -> SyftSuccess | SyftError: if self.match_type and not isinstance(self.value, self.enum_type): return SyftError( - message=f"{type(self.value)} must be of type: {self.enum_type}" + message=f"{type(self.value)} must be of type: {self.enum_type}", ) return SyftSuccess(message=f"{type(self)} valid") @staticmethod def from_obj( - linked_obj: LinkedObject, attr_name: str, value: Enum | None = None + linked_obj: LinkedObject, attr_name: str, value: Enum | None = None, ) -> "EnumMutation": enum_type = type_for_field(linked_obj.object_type, attr_name) return EnumMutation( @@ -1087,7 +1080,7 @@ def from_obj( ) def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: try: valid = self.valid @@ -1185,13 +1178,12 @@ def __repr_syft_nested__(self) -> str: msg += "to permission RequestStatus.APPROVED." if self.code.nested_codes is None or self.code.nested_codes == {}: # type: ignore msg += " No nested requests" + elif self.nested_solved: + # else: + msg += "

This change requests the following nested functions calls:
" + msg += self.nested_repr() else: - if self.nested_solved: - # else: - msg += "

This change requests the following nested functions calls:
" - msg += self.nested_repr() - else: - msg += " Nested Requests not resolved" + msg += " Nested Requests not resolved" return msg def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: @@ -1223,7 +1215,7 @@ def valid(self) -> SyftSuccess | SyftError: if self.match_type and not isinstance(self.value, UserCodeStatus): # TODO: fix the mypy issue return SyftError( # type: ignore[unreachable] - message=f"{type(self.value)} must be of type: {UserCodeStatus}" + message=f"{type(self.value)} must be of type: {UserCodeStatus}", ) return SyftSuccess(message=f"{type(self)} valid") @@ -1254,7 +1246,7 @@ def mutate( return res def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: try: valid = self.valid @@ -1318,16 +1310,16 @@ def mutate( undo: bool, ) -> UserCodeStatusCollection | SyftError: return SyftError( - message="Synced UserCodes status is computed, and cannot be updated manually." + message="Synced UserCodes status is computed, and cannot be updated manually.", ) def _run( - self, context: ChangeContext, apply: bool + self, context: ChangeContext, apply: bool, ) -> Result[SyftSuccess, SyftError]: return Ok( SyftError( - message="Synced UserCodes status is computed, and cannot be updated manually." - ) + message="Synced UserCodes status is computed, and cannot be updated manually.", + ), ) def link(self) -> Any: # type: ignore diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index db152f98597..605efc27216 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -9,29 +9,30 @@ from ...types.uid import UID from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..notification.email_templates import RequestEmailTemplate -from ..notification.email_templates import RequestUpdateEmailTemplate -from ..notification.notification_service import CreateNotification -from ..notification.notification_service import NotificationService +from ..notification.email_templates import ( + RequestEmailTemplate, + RequestUpdateEmailTemplate, +) +from ..notification.notification_service import CreateNotification, NotificationService from ..notification.notifications import Notification from ..notifier.notifier_enums import NOTIFIERS -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import SERVICE_TO_TYPES -from ..service import TYPE_TO_SERVICE -from ..service import service_method +from ..response import SyftError, SyftSuccess +from ..service import SERVICE_TO_TYPES, TYPE_TO_SERVICE, AbstractService, service_method from ..user.user import UserView -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL +from ..user.user_roles import ( + ADMIN_ROLE_LEVEL, + DATA_SCIENTIST_ROLE_LEVEL, + GUEST_ROLE_LEVEL, +) from ..user.user_service import UserService -from .request import Change -from .request import Request -from .request import RequestInfo -from .request import RequestInfoFilter -from .request import RequestStatus -from .request import SubmitRequest +from .request import ( + Change, + Request, + RequestInfo, + RequestInfoFilter, + RequestStatus, + SubmitRequest, +) from .request_stash import RequestStash @@ -62,7 +63,7 @@ def submit( link = LinkedObject.with_context(request, context=context) admin_verify_key = context.server.get_service_method( - UserService.admin_verify_key + UserService.admin_verify_key, ) root_verify_key = admin_verify_key() @@ -83,7 +84,7 @@ def submit( return request else: return SyftError( - message=f"Failed to send notification: {result.err()}" + message=f"Failed to send notification: {result.err()}", ) return request @@ -96,10 +97,10 @@ def submit( raise e @service_method( - path="request.get_by_uid", name="get_by_uid", roles=DATA_SCIENTIST_ROLE_LEVEL + path="request.get_by_uid", name="get_by_uid", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> Request | None | SyftError: result = self.stash.get_by_uid(context.credentials, uid) if result.is_err(): @@ -107,7 +108,7 @@ def get_by_uid( return result.ok() @service_method( - path="request.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL + path="request.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL, ) def get_all(self, context: AuthedServiceContext) -> list[Request] | SyftError: result = self.stash.get_all(context.credentials) @@ -125,14 +126,13 @@ def get_all_info( page_size: int | None = 0, ) -> list[list[RequestInfo]] | list[RequestInfo] | SyftError: """Get the information of all requests""" - result = self.stash.get_all(context.credentials) if result.is_err(): return SyftError(message=result.err()) method = context.server.get_service_method(UserService.get_by_verify_key) get_message = context.server.get_service_method( - NotificationService.filter_by_obj + NotificationService.filter_by_obj, ) requests: list[RequestInfo] = [] @@ -154,13 +154,13 @@ def get_all_info( @service_method(path="request.add_changes", name="add_changes") def add_changes( - self, context: AuthedServiceContext, uid: UID, changes: list[Change] + self, context: AuthedServiceContext, uid: UID, changes: list[Change], ) -> Request | SyftError: result = self.stash.get_by_uid(credentials=context.credentials, uid=uid) if result.is_err(): return SyftError( - message=f"Failed to retrieve request with uid: {uid}. Error: {result.err()}" + message=f"Failed to retrieve request with uid: {uid}. Error: {result.err()}", ) request = result.ok() @@ -178,7 +178,7 @@ def filter_all_info( """Get a Dataset""" result = self.get_all_info(context) requests = list( - filter(lambda res: (request_filter.name in res.user.name), result) + filter(lambda res: (request_filter.name in res.user.name), result), ) # If chunk size is defined, then split list into evenly sized chunks @@ -210,17 +210,17 @@ def apply( result = request.apply(context=context) filter_by_obj = context.server.get_service_method( - NotificationService.filter_by_obj + NotificationService.filter_by_obj, ) request_notification = filter_by_obj(context=context, obj_uid=uid) link = LinkedObject.with_context(request, context=context) - if not request.get_status(context) == RequestStatus.PENDING: + if request.get_status(context) != RequestStatus.PENDING: if request_notification is not None and not isinstance( - request_notification, SyftError + request_notification, SyftError, ): mark_as_read = context.server.get_service_method( - NotificationService.mark_as_read + NotificationService.mark_as_read, ) mark_as_read(context=context, uid=request_notification.id) @@ -233,7 +233,7 @@ def apply( email_template=RequestUpdateEmailTemplate, ) send_notification = context.server.get_service_method( - NotificationService.send + NotificationService.send, ) send_notification(context=context, notification=notification) @@ -245,12 +245,12 @@ def apply( @service_method(path="request.undo", name="undo") def undo( - self, context: AuthedServiceContext, uid: UID, reason: str + self, context: AuthedServiceContext, uid: UID, reason: str, ) -> SyftSuccess | SyftError: result = self.stash.get_by_uid(credentials=context.credentials, uid=uid) if result.is_err(): return SyftError( - message=f"Failed to update request: {uid} with error: {result.err()}" + message=f"Failed to update request: {uid} with error: {result.err()}", ) request = result.ok() @@ -262,7 +262,7 @@ def undo( if result.is_err(): return SyftError( - message=f"Failed to undo Request: <{uid}> with error: {result.err()}" + message=f"Failed to undo Request: <{uid}> with error: {result.err()}", ) link = LinkedObject.with_context(request, context=context) @@ -283,13 +283,13 @@ def undo( return SyftSuccess(message=f"Request {uid} successfully denied !") def save( - self, context: AuthedServiceContext, request: Request + self, context: AuthedServiceContext, request: Request, ) -> Request | SyftError: result = self.stash.update(context.credentials, request) if result.is_ok(): return result.ok() return SyftError( - message=f"Failed to update Request: <{request.id}>. Error: {result.err()}" + message=f"Failed to update Request: <{request.id}>. Error: {result.err()}", ) @service_method( @@ -297,7 +297,7 @@ def save( name="delete_by_uid", ) def delete_by_uid( - self, context: AuthedServiceContext, uid: UID + self, context: AuthedServiceContext, uid: UID, ) -> SyftSuccess | SyftError: """Delete the request with the given uid.""" result = self.stash.delete_by_uid(context.credentials, uid) @@ -328,7 +328,7 @@ def set_tags( @service_method(path="request.get_by_usercode_id", name="get_by_usercode_id") def get_by_usercode_id( - self, context: AuthedServiceContext, usercode_id: UID + self, context: AuthedServiceContext, usercode_id: UID, ) -> list[Request] | SyftError: result = self.stash.get_by_usercode_id(context.credentials, usercode_id) if result.is_err(): diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index b56bd6932e1..0d64c20943f 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -1,23 +1,24 @@ # stdlib # third party -from result import Ok -from result import Result +from result import Ok, Result # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.document_store import ( + BaseUIDStoreStash, + PartitionKey, + PartitionSettings, + QueryKeys, +) from ...types.datetime import DateTime from ...types.uid import UID from ...util.telemetry import instrument from .request import Request RequestingUserVerifyKeyPartitionKey = PartitionKey( - key="requesting_user_verify_key", type_=SyftVerifyKey + key="requesting_user_verify_key", type_=SyftVerifyKey, ) OrderByRequestTimeStampPartitionKey = PartitionKey(key="request_time", type_=DateTime) @@ -28,7 +29,7 @@ class RequestStash(BaseUIDStoreStash): object_type = Request settings: PartitionSettings = PartitionSettings( - name=Request.__canonical_name__, object_type=Request + name=Request.__canonical_name__, object_type=Request, ) def get_all_for_verify_key( @@ -46,7 +47,7 @@ def get_all_for_verify_key( ) def get_by_usercode_id( - self, credentials: SyftVerifyKey, user_code_id: UID + self, credentials: SyftVerifyKey, user_code_id: UID, ) -> Result[list[Request], str]: query = self.get_all(credentials=credentials) if query.is_err(): diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index ebecf9e2fcb..2ec9ba4552c 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -38,13 +38,13 @@ def __getattr__(self, name: str) -> Any: return super().__getattr__(name) display(self) raise AttributeError( - f"You have tried accessing `{name}` on a {type(self).__name__} with message: {self.message}" + f"You have tried accessing `{name}` on a {type(self).__name__} with message: {self.message}", ) def __bool__(self) -> bool: return self._bool - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, SyftResponseMessage): return ( self.message == other.message @@ -155,10 +155,10 @@ def format_traceback(etype: Any, evalue: Any, tb: Any, tb_offset: Any) -> str: def syft_exception_handler( - shell: Any, etype: Any, evalue: Any, tb: Any, tb_offset: Any = None + shell: Any, etype: Any, evalue: Any, tb: Any, tb_offset: Any = None, ) -> None: template = evalue.format_traceback( - etype=etype, evalue=evalue, tb=tb, tb_offset=tb_offset + etype=etype, evalue=evalue, tb=tb, tb_offset=tb_offset, ) sys.stderr.write(template) @@ -167,7 +167,7 @@ def syft_exception_handler( # third party from IPython import get_ipython - get_ipython().set_custom_exc((SyftException,), syft_exception_handler) # noqa: F821 + get_ipython().set_custom_exc((SyftException,), syft_exception_handler) except Exception: pass # nosec diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 76a61689eaa..cbdf311647b 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -1,49 +1,47 @@ # future from __future__ import annotations +import functools +import inspect +import logging + # stdlib from collections import defaultdict from collections.abc import Callable from copy import deepcopy -import functools from functools import partial -import inspect from inspect import Parameter -import logging -from typing import Any -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any # third party -from result import Ok -from result import OkErr +from result import Ok, OkErr from typing_extensions import Self # relative from ..abstract_server import AbstractServer from ..protocol.data_protocol import migrate_args_and_kwargs -from ..serde.lib_permissions import CMPCRUDPermission -from ..serde.lib_permissions import CMPPermission -from ..serde.lib_service_registry import CMPBase -from ..serde.lib_service_registry import CMPClass -from ..serde.lib_service_registry import CMPFunction -from ..serde.lib_service_registry import action_execute_registry_libs +from ..serde.lib_permissions import CMPCRUDPermission, CMPPermission +from ..serde.lib_service_registry import ( + CMPBase, + CMPClass, + CMPFunction, + action_execute_registry_libs, +) from ..serde.serializable import serializable -from ..serde.signature import Signature -from ..serde.signature import signature_remove_context -from ..serde.signature import signature_remove_self +from ..serde.signature import Signature, signature_remove_context, signature_remove_self from ..server.credentials import SyftVerifyKey from ..store.document_store import DocumentStore from ..store.linked_obj import LinkedObject -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftObject -from ..types.syft_object import attach_attribute_to_syft_object +from ..types.syft_object import ( + SYFT_OBJECT_VERSION_1, + SyftBaseObject, + SyftObject, + attach_attribute_to_syft_object, +) from ..types.uid import UID -from .context import AuthedServiceContext -from .context import ChangeContext +from .context import AuthedServiceContext, ChangeContext from .response import SyftError -from .user.user_roles import DATA_OWNER_ROLE_LEVEL -from .user.user_roles import ServiceRole +from .user.user_roles import DATA_OWNER_ROLE_LEVEL, ServiceRole from .warnings import APIEndpointWarning logger = logging.getLogger(__name__) @@ -175,7 +173,7 @@ def from_user(cls, credentials: SyftVerifyKey) -> Self: k: lib_config for k, lib_config in LibConfigRegistry.get_registered_configs().items() if lib_config.has_permission(credentials) - } + }, ) def __contains__(self, path: str) -> bool: @@ -201,7 +199,7 @@ def from_role(cls, user_service_role: ServiceRole) -> Self: k: service_config for k, service_config in ServiceConfigRegistry.get_registered_configs().items() if service_config.has_permission(user_service_role) - } + }, ) def __contains__(self, path: str) -> bool: @@ -250,7 +248,7 @@ def deconstruct_param(param: inspect.Parameter) -> dict[str, Any]: param_type = param.annotation if not hasattr(param_type, "__signature__"): raise Exception( - f"Type {param_type} needs __signature__. Or code changed to support backup init" + f"Type {param_type} needs __signature__. Or code changed to support backup init", ) signature = param_type.__signature__ sub_mapping = {} @@ -326,10 +324,7 @@ def expand_signature(signature: Signature, autosplat: list[str]) -> Signature: ) return Signature( - **{ - "parameters": new_params, - "return_annotation": signature.return_annotation, - } + parameters=new_params, return_annotation=signature.return_annotation, ) @@ -360,7 +355,7 @@ def _decorator(self: Any, *args: Any, **kwargs: Any) -> Callable: if communication_protocol: args, kwargs = migrate_args_and_kwargs( - args=args, kwargs=kwargs, to_latest_protocol=True + args=args, kwargs=kwargs, to_latest_protocol=True, ) if autosplat is not None and len(autosplat) > 0: args, kwargs = reconstruct_args_kwargs( @@ -439,7 +434,7 @@ def add_transform( @classmethod def get_transform( - cls, type_from: type[SyftObject], type_to: type[SyftObject] + cls, type_from: type[SyftObject], type_to: type[SyftObject], ) -> Callable: klass_from = type_from.__canonical_name__ version_from = type_from.__version__ @@ -485,11 +480,11 @@ def from_api_or_context( if func_or_path not in user_config_registry: if ServiceConfigRegistry.path_exists(func_or_path): return SyftError( - message=f"As a `{server_context.role}` you have has no access to: {func_or_path}" + message=f"As a `{server_context.role}` you have has no access to: {func_or_path}", ) else: return SyftError( - message=f"API call not in registered services: {func_or_path}" + message=f"API call not in registered services: {func_or_path}", ) _private_api_path = user_config_registry.private_path_for(func_or_path) diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index 67720658c80..7332986da30 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -1,32 +1,30 @@ # stdlib -from collections.abc import Callable import logging +from collections.abc import Callable from typing import Any # third party -from pydantic import field_validator -from pydantic import model_validator +from pydantic import field_validator, model_validator from typing_extensions import Self # relative -from ...abstract_server import ServerSideType -from ...abstract_server import ServerType +from ...abstract_server import ServerSideType, ServerType from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...service.worker.utils import DEFAULT_WORKER_POOL_NAME from ...types.syft_migration import migrate -from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SYFT_OBJECT_VERSION_3 -from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default +from ...types.syft_object import ( + SYFT_OBJECT_VERSION_1, + SYFT_OBJECT_VERSION_2, + SYFT_OBJECT_VERSION_3, + PartialSyftObject, + SyftObject, +) +from ...types.transforms import drop, make_set_default from ...types.uid import UID from ...util import options from ...util.colors import SURFACE -from ...util.misc_objs import HTMLObject -from ...util.misc_objs import MarkdownDescription +from ...util.misc_objs import HTMLObject, MarkdownDescription from ...util.schema import DEFAULT_WELCOME_MSG logger = logging.getLogger(__name__) @@ -45,7 +43,7 @@ class PwdTokenResetConfig(SyftObject): def validate_char_types(self) -> Self: if not self.ascii and not self.numbers: raise ValueError( - "Invalid config, at least one of the ascii/number options must be true." + "Invalid config, at least one of the ascii/number options must be true.", ) return self @@ -138,7 +136,7 @@ class ServerSettingsV1(SyftObject): eager_execution_enabled: bool = False default_worker_pool: str = DEFAULT_WORKER_POOL_NAME welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( - text=DEFAULT_WELCOME_MSG + text=DEFAULT_WELCOME_MSG, ) @@ -171,7 +169,7 @@ class ServerSettingsV2(SyftObject): eager_execution_enabled: bool = False default_worker_pool: str = DEFAULT_WORKER_POOL_NAME welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( - text=DEFAULT_WELCOME_MSG + text=DEFAULT_WELCOME_MSG, ) notifications_enabled: bool @@ -205,7 +203,7 @@ class ServerSettings(SyftObject): eager_execution_enabled: bool = False default_worker_pool: str = DEFAULT_WORKER_POOL_NAME welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( - text=DEFAULT_WELCOME_MSG + text=DEFAULT_WELCOME_MSG, ) notifications_enabled: bool pwd_token_config: PwdTokenResetConfig = PwdTokenResetConfig() diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index b54abacd078..2f3d4d4f947 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -2,9 +2,7 @@ from string import Template # third party -from result import Err -from result import Ok -from result import Result +from result import Err, Ok, Result # relative from ...abstract_server import ServerSideType @@ -12,25 +10,16 @@ from ...store.document_store import DocumentStore from ...util.assets import load_png_base64 from ...util.experimental_flags import flags -from ...util.misc_objs import HTMLObject -from ...util.misc_objs import MarkdownDescription +from ...util.misc_objs import HTMLObject, MarkdownDescription from ...util.notebook_ui.styles import FONT_CSS -from ...util.schema import DO_COMMANDS -from ...util.schema import DS_COMMANDS -from ...util.schema import GUEST_COMMANDS -from ..context import AuthedServiceContext -from ..context import UnauthedServiceContext +from ...util.schema import DO_COMMANDS, DS_COMMANDS, GUEST_COMMANDS +from ..context import AuthedServiceContext, UnauthedServiceContext from ..notifier.notifier_enums import EMAIL_TYPES -from ..response import SyftError -from ..response import SyftSuccess -from ..service import AbstractService -from ..service import service_method -from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import GUEST_ROLE_LEVEL -from ..user.user_roles import ServiceRole +from ..response import SyftError, SyftSuccess +from ..service import AbstractService, service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL, GUEST_ROLE_LEVEL, ServiceRole from ..warnings import HighSideCRUDWarning -from .settings import ServerSettings -from .settings import ServerSettingsUpdate +from .settings import ServerSettings, ServerSettingsUpdate from .settings_stash import SettingsStash @@ -46,7 +35,6 @@ def __init__(self, store: DocumentStore) -> None: @service_method(path="settings.get", name="get") def get(self, context: UnauthedServiceContext) -> Result[Ok, Err]: """Get Settings""" - result = self.stash.get_all(context.server.signing_key.verify_key) if result.is_ok(): settings = result.ok() @@ -60,7 +48,7 @@ def get(self, context: UnauthedServiceContext) -> Result[Ok, Err]: @service_method(path="settings.set", name="set") def set( - self, context: AuthedServiceContext, settings: ServerSettings + self, context: AuthedServiceContext, settings: ServerSettings, ) -> Result[Ok, Err]: """Set a new the Server Settings""" result = self.stash.set(context.credentials, settings) @@ -76,7 +64,7 @@ def set( roles=ADMIN_ROLE_LEVEL, ) def update( - self, context: AuthedServiceContext, settings: ServerSettingsUpdate + self, context: AuthedServiceContext, settings: ServerSettingsUpdate, ) -> Result[SyftSuccess, SyftError]: res = self._update(context, settings) if res.is_ok(): @@ -84,18 +72,18 @@ def update( message=( "Settings updated successfully. " + "You must call .refresh() to sync your client with the changes." - ) + ), ) else: return SyftError(message=res.err()) def _update( - self, context: AuthedServiceContext, settings: ServerSettingsUpdate + self, context: AuthedServiceContext, settings: ServerSettingsUpdate, ) -> Result[Ok, Err]: - """ - Update the Server Settings using the provided values. + """Update the Server Settings using the provided values. Args: + ---- name: Optional[str] Server name organization: Optional[str] @@ -111,18 +99,21 @@ def _update( association_request_auto_approval: Optional[bool] Returns: + ------- Result[SyftSuccess, SyftError]: A result indicating the success or failure of the update operation. Example: + ------- >>> server_client.update(name='foo', organization='bar', description='baz', signup_enabled=True) SyftSuccess: Settings updated successfully. + """ result = self.stash.get_all(context.credentials) if result.is_ok(): current_settings = result.ok() if len(current_settings) > 0: new_settings = current_settings[0].model_copy( - update=settings.to_dict(exclude_empty=True) + update=settings.to_dict(exclude_empty=True), ) notifier_service = context.server.get_service("notifierservice") @@ -130,14 +121,14 @@ def _update( if settings.notifications_enabled is True: if not notifier_service.settings(context): return SyftError( - message="Create notification settings using enable_notifications from user_service" + message="Create notification settings using enable_notifications from user_service", ) notifier_service = context.server.get_service("notifierservice") result = notifier_service.set_notifier_active_to_true(context) elif settings.notifications_enabled is False: if not notifier_service.settings(context): return SyftError( - message="Create notification settings using enable_notifications from user_service" + message="Create notification settings using enable_notifications from user_service", ) notifier_service = context.server.get_service("notifierservice") result = notifier_service.set_notifier_active_to_false(context) @@ -154,12 +145,12 @@ def _update( roles=ADMIN_ROLE_LEVEL, ) def set_server_side_type_dangerous( - self, context: AuthedServiceContext, server_side_type: str + self, context: AuthedServiceContext, server_side_type: str, ) -> Result[SyftSuccess, SyftError]: side_type_options = [e.value for e in ServerSideType] if server_side_type not in side_type_options: return SyftError( - message=f"Not a valid server_side_type, please use one of the options from: {side_type_options}" + message=f"Not a valid server_side_type, please use one of the options from: {side_type_options}", ) result = self.stash.get_all(context.credentials) @@ -174,7 +165,7 @@ def set_server_side_type_dangerous( message=( "Settings updated successfully. " + "You must call .refresh() to sync your client with the changes." - ) + ), ) else: return SyftError(message=update_result.err()) @@ -225,7 +216,7 @@ def disable_notifications( warning=HighSideCRUDWarning(confirmation=True), ) def allow_guest_signup( - self, context: AuthedServiceContext, enable: bool + self, context: AuthedServiceContext, enable: bool, ) -> SyftSuccess | SyftError: """Enable/Disable Registration for Data Scientist or Guest Users.""" flags.CAN_REGISTER = enable @@ -247,7 +238,7 @@ def allow_guest_signup( # warning=HighSideCRUDWarning(confirmation=True), # ) def enable_eager_execution( - self, context: AuthedServiceContext, enable: bool + self, context: AuthedServiceContext, enable: bool, ) -> SyftSuccess | SyftError: """Enable/Disable eager execution.""" settings = ServerSettingsUpdate(eager_execution_enabled=enable) @@ -262,7 +253,7 @@ def enable_eager_execution( @service_method(path="settings.set_email_rate_limit", name="set_email_rate_limit") def set_email_rate_limit( - self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int + self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int, ) -> SyftSuccess | SyftError: notifier_service = context.server.get_service("notifierservice") return notifier_service.set_email_rate_limit(context, email_type, daily_limit) @@ -272,7 +263,7 @@ def set_email_rate_limit( name="allow_association_request_auto_approval", ) def allow_association_request_auto_approval( - self, context: AuthedServiceContext, enable: bool + self, context: AuthedServiceContext, enable: bool, ) -> SyftSuccess | SyftError: new_settings = ServerSettingsUpdate(association_request_auto_approval=enable) result = self._update(context, settings=new_settings) @@ -281,7 +272,7 @@ def allow_association_request_auto_approval( message = "enabled" if enable else "disabled" return SyftSuccess( - message="Association request auto-approval successfully " + message + message="Association request auto-approval successfully " + message, ) @service_method( @@ -296,7 +287,7 @@ def welcome_preview( ) -> MarkdownDescription | HTMLObject | SyftError: if not markdown and not html or markdown and html: return SyftError( - message="Invalid markdown/html fields. You must set one of them." + message="Invalid markdown/html fields. You must set one of them.", ) welcome_msg = None @@ -319,7 +310,7 @@ def welcome_customize( ) -> SyftSuccess | SyftError: if not markdown and not html or markdown and html: return SyftError( - message="Invalid markdown/html fields. You must set one of them." + message="Invalid markdown/html fields. You must set one of them.", ) welcome_msg = None diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 52c134274f7..443a18f9b51 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -6,10 +6,12 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings +from ...store.document_store import ( + BaseUIDStoreStash, + DocumentStore, + PartitionKey, + PartitionSettings, +) from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission @@ -24,7 +26,7 @@ class SettingsStash(BaseUIDStoreStash): object_type = ServerSettings settings: PartitionSettings = PartitionSettings( - name=ServerSettings.__canonical_name__, object_type=ServerSettings + name=ServerSettings.__canonical_name__, object_type=ServerSettings, ) def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 9465f6f10f9..d393e4b4dc4 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,23 +1,17 @@ # stdlib -from collections.abc import Callable -from collections.abc import Collection -from collections.abc import Iterable -from dataclasses import dataclass import enum import html import logging import operator import textwrap -from typing import Any -from typing import ClassVar -from typing import Literal -from typing import TYPE_CHECKING +from collections.abc import Callable, Collection, Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Literal # third party import pandas as pd from rich import box -from rich.console import Console -from rich.console import Group +from rich.console import Console, Group from rich.markdown import Markdown from rich.padding import Padding from rich.panel import Panel @@ -26,37 +20,30 @@ # relative from ...client.api import APIRegistry from ...client.client import SyftClient -from ...client.sync_decision import SyncDecision -from ...client.sync_decision import SyncDirection +from ...client.sync_decision import SyncDecision, SyncDirection from ...server.credentials import SyftVerifyKey from ...types.datetime import DateTime -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SyftObject -from ...types.syft_object import short_uid +from ...types.syft_object import SYFT_OBJECT_VERSION_1, SyftObject, short_uid from ...types.syncable_object import SyncableSyftObject -from ...types.uid import LineageID -from ...types.uid import UID +from ...types.uid import UID, LineageID from ...util import options from ...util.colors import SURFACE -from ...util.notebook_ui.components.sync import Label -from ...util.notebook_ui.components.sync import SyncTableObject +from ...util.notebook_ui.components.sync import Label, SyncTableObject from ...util.notebook_ui.icons import Icon -from ...util.notebook_ui.styles import FONT_CSS -from ...util.notebook_ui.styles import ITABLES_CSS +from ...util.notebook_ui.styles import FONT_CSS, ITABLES_CSS from ..action.action_object import ActionObject -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionPermission -from ..action.action_permissions import StoragePermission +from ..action.action_permissions import ( + ActionObjectPermission, + ActionPermission, + StoragePermission, +) from ..api.api import TwinAPIEndpoint -from ..code.user_code import UserCode -from ..code.user_code import UserCodeStatusCollection -from ..job.job_stash import Job -from ..job.job_stash import JobType +from ..code.user_code import UserCode, UserCodeStatusCollection +from ..job.job_stash import Job, JobType from ..log.log import SyftLog from ..output.output_service import ExecutionOutput from ..request.request import Request -from ..response import SyftError -from ..response import SyftSuccess +from ..response import SyftError, SyftSuccess from ..user.user import UserView from .sync_state import SyncState @@ -64,8 +51,7 @@ if TYPE_CHECKING: # relative - from .resolve_widget import PaginatedResolveWidget - from .resolve_widget import ResolveWidget + from .resolve_widget import PaginatedResolveWidget, ResolveWidget sketchy_tab = "‎ " * 4 @@ -163,7 +149,7 @@ def recursive_attr_repr(value_attr: list | dict | bytes, num_tabs: int = 0) -> s elif isinstance(value_attr, dict): dict_repr = "{\n" for key, elem in value_attr.items(): - dict_repr += f"{sketchy_tab * new_num_tabs}{key}: {str(elem)}\n" + dict_repr += f"{sketchy_tab * new_num_tabs}{key}: {elem!s}\n" dict_repr += "}" return dict_repr @@ -175,7 +161,7 @@ def recursive_attr_repr(value_attr: list | dict | bytes, num_tabs: int = 0) -> s if isinstance(value_attr, UID): value_attr = short_uid(value_attr) # type: ignore - return f"{sketchy_tab*num_tabs}{str(value_attr)}" + return f"{sketchy_tab*num_tabs}{value_attr!s}" class ObjectDiff(SyftObject): # StateTuple (compare 2 objects) @@ -348,7 +334,7 @@ def repr_attr_diffstatus_dict(self) -> dict: if value_low is None or value_high is None: res[attr] = DiffStatus.NEW elif isinstance(value_low, pd.DataFrame) and isinstance( - value_high, pd.DataFrame + value_high, pd.DataFrame, ): res[attr] = ( DiffStatus.MODIFIED @@ -516,7 +502,7 @@ def _repr_html_(self) -> str: return base_str + attr_text def __repr__(self) -> str: - return f"{self.__class__.__name__}[{self.obj_type.__name__}](#{str(self.object_id)})" + return f"{self.__class__.__name__}[{self.obj_type.__name__}](#{self.object_id!s})" def _wrap_text(text: str, width: int, indent: int = 4) -> str: @@ -530,11 +516,11 @@ def _wrap_text(text: str, width: int, indent: int = 4) -> str: break_long_words=False, replace_whitespace=False, subsequent_indent=" " * indent, - ) + ), ) for line in text.splitlines() if line.strip() != "" - ] + ], ) @@ -689,7 +675,7 @@ def is_unchanged(self) -> bool: return self.status == "SAME" def get_dependents( - self, include_roots: bool = False, include_batch_root: bool = True + self, include_roots: bool = False, include_batch_root: bool = True, ) -> list[ObjectDiff]: return self.walk_graph( deps=self.dependents, @@ -747,14 +733,13 @@ def is_skipped(self) -> bool: def create_new_resolved_states( self, ) -> tuple["ResolvedSyncState", "ResolvedSyncState"]: - """ - Returns new ResolvedSyncState objects for the source and target servers + """Returns new ResolvedSyncState objects for the source and target servers """ resolved_state_low = ResolvedSyncState( - server_uid=self.low_server_uid, alias="low" + server_uid=self.low_server_uid, alias="low", ) resolved_state_high = ResolvedSyncState( - server_uid=self.high_server_uid, alias="high" + server_uid=self.high_server_uid, alias="high", ) # Return source, target @@ -778,7 +763,7 @@ def from_dependencies( sync_direction: SyncDirection, ) -> "ObjectDiffBatch": def _build_hierarchy_helper( - uid: UID, level: int = 0, visited: set | None = None + uid: UID, level: int = 0, visited: set | None = None, ) -> list: visited = visited if visited is not None else set() @@ -804,7 +789,7 @@ def _build_hierarchy_helper( uid=dep_uid, level=level + 1, visited=visited | set(deps) - {dep_uid}, - ) + ), ) return result @@ -847,8 +832,8 @@ def _repr_html_(self) -> str: except Exception as _: return SyftError( message=html.escape( - "Could not render batch, please use resolve() instead." - ) + "Could not render batch, please use resolve() instead.", + ), )._repr_html_() return f""" @@ -925,15 +910,15 @@ def __repr__(self) -> Any: except Exception as _: return SyftError( message=html.escape( - "Could not render batch, please use resolve() instead." - ) + "Could not render batch, please use resolve() instead.", + ), )._repr_html_() def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return "" # Turns off the _repr_markdown_ of SyftObject def _get_visual_hierarchy( - self, server: ObjectDiff, visited: set[UID] | None = None + self, server: ObjectDiff, visited: set[UID] | None = None, ) -> dict[ObjectDiff, dict]: visited = visited if visited is not None else set() visited.add(server.object_id) @@ -941,7 +926,7 @@ def _get_visual_hierarchy( _, child_types_map = self.visual_hierarchy child_types = child_types_map.get(server.obj_type, []) dep_ids = self.dependencies.get(server.object_id, []) + self.dependents.get( - server.object_id, [] + server.object_id, [], ) result = {} @@ -975,7 +960,7 @@ def visual_root(self) -> ObjectDiff: @property def user_code_high(self) -> UserCode | None: - """return the user code of the high side of this batch, if it exists""" + """Return the user code of the high side of this batch, if it exists""" user_code_diff = self.user_code_diff if user_code_diff is not None and isinstance(user_code_diff.high_obj, UserCode): return user_code_diff.high_obj @@ -983,7 +968,7 @@ def user_code_high(self) -> UserCode | None: @property def user_code_diff(self) -> ObjectDiff | None: - """return the main user code diff of the high side of this batch, if it exists""" + """Return the main user code diff of the high side of this batch, if it exists""" user_code_diffs: list[ObjectDiff] = [ diff for diff in self.get_dependencies(include_roots=True) @@ -1091,8 +1076,7 @@ def from_batch(self, batch: ObjectDiffBatch) -> Any: @dataclass class ServerDiffFilter: - """ - Filter to apply to a ServerDiff object to determine if it should be included in a batch. + """Filter to apply to a ServerDiff object to determine if it should be included in a batch. Checks for `property op value` , where property: FilterProperty - property to filter on @@ -1144,7 +1128,7 @@ class ServerDiff(SyftObject): include_ignored: bool = False def resolve( - self, build_state: bool = True + self, build_state: bool = True, ) -> "PaginatedResolveWidget | SyftSuccess": if len(self.batches) == 0: return SyftSuccess(message="No batches to resolve") @@ -1177,7 +1161,7 @@ def ignored_changes(self) -> list[IgnoredBatchView]: for ignored_batch in self.ignored_batches: other_batches = [b for b in self.all_batches if b is not ignored_batch] result.append( - IgnoredBatchView(batch=ignored_batch, other_batches=other_batches) + IgnoredBatchView(batch=ignored_batch, other_batches=other_batches), ) return result @@ -1273,14 +1257,12 @@ def from_sync_state( @staticmethod def apply_previous_ignore_state( - batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int] + batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int], ) -> None: - """ - Loop through all ignored batches in syncstate. If batch did not change, set to ignored + """Loop through all ignored batches in syncstate. If batch did not change, set to ignored If another batch needs to exist in order to accept that changed batch: also unignore e.g. if a job changed, also unignore the usercode """ - for root_id, batch_hash in previously_ignored_batches.items(): for batch in batches: if batch.root_id == root_id: @@ -1289,7 +1271,7 @@ def apply_previous_ignore_state( else: logger.debug( f"""A batch with type {batch.root_type.__name__} was previously ignored but has changed -It will be available for review again.""" +It will be available for review again.""", ) # batch has changed, so unignore batch.decision = None @@ -1309,11 +1291,11 @@ def apply_previous_ignore_state( @staticmethod def dependencies_from_states( - low_state: SyncState, high_state: SyncState + low_state: SyncState, high_state: SyncState, ) -> dict[UID, list[UID]]: dependencies = {} all_parents = set(low_state.dependencies.keys()) | set( - high_state.dependencies.keys() + high_state.dependencies.keys(), ) for parent in all_parents: low_deps = low_state.dependencies.get(parent, []) @@ -1387,7 +1369,7 @@ def _sort_batches(hierarchies: list[ObjectDiffBatch]) -> list[ObjectDiffBatch]: key=lambda x: ( hierarchy_order.index(x.root.obj_type), x.root.object_id, - ) + ), ) # sorted = sorted groups + without_usercode @@ -1456,10 +1438,9 @@ def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) def _apply_filters( - self, filters: list[ServerDiffFilter], inplace: bool = True + self, filters: list[ServerDiffFilter], inplace: bool = True, ) -> Self: - """ - Apply filters to the ServerDiff object and return a new ServerDiff object + """Apply filters to the ServerDiff object and return a new ServerDiff object """ batches = self.all_batches for filter in filters: @@ -1497,29 +1478,29 @@ def _filter( new_filters = [] if user_email is not None: new_filters.append( - ServerDiffFilter(FilterProperty.USER, user_email, operator.eq) + ServerDiffFilter(FilterProperty.USER, user_email, operator.eq), ) if not include_ignored: new_filters.append( - ServerDiffFilter(FilterProperty.IGNORED, True, operator.ne) + ServerDiffFilter(FilterProperty.IGNORED, True, operator.ne), ) if not include_same: new_filters.append( - ServerDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) + ServerDiffFilter(FilterProperty.STATUS, "SAME", operator.ne), ) if include_types is not None: include_types_ = { t.__name__ if isinstance(t, type) else t for t in include_types } new_filters.append( - ServerDiffFilter(FilterProperty.TYPE, include_types_, operator.contains) + ServerDiffFilter(FilterProperty.TYPE, include_types_, operator.contains), ) if exclude_types: for exclude_type in exclude_types: if isinstance(exclude_type, type): exclude_type = exclude_type.__name__ new_filters.append( - ServerDiffFilter(FilterProperty.TYPE, exclude_type, operator.ne) + ServerDiffFilter(FilterProperty.TYPE, exclude_type, operator.ne), ) return self._apply_filters(new_filters, inplace=inplace) @@ -1562,7 +1543,7 @@ def from_batch_decision( and diff.object_type != "ExecutionOutput" ): raise ValueError( - "share_to_user is required to share private data" + "share_to_user is required to share private data", ) else: new_permissions_low_side = [ @@ -1570,7 +1551,7 @@ def from_batch_decision( uid=diff.object_id, permission=ActionPermission.READ, credentials=share_to_user, - ) + ), ] # storage permissions @@ -1581,12 +1562,12 @@ def from_batch_decision( if not mockify: new_storage_permissions.append( StoragePermission( - uid=diff.object_id, server_uid=diff.low_server_uid - ) + uid=diff.object_id, server_uid=diff.low_server_uid, + ), ) elif sync_direction == SyncDirection.LOW_TO_HIGH: new_storage_permissions.append( - StoragePermission(uid=diff.object_id, server_uid=diff.high_server_uid) + StoragePermission(uid=diff.object_id, server_uid=diff.high_server_uid), ) return cls( @@ -1618,7 +1599,7 @@ def from_client(cls, client: SyftClient) -> "ResolvedSyncState": alias: str = client.metadata.server_side_type # type: ignore if alias not in ["low", "high"]: raise ValueError( - "can only create resolved sync state for high, low side deployments" + "can only create resolved sync state for high, low side deployments", ) return cls(server_uid=client.id, alias=alias) @@ -1670,11 +1651,11 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None: if self.alias == "low": self.new_permissions.extend(sync_instruction.new_permissions_lowside) self.new_storage_permissions.extend( - sync_instruction.new_storage_permissions_lowside + sync_instruction.new_storage_permissions_lowside, ) elif self.alias == "high": self.new_storage_permissions.extend( - sync_instruction.new_storage_permissions_highside + sync_instruction.new_storage_permissions_highside, ) else: raise ValueError("Invalid alias") @@ -1722,7 +1703,7 @@ def display_diff_hierarchy(diff_hierarchy: list[tuple[ObjectDiff, int]]) -> None low_side_panel.title = "Low side" low_side_panel.title_align = "left" high_side_panel = display_diff_object( - diff.high_state if diff.high_obj is not None else None + diff.high_state if diff.high_obj is not None else None, ) high_side_panel.title = "High side" high_side_panel.title_align = "left" diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 39d0ff08b7b..cb68e51b629 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -1,42 +1,39 @@ # stdlib -from collections.abc import Callable -from enum import Enum -from enum import auto -from functools import partial import html import secrets +from collections.abc import Callable +from enum import Enum, auto +from functools import partial from typing import Any from uuid import uuid4 +import ipywidgets as widgets + # third party from IPython import display -import ipywidgets as widgets -from ipywidgets import Button -from ipywidgets import Checkbox -from ipywidgets import HBox -from ipywidgets import HTML -from ipywidgets import Layout -from ipywidgets import VBox +from ipywidgets import HTML, Button, Checkbox, HBox, Layout, VBox # relative from ...client.sync_decision import SyncDirection from ...types.uid import UID -from ...util.notebook_ui.components.sync import Alert -from ...util.notebook_ui.components.sync import CopyIDButton -from ...util.notebook_ui.components.sync import MainDescription -from ...util.notebook_ui.components.sync import SyncWidgetHeader -from ...util.notebook_ui.components.sync import TypeLabel -from ...util.notebook_ui.components.tabulator_template import build_tabulator_table -from ...util.notebook_ui.components.tabulator_template import highlight_single_row -from ...util.notebook_ui.components.tabulator_template import update_table_cell +from ...util.notebook_ui.components.sync import ( + Alert, + CopyIDButton, + MainDescription, + SyncWidgetHeader, + TypeLabel, +) +from ...util.notebook_ui.components.tabulator_template import ( + build_tabulator_table, + highlight_single_row, + update_table_cell, +) from ...util.notebook_ui.styles import CSS_CODE from ..action.action_object import ActionObject from ..api.api import TwinAPIEndpoint from ..log.log import SyftLog -from ..response import SyftError -from ..response import SyftSuccess -from .diff_state import ObjectDiff -from .diff_state import ObjectDiffBatch +from ..response import SyftError, SyftSuccess +from .diff_state import ObjectDiff, ObjectDiffBatch # Standard div Jupyter Lab uses for notebook outputs # This is needed to use alert styles from SyftSuccess and SyftError @@ -175,17 +172,17 @@ def build(self) -> widgets.HBox: target_side = "Low side" html_from = create_diff_html( - f"From {source_side} (new values)", from_properties, self.statuses + f"From {source_side} (new values)", from_properties, self.statuses, ) html_to = create_diff_html( - f"To {target_side} (old values)", to_properties, self.statuses + f"To {target_side} (old values)", to_properties, self.statuses, ) widget_from = widgets.HTML( - value=html_from, layout=widgets.Layout(width="50%", overflow="auto") + value=html_from, layout=widgets.Layout(width="50%", overflow="auto"), ) widget_to = widgets.HTML( - value=html_to, layout=widgets.Layout(width="50%", overflow="auto") + value=html_to, layout=widgets.Layout(width="50%", overflow="auto"), ) css_accordion = """