Skip to content

Commit

Permalink
Merge pull request #9275 from OpenMined/remove_get_service_method
Browse files Browse the repository at this point in the history
Remove get service method and get service
  • Loading branch information
eelcovdw authored Sep 9, 2024
2 parents 56fd893 + 7b8bef8 commit 3e189db
Show file tree
Hide file tree
Showing 47 changed files with 282 additions and 517 deletions.
2 changes: 2 additions & 0 deletions packages/syft/src/syft/abstract_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

if TYPE_CHECKING:
# relative
from .server.service_registry import ServiceRegistry
from .service.service import AbstractService


Expand Down Expand Up @@ -39,6 +40,7 @@ class AbstractServer:
server_type: ServerType | None
server_side_type: ServerSideType | None
in_memory_workers: bool
services: "ServiceRegistry"

def get_service(self, path_or_func: str | Callable) -> "AbstractService":
raise NotImplementedError
5 changes: 3 additions & 2 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,9 @@ def register(self, new_user: UserCreate) -> SyftSigningKey | None:
)
else:
service_context = ServerServiceContext(server=self.server)
method = self.server.get_service_method(UserService.register)
response = method(context=service_context, new_user=new_user)
response = self.server.services.user.register(
context=service_context, new_user=new_user
)
response = post_process_result(response, unwrap_on_success=False)
return response

Expand Down
14 changes: 6 additions & 8 deletions packages/syft/src/syft/server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def _get_server_connection(peer_uid: UID) -> ServerConnection:
# relative
from ..service.network.server_peer import route_to_connection

network_service = worker.get_service("NetworkService")
peer = network_service.stash.get_by_uid(worker.verify_key, peer_uid).unwrap()
peer = worker.network.stash.get_by_uid(worker.verify_key, peer_uid).unwrap()
peer_server_route = peer.pick_highest_priority_route()
connection = route_to_connection(route=peer_server_route)
return connection
Expand Down Expand Up @@ -168,9 +167,8 @@ def syft_new_api_call(

def handle_forgot_password(email: str, server: AbstractServer) -> Response:
try:
method = server.get_service_method(UserService.forgot_password)
context = UnauthedServiceContext(server=server)
result = method(context=context, email=email)
result = server.services.user.forgot_password(context=context, email=email)
except SyftException as e:
result = SyftError.from_public_exception(e)

Expand All @@ -186,9 +184,10 @@ def handle_reset_password(
token: str, new_password: str, server: AbstractServer
) -> Response:
try:
method = server.get_service_method(UserService.reset_password)
context = UnauthedServiceContext(server=server)
result = method(context=context, token=token, new_password=new_password)
result = server.services.user.reset_password(
context=context, token=token, new_password=new_password
)
except SyftException as e:
result = SyftError.from_public_exception(e)

Expand All @@ -206,12 +205,11 @@ def handle_login(email: str, password: str, server: AbstractServer) -> Response:
except ValidationError as e:
return {"Error": e.json()}

method = server.get_service_method(UserService.exchange_credentials)
context = UnauthedServiceContext(
server=server, login_credentials=login_credentials
)
try:
result = method(context=context).value
result = server.services.user.exchange_credentials(context=context).value
if not isinstance(result, UserPrivateKey):
response = SyftError(message=f"Incorrect return type: {type(result)}")
else:
Expand Down
77 changes: 27 additions & 50 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
from ..service.action.action_store import DictActionStore
from ..service.action.action_store import MongoActionStore
from ..service.action.action_store import SQLiteActionStore
from ..service.blob_storage.service import BlobStorageService
from ..service.code.user_code_service import UserCodeService
from ..service.code.user_code_stash import UserCodeStash
from ..service.context import AuthedServiceContext
from ..service.context import ServerServiceContext
Expand All @@ -54,7 +52,6 @@
from ..service.job.job_stash import JobStatus
from ..service.job.job_stash import JobType
from ..service.metadata.server_metadata import ServerMetadata
from ..service.network.network_service import NetworkService
from ..service.network.utils import PeerHealthCheckTask
from ..service.notifier.notifier_service import NotifierService
from ..service.queue.base_queue import AbstractMessageHandler
Expand Down Expand Up @@ -82,12 +79,10 @@
from ..service.user.user import UserCreate
from ..service.user.user import UserView
from ..service.user.user_roles import ServiceRole
from ..service.user.user_service import UserService
from ..service.user.user_stash import UserStash
from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG
from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME
from ..service.worker.utils import create_default_image
from ..service.worker.worker_image_service import SyftWorkerImageService
from ..service.worker.worker_pool import WorkerPool
from ..service.worker.worker_pool_service import SyftWorkerPoolService
from ..service.worker.worker_pool_stash import SyftWorkerPoolStash
Expand Down Expand Up @@ -535,8 +530,7 @@ def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None:
from ..store.blob_storage.seaweedfs import SeaweedFSConfig

if isinstance(config, SeaweedFSConfig) and self.signing_key:
blob_storage_service = self.get_service(BlobStorageService)
remote_profiles = blob_storage_service.remote_profile_stash.get_all(
remote_profiles = self.services.blob_storage.remote_profile_stash.get_all(
credentials=self.signing_key.verify_key, has_permission=True
).unwrap()
for remote_profile in remote_profiles:
Expand Down Expand Up @@ -825,8 +819,9 @@ def find_and_migrate_data(
credentials=self.verify_key,
role=ServiceRole.ADMIN,
)
migration_service = self.get_service("migrationservice")
return migration_service.migrate_data(context, document_store_object_types)
return self.services.migration.migrate_data(
context, document_store_object_types
)

@property
def guest_client(self) -> SyftClient:
Expand Down Expand Up @@ -878,11 +873,10 @@ def post_init(self) -> None:
)

if "usercodeservice" in self.service_path_map:
user_code_service = self.get_service(UserCodeService)
user_code_service.load_user_code(context=context)
self.services.user_code.load_user_code(context=context)

def reload_user_code() -> None:
user_code_service.load_user_code(context=context)
self.services.user_code.load_user_code(context=context)

ti = thread_ident()
if ti is not None:
Expand Down Expand Up @@ -939,11 +933,11 @@ def init_stores(

@property
def job_stash(self) -> JobStash:
return self.get_service("jobservice").stash
return self.services.job.stash

@property
def worker_stash(self) -> WorkerStash:
return self.get_service("workerservice").stash
return self.services.worker.stash

@property
def service_path_map(self) -> dict[str, AbstractService]:
Expand Down Expand Up @@ -1126,9 +1120,9 @@ def forward_message(
)

client = None

network_service = self.get_service(NetworkService)
peer = network_service.stash.get_by_uid(self.verify_key, server_uid).unwrap()
peer = self.services.network.stash.get_by_uid(
self.verify_key, server_uid
).unwrap()

# Since we have several routes to a peer
# we need to cache the client for a given server_uid along with the route
Expand Down Expand Up @@ -1172,11 +1166,9 @@ def forward_message(
raise SyftException(public_message=(f"Server has no route to {server_uid}"))

def get_role_for_credentials(self, credentials: SyftVerifyKey) -> ServiceRole:
return (
self.get_service("userservice")
.get_role_for_credentials(credentials=credentials)
.unwrap()
)
return self.services.user.get_role_for_credentials(
credentials=credentials
).unwrap()

@instrument
def handle_api_call(
Expand Down Expand Up @@ -1423,10 +1415,7 @@ def add_action_to_queue(
has_execute_permissions=has_execute_permissions,
worker_pool=worker_pool_ref, # set worker pool reference as part of queue item
)

user_service = self.get_service("UserService")
user_service = cast(UserService, user_service)
user_id = user_service.get_user_id_for_credentials(credentials).unwrap()
user_id = self.services.user.get_user_id_for_credentials(credentials).unwrap()

return self.add_queueitem_to_queue(
queue_item=queue_item,
Expand Down Expand Up @@ -1454,9 +1443,6 @@ def add_queueitem_to_queue(
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(server=self, credentials=credentials, role=role)

action_service = self.get_service("actionservice")
log_service = self.get_service("logservice")

result_obj = ActionObject.empty()
if action is not None:
result_obj = ActionObject.obj_not_ready(
Expand All @@ -1469,10 +1455,8 @@ def add_queueitem_to_queue(
result_obj.syft_server_location = self.id
result_obj.syft_client_verify_key = credentials

action_service = self.get_service("actionservice")

if not action_service.store.exists(uid=action.result_id):
action_service.set_result_to_store(
if not self.services.action.store.exists(uid=action.result_id):
self.services.action.set_result_to_store(
result_action_object=result_obj,
context=context,
).unwrap()
Expand All @@ -1495,7 +1479,7 @@ def add_queueitem_to_queue(
self.job_stash.set(credentials, job).unwrap()
self.queue_stash.set_placeholder(credentials, queue_item).unwrap()

log_service.add(context, log_id, queue_item.job_id)
self.services.log.add(context, log_id, queue_item.job_id)

return job

Expand All @@ -1519,8 +1503,7 @@ def _sort_jobs(self, jobs: list[Job]) -> list[Job]:
def _get_existing_user_code_jobs(
self, context: AuthedServiceContext, user_code_id: UID
) -> list[Job]:
job_service = self.get_service("jobservice")
jobs = job_service.get_by_user_code_id(
jobs = self.services.job.get_by_user_code_id(
context=context, user_code_id=user_code_id
)
return self._sort_jobs(jobs)
Expand All @@ -1533,8 +1516,7 @@ def _is_usercode_call_on_owned_kwargs(
) -> bool:
if api_call.path != "code.call":
return False
user_code_service = self.get_service("usercodeservice")
return user_code_service.is_execution_on_owned_args(
return self.services.user_code.is_execution_on_owned_args(
context, user_code_id, api_call.kwargs
)

Expand Down Expand Up @@ -1562,7 +1544,7 @@ def add_api_call_to_queue(
action = Action.from_api_call(unsigned_call)
user_code_id = action.user_code_id

user = self.get_service(UserService).get_current_user(context)
user = self.services.user.get_current_user(context)
user = cast(UserView, user)

is_execution_on_owned_kwargs_allowed = (
Expand Down Expand Up @@ -1637,11 +1619,11 @@ def add_api_call_to_queue(

@property
def pool_stash(self) -> SyftWorkerPoolStash:
return self.get_service(SyftWorkerPoolService).stash
return self.services.syft_worker_pool.stash

@property
def user_code_stash(self) -> UserCodeStash:
return self.get_service(UserCodeService).stash
return self.services.user_code.stash

@as_result(NotFoundException)
def get_default_worker_pool(self) -> WorkerPool | None:
Expand Down Expand Up @@ -1813,7 +1795,7 @@ def get_default_worker_tag_by_env(dev_mode: bool = False) -> str | None:
def create_default_worker_pool(server: Server) -> None:
credentials = server.verify_key
pull_image = not server.dev_mode
image_stash = server.get_service(SyftWorkerImageService).stash
image_stash = server.services.syft_worker_image.stash
default_pool_name = server.settings.default_worker_pool

try:
Expand Down Expand Up @@ -1845,9 +1827,8 @@ def create_default_worker_pool(server: Server) -> None:

if not default_image.is_built:
logger.info(f"Building default worker image with tag={default_worker_tag}. ")
image_build_method = server.get_service_method(SyftWorkerImageService.build)
# Build the Image for given tag
result = image_build_method(
result = server.services.worker_image.build(
context,
image_uid=default_image.id,
tag=DEFAULT_WORKER_IMAGE_TAG,
Expand All @@ -1864,8 +1845,7 @@ def create_default_worker_pool(server: Server) -> None:
)
if default_worker_pool is None:
worker_to_add_ = worker_count
create_pool_method = server.get_service_method(SyftWorkerPoolService.launch)
result = create_pool_method(
result = server.services.syft_worker_pool.launch(
context,
pool_name=default_pool_name,
image_uid=default_image.id,
Expand All @@ -1879,10 +1859,7 @@ def create_default_worker_pool(server: Server) -> None:
default_worker_pool.worker_list
)
if worker_to_add_ > 0:
add_worker_method = server.get_service_method(
SyftWorkerPoolService.add_workers
)
result = add_worker_method(
result = server.services.syft_worker_pool.add_workers(
context=context,
number=worker_to_add_,
pool_name=default_pool_name,
Expand Down
14 changes: 7 additions & 7 deletions packages/syft/src/syft/service/action/action_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ def __call_function(
self, call_mode: EXECUTION_MODE, *args: Any, **kwargs: Any
) -> Any:
self.context = self.__check_context()
endpoint_service = self.context.server.get_service("apiservice")

if call_mode == EXECUTION_MODE.MOCK:
__endpoint_mode = endpoint_service.execute_server_side_endpoint_mock_by_id
elif call_mode == EXECUTION_MODE.PRIVATE:
__endpoint_mode = (
endpoint_service.execute_service_side_endpoint_private_by_id
self.context.server.services.api.execute_server_side_endpoint_mock_by_id
)
elif call_mode == EXECUTION_MODE.PRIVATE:
__endpoint_mode = self.context.server.services.api.execute_service_side_endpoint_private_by_id
else:
__endpoint_mode = endpoint_service.execute_server_side_endpoint_by_id
__endpoint_mode = (
self.context.server.services.api.execute_server_side_endpoint_by_id
)

return __endpoint_mode(
return __endpoint_mode( # type: ignore[misc]
*args,
context=self.context,
endpoint_uid=self.endpoint_id,
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,9 @@ def get_sync_dependencies(
# relative
from ..job.job_stash import Job

job_service = context.server.get_service("jobservice") # type: ignore
job: Job | None = job_service.get_by_result_id(context, self.id.id) # type: ignore
job: Job | None = context.server.services.job.get_by_result_id(
context, self.id.id
) # type: ignore
if job is not None:
return [job.id]
else:
Expand Down
Loading

0 comments on commit 3e189db

Please sign in to comment.