From 14a5d4166204e64d956be76356f2ba4525ab06e7 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 31 Jul 2024 17:49:58 -0400 Subject: [PATCH 1/2] avoid searching all datasets in asset property --- .../syft/src/syft/service/code/user_code.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 4b8c1c8e7ae..bd93aa1d04a 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -726,17 +726,6 @@ def assets(self) -> DictTuple[str, Asset] | SyftError: if isinstance(api, SyftError): return api - # get all assets on the server - datasets: list[Dataset] = api.services.dataset.get_all() - if isinstance(datasets, SyftError): - return datasets - - all_assets: dict[UID, Asset] = {} - for dataset in datasets: - for asset in dataset.asset_list: - asset._dataset_name = dataset.name - all_assets[asset.action_id] = asset - # get a flat dict of all inputs all_inputs = {} inputs = self.input_policy_init_kwargs or {} @@ -746,10 +735,14 @@ def assets(self) -> DictTuple[str, Asset] | SyftError: # map the action_id to the asset used_assets: list[Asset] = [] for kwarg_name, action_id in all_inputs.items(): - asset = all_assets.get(action_id, None) - if asset: - asset._kwarg_name = kwarg_name - used_assets.append(asset) + + assets = api.dataset.get_assets_by_action_id(uid=action_id) + if isinstance(assets, SyftError): + return assets + if assets: + asset = assets[0] + asset._kwarg_name = kwarg_name + used_assets.append(asset) asset_dict = {asset._kwarg_name: asset for asset in used_assets} return DictTuple(asset_dict) From 2a9d7c9c298d8504d01f1188851ff606d6634ed7 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 31 Jul 2024 17:57:00 -0400 Subject: [PATCH 2/2] formatting --- packages/syft/src/syft/service/code/user_code.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index bd93aa1d04a..1d4615e2e54 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -73,7 +73,6 @@ from ..action.action_object import ActionObject from ..context import AuthedServiceContext from ..dataset.dataset import Asset -from ..dataset.dataset import Dataset from ..job.job_stash import Job from ..output.output_service import ExecutionOutput from ..output.output_service import OutputService @@ -735,14 +734,13 @@ def assets(self) -> DictTuple[str, Asset] | SyftError: # map the action_id to the asset used_assets: list[Asset] = [] for kwarg_name, action_id in all_inputs.items(): - - assets = api.dataset.get_assets_by_action_id(uid=action_id) - if isinstance(assets, SyftError): - return assets - if assets: - asset = assets[0] - asset._kwarg_name = kwarg_name - used_assets.append(asset) + assets = api.dataset.get_assets_by_action_id(uid=action_id) + if isinstance(assets, SyftError): + return assets + if assets: + asset = assets[0] + asset._kwarg_name = kwarg_name + used_assets.append(asset) asset_dict = {asset._kwarg_name: asset for asset in used_assets} return DictTuple(asset_dict)