Skip to content

Commit

Permalink
Merge branch 'dev' into test_upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
teo-milea authored Sep 23, 2024
2 parents 145c0dc + f5c0e8c commit 8e9f2ec
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 49 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pr-tests-stack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ jobs:
tox -e migration.scenarios.test
pr-tests-migrations-k8s:
if: false # skipping this job for now
strategy:
max-parallel: 99
matrix:
Expand Down
10 changes: 8 additions & 2 deletions packages/syft/src/syft/client/datasite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,14 @@ def get_migration_data(self, include_blobs: bool = True) -> MigrationData:

return res

def load_migration_data(self, path: str | Path) -> SyftSuccess:
migration_data = MigrationData.from_file(path)
def load_migration_data(
self, path_or_data: str | Path | MigrationData
) -> SyftSuccess:
if isinstance(path_or_data, MigrationData):
migration_data = path_or_data
else:
migration_data = MigrationData.from_file(path_or_data)

migration_data._set_obj_location_(self.id, self.verify_key)

if self.id != migration_data.server_uid:
Expand Down
8 changes: 4 additions & 4 deletions packages/syft/src/syft/service/api/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set(
public_message="An API endpoint already exists at the given path."
)

result = self.stash.upsert(context.credentials, endpoint=new_endpoint).unwrap()
result = self.stash.upsert(context.credentials, obj=new_endpoint).unwrap()
action_obj = ActionObject.from_obj(
id=new_endpoint.action_object_id,
syft_action_data=CustomEndpointActionObject(endpoint_id=result.id),
Expand Down Expand Up @@ -157,7 +157,7 @@ def update(
endpoint.mock_function.view_access = view_access

# save changes
self.stash.upsert(context.credentials, endpoint=endpoint).unwrap()
self.stash.upsert(context.credentials, obj=endpoint).unwrap()
return SyftSuccess(message="Endpoint successfully updated.")

@service_method(
Expand Down Expand Up @@ -218,7 +218,7 @@ def set_state(
if mock and api_endpoint.mock_function:
api_endpoint.mock_function.state = state

self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap()
self.stash.upsert(context.credentials, obj=api_endpoint).unwrap()
return SyftSuccess(message=f"APIEndpoint {api_path} state updated.")

@service_method(
Expand Down Expand Up @@ -248,7 +248,7 @@ def set_settings(
if mock and api_endpoint.mock_function:
api_endpoint.mock_function.settings = settings

self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap()
self.stash.upsert(context.credentials, obj=api_endpoint).unwrap()
return SyftSuccess(message=f"APIEndpoint {api_path} settings updated.")

@service_method(
Expand Down
19 changes: 0 additions & 19 deletions packages/syft/src/syft/service/api/api_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,3 @@ def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool:
return True
except NotFoundException:
return False

@as_result(StashException)
def upsert(
self,
credentials: SyftVerifyKey,
endpoint: TwinAPIEndpoint,
has_permission: bool = False,
) -> TwinAPIEndpoint:
"""Upsert an endpoint."""
exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap()

if exists:
super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap()

return (
super()
.set(credentials=credentials, obj=endpoint, ignore_duplicates=False)
.unwrap()
)
17 changes: 9 additions & 8 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,23 +642,24 @@ def allowed_ids_only(
public_message=f"Invalid server type for code submission: {context.server.server_type}"
)

server_identity = ServerIdentity(
server_name=context.server.name,
server_id=context.server.id,
verify_key=context.server.signing_key.verify_key,
)
allowed_inputs = allowed_inputs.get(server_identity, {})
allowed_inputs_for_server = None
for identity, inputs in allowed_inputs.items():
if identity.server_id == context.server.id:
allowed_inputs_for_server = inputs
break
if allowed_inputs_for_server is None:
allowed_inputs_for_server = {}

filtered_kwargs = {}
for key in allowed_inputs.keys():
for key in allowed_inputs_for_server.keys():
if key in kwargs:
value = kwargs[key]
uid = value

if not isinstance(uid, UID):
uid = getattr(value, "id", None)

if uid != allowed_inputs[key]:
if uid != allowed_inputs_for_server[key]:
raise SyftException(
public_message=f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
)
Expand Down
49 changes: 34 additions & 15 deletions packages/syft/src/syft/store/db/stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
"""
Decorator to inject a session into the function kwargs if it is not provided.
Make sure to pass session as a keyword argument to the function.
TODO: This decorator is a temporary fix, we want to move to a DI approach instead:
move db connection and session to context, and pass context to all stash methods.
"""
Expand All @@ -87,8 +89,9 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any:
if inject_session and kwargs.get("session") is None:
with self.sessionmaker() as session:
kwargs["session"] = session
return func(self, *args, **kwargs)
with session.begin():
kwargs["session"] = session
return func(self, *args, **kwargs)
return func(self, *args, **kwargs)

return wrapper # type: ignore
Expand Down Expand Up @@ -369,11 +372,13 @@ def set(
uid = obj.id

# check if the object already exists
if self.exists(credentials, uid) or not self.is_unique(obj):
if self.exists(credentials, uid, session=session) or not self.is_unique(
obj, session=session
):
if ignore_duplicates:
return obj
unique_fields_str = ", ".join(self.unique_fields)
raise StashException(
raise UniqueConstraintException(
public_message=f"Duplication Key Error for {obj}.\n"
f"The fields that should be unique are {unique_fields_str}."
)
Expand All @@ -399,7 +404,6 @@ def set(
raise StashException(
f"Error serializing object: {e}. Some fields are invalid."
)

# create the object with the permissions
stmt = self.table.insert().values(
id=uid,
Expand All @@ -408,7 +412,6 @@ def set(
storage_permissions=storage_permissions,
)
session.execute(stmt)
session.commit()
return self.get_by_uid(credentials, uid, session=session).unwrap()

@as_result(ValidationError, AttributeError)
Expand Down Expand Up @@ -462,7 +465,7 @@ def update(
).unwrap()

# TODO has_permission is not used
if not self.is_unique(obj):
if not self.is_unique(obj, session=session):
raise UniqueConstraintException(
f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}"
)
Expand All @@ -483,14 +486,12 @@ def update(
f"Error serializing object: {e}. Some fields are invalid."
)
stmt = stmt.values(fields=fields)

result = session.execute(stmt)
session.commit()
if result.rowcount == 0:
raise NotFoundException(
f"{self.object_type.__name__}: {obj.id} not found or no permission to update."
)
return self.get_by_uid(credentials, obj.id).unwrap()
return self.get_by_uid(credentials, obj.id, session=session).unwrap()

@as_result(StashException, NotFoundException)
@with_session
Expand All @@ -510,7 +511,6 @@ def delete_by_uid(
session=session,
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 0:
raise NotFoundException(
f"{self.object_type.__name__}: {uid} not found or no permission to delete."
Expand Down Expand Up @@ -649,8 +649,6 @@ def add_permission(
stmt = self.table.update().where(self.table.c.id == permission.uid)
stmt = stmt.values(permissions=list(existing_permissions))
session.execute(stmt)
session.commit()

return None

@as_result(NotFoundException)
Expand Down Expand Up @@ -685,7 +683,6 @@ def remove_permission(
.values(permissions=list(permissions))
)
session.execute(stmt)
session.commit()
return None

@with_session
Expand Down Expand Up @@ -842,7 +839,6 @@ def remove_storage_permission(
.values(storage_permissions=[str(uid) for uid in permissions])
)
session.execute(stmt)
session.commit()
return None

@as_result(StashException)
Expand All @@ -857,3 +853,26 @@ def _get_storage_permissions_for_uid(
if result is None:
raise NotFoundException(f"No storage permissions found for uid: {uid}")
return {UID(uid) for uid in result.storage_permissions}

@with_session
@as_result(StashException)
def upsert(
self,
credentials: SyftVerifyKey,
obj: StashT,
session: Session = None,
) -> StashT:
"""Insert or update an object in the stash if it already exists.
Atomic operation when using the same session for both operations.
"""

try:
return self.set(
credentials=credentials,
obj=obj,
session=session,
).unwrap()
except UniqueConstraintException:
return self.update(
credentials=credentials, obj=obj, session=session
).unwrap()
24 changes: 24 additions & 0 deletions packages/syft/tests/syft/stores/base_stash_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,30 @@ def test_basestash_update(
assert retrieved == updated_obj


def test_basestash_upsert(
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
) -> None:
base_stash.set(root_verify_key, mock_object).unwrap()

updated_obj = mock_object.copy()
updated_obj.name = faker.name()

retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap()
assert retrieved == updated_obj

updated_obj.id = UID()

with pytest.raises(StashException):
# fails because the name should be unique
base_stash.upsert(root_verify_key, updated_obj).unwrap()

updated_obj.name = faker.name()

retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap()
assert retrieved == updated_obj
assert len(base_stash.get_all(root_verify_key).unwrap()) == 2


def test_basestash_cannot_update_non_existent(
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
) -> None:
Expand Down

0 comments on commit 8e9f2ec

Please sign in to comment.