Skip to content

Commit

Permalink
Add Mock Model Execution
Browse files Browse the repository at this point in the history
  • Loading branch information
rasswanth-s committed Aug 3, 2024
1 parent e9596c5 commit 51fc9d6
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0a0ddd55-f403-44aa-ae82-da02458a3ef1",
"id": "833a3601-f4b9-4063-a7db-eaa438d668b9",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -691,13 +691,13 @@
"metadata": {},
"outputs": [],
"source": [
"# Check result of execution on mock data\n",
"# TODO: Re-enable mock flow\n",
"# mock_result = compute_census_matches(\n",
"# canada_census_data=canada_census_data.mock,\n",
"# italy_census_data=italy_census_data.mock,\n",
"# )\n",
"# mock_result"
"# Mock Model Flow\n",
"mock_result = run_inference(\n",
" model=gpt2_model.mock,\n",
" evals=gpt2_gender_bias_evals_asset.mock,\n",
" syft_no_server=True,\n",
")\n",
"mock_result"
]
},
{
Expand Down Expand Up @@ -941,6 +941,14 @@
"if canada_enclave.deployment_type.value == \"python\":\n",
" canada_enclave.land()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4353bd6-0a69-4b3b-b686-b915a07b9027",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
15 changes: 15 additions & 0 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,21 @@ def get_mock(
return result.ok()
return SyftError(message=result.err())

# TODO: fix this Tech Debt, currently , we do not have a way to add
# ActionPermission.ALL_READ to the permissions
# Like we have for stashes (document store)
# This is a temporary fix to allow the user to get the model code
@service_method(
path="action.get_model_code", name="get_model_code", roles=GUEST_ROLE_LEVEL
)
def get_model_code(
self, context: AuthedServiceContext, uid: UID
) -> Result[SyftError, SyftObject]:
result = self.store.get_model_code(uid=uid)
if result.is_ok():
return result.ok()
return SyftError(message=result.err())

@service_method(
path="action.has_storage_permission",
name="has_storage_permission",
Expand Down
14 changes: 14 additions & 0 deletions packages/syft/src/syft/service/action/action_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ def get_mock(self, uid: UID) -> Result[SyftObject, str]:
except Exception as e:
return Err(f"Could not find item with uid {uid}, {e}")

def get_model_code(self, uid: UID) -> Result[SyftObject, str]:
# relative
from ..model.model import SubmitModelCode

uid = uid.id # We only need the UID from LineageID or UID

try:
syft_object = self.data[uid]
if isinstance(syft_object, SubmitModelCode):
return Ok(syft_object)
return Err("No SubmitModelCode in Store")
except Exception as e:
return Err(f"Could not find item with uid {uid}, {e}")

def get_pointer(
self,
uid: UID,
Expand Down
31 changes: 30 additions & 1 deletion packages/syft/src/syft/service/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def data(self) -> Any:
display(warning)
return None

@property
def mock(self) -> SyftError | Any:
# relative
from ...client.api import APIRegistry

api = APIRegistry.api_for(
server_uid=self.syft_server_location,
user_verify_key=self.syft_client_verify_key,
)
if api is None:
raise ValueError(f"api is None. You must login to {self.syft_server_uid}")
result = api.services.action.get_mock(self.action_id)
if isinstance(result, SyftError):
return result
try:
if isinstance(result, SyftObject):
return result.syft_action_data
return result
except Exception as e:
return SyftError(message=f"Failed to get mock. {e}")

# def __call__(self, *args, **kwargs) -> Any:
# endpoint = self.endpoint
# result = endpoint.__call__(*args, **kwargs)
Expand Down Expand Up @@ -319,7 +340,7 @@ def model_code(self) -> SubmitModelCode | None:
)
if api is None or api.services is None:
return None
res = api.services.action.get(self.code_action_id)
res = api.services.action.get_model_code(self.code_action_id)
if has_permission(res):
return res
else:
Expand All @@ -329,6 +350,14 @@ def model_code(self) -> SubmitModelCode | None:
display(warning)
return None

@property
def mock(self) -> SyftModelClass:
model_code = self.model_code
if model_code is None:
raise ValueError("[Model.mock] Cannot access model code")
mock_assets = [asset.mock for asset in self.asset_list]
return model_code(assets=mock_assets)

def _coll_repr_(self) -> dict[str, Any]:
return {
"Name": self.name,
Expand Down

0 comments on commit 51fc9d6

Please sign in to comment.