Skip to content

Commit

Permalink
Merge pull request #8921 from khoaguin/more-datatypes-for-assets
Browse files Browse the repository at this point in the history
Suport more data types for data uploading
  • Loading branch information
koenvanderveen authored Jun 20, 2024
2 parents df2859b + f9ec721 commit e043645
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
4 changes: 2 additions & 2 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
try:
contains_empty = asset.contains_empty()
twin = TwinObject(
private_obj=asset.data,
mock_obj=asset.mock,
private_obj=ActionObject.from_obj(asset.data),
mock_obj=ActionObject.from_obj(asset.mock),
syft_node_location=self.id,
syft_client_verify_key=self.verify_key,
)
Expand Down
48 changes: 48 additions & 0 deletions packages/syft/tests/syft/service/dataset/dataset_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

# third party
import numpy as np
import pandas as pd
from pydantic import ValidationError
import pytest
import torch

# syft absolute
import syft as sy
Expand Down Expand Up @@ -253,3 +255,49 @@ def test_adding_contributors_with_duplicate_email():
assert isinstance(res3, SyftSuccess)
assert isinstance(res4, SyftError)
assert len(asset.contributors) == 1


@pytest.fixture(
params=[
1,
"hello",
{"key": "value"},
{1, 2, 3},
np.array([1, 2, 3]),
pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}),
torch.Tensor([1, 2, 3]),
]
)
def different_data_types(
request,
) -> int | str | dict | set | np.ndarray | pd.DataFrame | torch.Tensor:
return request.param


def test_upload_dataset_with_assets_of_different_data_types(
worker: Worker,
different_data_types: int
| str
| dict
| set
| np.ndarray
| pd.DataFrame
| torch.Tensor,
) -> None:
asset = sy.Asset(
name=random_hash(),
data=different_data_types,
mock=different_data_types,
)
dataset = Dataset(name=random_hash())
dataset.add_asset(asset)
root_domain_client = worker.root_client
res = root_domain_client.upload_dataset(dataset)
assert isinstance(res, SyftSuccess)
assert len(root_domain_client.api.services.dataset.get_all()) == 1
assert type(root_domain_client.datasets[0].assets[0].data) == type(
different_data_types
)
assert type(root_domain_client.datasets[0].assets[0].mock) == type(
different_data_types
)

0 comments on commit e043645

Please sign in to comment.