diff --git a/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb b/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb index 3a3431e2917..dd77e8d09f4 100644 --- a/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb +++ b/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb @@ -341,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a0ddd55-f403-44aa-ae82-da02458a3ef1", + "id": "833a3601-f4b9-4063-a7db-eaa438d668b9", "metadata": {}, "outputs": [], "source": [ @@ -691,13 +691,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Check result of execution on mock data\n", - "# TODO: Re-enable mock flow\n", - "# mock_result = compute_census_matches(\n", - "# canada_census_data=canada_census_data.mock,\n", - "# italy_census_data=italy_census_data.mock,\n", - "# )\n", - "# mock_result" + "# Mock Model Flow\n", + "mock_result = run_inference(\n", + " model=gpt2_model.mock,\n", + " evals=gpt2_gender_bias_evals_asset.mock,\n", + " syft_no_server=True,\n", + ")\n", + "mock_result" ] }, { @@ -941,6 +941,14 @@ "if canada_enclave.deployment_type.value == \"python\":\n", " canada_enclave.land()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4353bd6-0a69-4b3b-b686-b915a07b9027", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 94ebeaa7b4b..b5361968ad1 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -368,6 +368,21 @@ def get_mock( return result.ok() return SyftError(message=result.err()) + # TODO: fix this Tech Debt, currently , we do not have a way to add + # ActionPermission.ALL_READ to the permissions + # Like we have for stashes (document store) + # This is a temporary fix to allow the user to get the model code + @service_method( + path="action.get_model_code", name="get_model_code", roles=GUEST_ROLE_LEVEL + ) + def get_model_code( + self, context: AuthedServiceContext, uid: UID + ) -> Result[SyftError, SyftObject]: + result = self.store.get_model_code(uid=uid) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + @service_method( path="action.has_storage_permission", name="has_storage_permission", diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 250b3c5e9b5..951782d1779 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -113,6 +113,20 @@ def get_mock(self, uid: UID) -> Result[SyftObject, str]: except Exception as e: return Err(f"Could not find item with uid {uid}, {e}") + def get_model_code(self, uid: UID) -> Result[SyftObject, str]: + # relative + from ..model.model import SubmitModelCode + + uid = uid.id # We only need the UID from LineageID or UID + + try: + syft_object = self.data[uid] + if isinstance(syft_object, SubmitModelCode): + return Ok(syft_object) + return Err("No SubmitModelCode in Store") + except Exception as e: + return Err(f"Could not find item with uid {uid}, {e}") + def get_pointer( self, uid: UID, diff --git a/packages/syft/src/syft/service/model/model.py b/packages/syft/src/syft/service/model/model.py index c144b8bc3aa..b177052d813 100644 --- a/packages/syft/src/syft/service/model/model.py +++ b/packages/syft/src/syft/service/model/model.py @@ -132,6 +132,27 @@ def data(self) -> Any: display(warning) return None + @property + def mock(self) -> SyftError | Any: + # relative + from ...client.api import APIRegistry + + api = APIRegistry.api_for( + server_uid=self.syft_server_location, + user_verify_key=self.syft_client_verify_key, + ) + if api is None: + raise ValueError(f"api is None. You must login to {self.syft_server_uid}") + result = api.services.action.get_mock(self.action_id) + if isinstance(result, SyftError): + return result + try: + if isinstance(result, SyftObject): + return result.syft_action_data + return result + except Exception as e: + return SyftError(message=f"Failed to get mock. {e}") + # def __call__(self, *args, **kwargs) -> Any: # endpoint = self.endpoint # result = endpoint.__call__(*args, **kwargs) @@ -319,7 +340,7 @@ def model_code(self) -> SubmitModelCode | None: ) if api is None or api.services is None: return None - res = api.services.action.get(self.code_action_id) + res = api.services.action.get_model_code(self.code_action_id) if has_permission(res): return res else: @@ -329,6 +350,14 @@ def model_code(self) -> SubmitModelCode | None: display(warning) return None + @property + def mock(self) -> SyftModelClass: + model_code = self.model_code + if model_code is None: + raise ValueError("[Model.mock] Cannot access model code") + mock_assets = [asset.mock for asset in self.asset_list] + return model_code(assets=mock_assets) + def _coll_repr_(self) -> dict[str, Any]: return { "Name": self.name,