Skip to content

Commit

Permalink
test: add tests for export model feature (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurélien Gasser authored Apr 13, 2021
1 parent 9af72bd commit 7839815
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 25 deletions.
1 change: 1 addition & 0 deletions charts/substra-tests/templates/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ data:
values.yaml: |
options:
enable_intermediate_model_removal: False
enable_model_download: True
nodes:
- name: 'node-1'
msp_id: 'MyOrg1MSP'
Expand Down
1 change: 1 addition & 0 deletions local-backend-values.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
options:
enable_intermediate_model_removal: False
enable_model_download: True
nodes:
- name: 'local-backend'
msp_id: 'local-backend'
Expand Down
15 changes: 15 additions & 0 deletions substratest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ def download_opener(self, key):
with open(path, 'rb') as f:
return f.read()

def download_model(self, key):
with tempfile.TemporaryDirectory() as tmp:
self._client.download_model(key, tmp)
path = os.path.join(tmp, f'model_{key}')
with open(path, 'rb') as f:
return f.read()

def download_trunk_model_from_composite_traintuple(self, composite_traintuple_key):
with tempfile.TemporaryDirectory() as tmp:
self._client.download_trunk_model_from_composite_traintuple(composite_traintuple_key, tmp)
tuple = self.get_composite_traintuple(composite_traintuple_key)
path = os.path.join(tmp, f'model_{tuple.out_trunk_model.out_model.key}')
with open(path, 'rb') as f:
return f.read()

def describe_dataset(self, key):
return self._client.describe_dataset(key)

Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class NodeCfg:
@dataclasses.dataclass(frozen=True)
class Options:
enable_intermediate_model_removal: bool
enable_model_download: bool
minikube: bool = False


Expand Down
59 changes: 34 additions & 25 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.slow
def test_tuples_execution_on_same_node(factory, client, default_dataset, default_objective):
def test_tuples_execution_on_same_node(factory, network, client, default_dataset, default_objective):
"""Execution of a traintuple, a following testtuple and a following traintuple."""

spec = factory.create_algo()
Expand All @@ -29,6 +29,9 @@ def test_tuples_execution_on_same_node(factory, client, default_dataset, default
assert traintuple.metadata == {"foo": "bar"}
assert traintuple.out_model is not None

if network.options.enable_model_download:
assert client.download_model(traintuple.out_model.key) == b'{"value": 2.2}'

# check we can add twice the same traintuple
client.add_traintuple(spec)

Expand Down Expand Up @@ -382,28 +385,34 @@ def test_aggregate_composite_traintuples(factory, network, clients, default_data
testtuple = clients[0].wait(testtuple)
assert testtuple.dataset.perf == 32

if not network.options.enable_intermediate_model_removal:
return

# Optional (if "enable_intermediate_model_removal" is True): ensure the aggregatetuple of round 1 has been deleted.
#
# We do this by creating a new traintuple that depends on the deleted aggregatatuple, and ensuring that starting
# the traintuple fails.
#
# Ideally it would be better to try to do a request "as a backend" to get the deleted model. This would be closer
# to what we want to test and would also check that this request is correctly handled when the model has been
# deleted. Here, we cannot know for sure the failure reason. Unfortunately this cannot be done now as the
# username/password are not available in the settings files.

client = clients[0]
dataset = default_datasets[0]
algo = client.add_algo(spec)
if network.options.enable_model_download:
# Optional (if "enable_model_download" is True): ensure we can export out-models.
#
# - One out-model download is not proxified (direct download)
# - One out-model download is proxified (as it belongs to another org)
for tuple in previous_composite_traintuples:
assert clients[0].download_trunk_model_from_composite_traintuple(tuple.key) == b'{"value": 2.8}'

if network.options.enable_intermediate_model_removal:
# Optional (if "enable_intermediate_model_removal" is True): ensure the aggregatetuple of round 1 has been deleted.
#
# We do this by creating a new traintuple that depends on the deleted aggregatatuple, and ensuring that starting
# the traintuple fails.
#
# Ideally it would be better to try to do a request "as a backend" to get the deleted model. This would be closer
# to what we want to test and would also check that this request is correctly handled when the model has been
# deleted. Here, we cannot know for sure the failure reason. Unfortunately this cannot be done now as the
# username/password are not available in the settings files.

client = clients[0]
dataset = default_datasets[0]
algo = client.add_algo(spec)

spec = factory.create_traintuple(
algo=algo,
dataset=dataset,
data_samples=dataset.train_data_sample_keys,
)
traintuple = client.add_traintuple(spec)
traintuple = client.wait(traintuple)
assert traintuple.status == Status.failed
spec = factory.create_traintuple(
algo=algo,
dataset=dataset,
data_samples=dataset.train_data_sample_keys,
)
traintuple = client.add_traintuple(spec)
traintuple = client.wait(traintuple)
assert traintuple.status == Status.failed
1 change: 1 addition & 0 deletions values.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
options:
enable_intermediate_model_removal: False
enable_model_download: True
# minikube: True
nodes:
- name: 'node-1'
Expand Down

0 comments on commit 7839815

Please sign in to comment.