Skip to content

Commit

Permalink
v1; ++tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 19, 2024
1 parent a1e79d1 commit cbb9766
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 130 deletions.
262 changes: 147 additions & 115 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Callable, Dict, List, OrderedDict, Type, Union, Any
from typing import Any, Callable, Dict, List, OrderedDict, Type, Union

import numpy as np
import pytest
Expand All @@ -26,40 +25,45 @@
from .metricsrecord import MetricsRecord
from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
from .recordset_utils import (
_embed_status_into_recordset,
evaluate_ins_to_recordset,
evaluate_res_to_recordset,
fit_ins_to_recordset,
fit_res_to_recordset,
getparameters_ins_to_recordset,
getparameters_res_to_recordset,
getproperties_ins_to_recordset,
getproperties_res_to_recordset,
parameters_to_parametersrecord,
parametersrecord_to_parameters,
recordset_to_evaluate_ins,
recordset_to_evaluate_res,
recordset_to_fit_ins,
recordset_to_fit_res,
recordset_to_getparameters_ins,
recordset_to_getparameters_res,
recordset_to_getproperties_ins,
recordset_to_getproperties_res,
evaluate_res_to_recordset,
recordset_to_evaluate_res,
)
from .typing import (
Code,
Scalar,
ConfigsRecordValues,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
GetPropertiesIns,
EvaluateRes,
GetPropertiesRes,
MetricsRecordValues,
NDArray,
NDArrays,
Parameters,
Scalar,
Status,
)

from flwr.client.message_handler.message_handler_test import ClientWithProps, _get_client_fn


def get_ndarrays() -> NDArrays:
"""Return list of NumPy arrays."""
Expand Down Expand Up @@ -362,150 +366,178 @@ def test_set_configs_to_configsrecord_with_incorrect_types(
m_record.set_configs(my_metrics) # type: ignore


##################################################
# Testing conversion: *Ins --> RecordSet --> *Ins
# Testing conversion: *Res <-- RecordSet <-- *Res
##################################################

@pytest.mark.parametrize(
"context, config",
[
(nullcontext(), {'a': 1.0, 'b': 0}),
(pytest.raises(TypeError), {'a': 1.0, 'b': 3, 'c': True}), # fails due to unsupported type for configrecord value
],
)
def test_fitins_to_recordset_and_back(context: Any, config: Dict[str, Scalar]) -> None:

def test_fitins_to_recordset_and_back() -> None:
"""Test conversion FitIns --> RecordSet --> FitIns."""
arrays = get_ndarrays()
fitins = FitIns(parameters=ndarrays_to_parameters(arrays), config=config)
fitins = FitIns(
parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0}
)

fitins_copy = deepcopy(fitins)

with context:
recordset = fit_ins_to_recordset(fitins)

fitins_ = recordset_to_fit_ins(recordset, keep_input=False)
recordset = fit_ins_to_recordset(fitins, keep_input=False)

fitins_ = recordset_to_fit_ins(recordset, keep_input=False)

assert fitins_copy == fitins_


@pytest.mark.parametrize(
"context, metrics",
[
(nullcontext(), {"a": 1.0, "b": 0}),
(
pytest.raises(TypeError),
{"a": 1.0, "b": 3, "c": True},
), # fails due to unsupported type for metricsrecord value
],
)
def test_fitres_to_recordset_and_back(context: Any, metrics: Dict[str, Scalar]) -> None:
"""Test conversion FitRes --> RecordSet --> FitRes."""
arrays = get_ndarrays()
fitres = FitRes(
parameters=ndarrays_to_parameters(arrays),
num_examples=1,
status=Status(code=Code(0), message=""),
metrics=metrics,
)

###################### DELETE FROM BELOW #################################
###################### DELETE FROM BELOW #################################
###################### DELETE FROM BELOW #################################
###################### DELETE FROM BELOW #################################
fitres_copy = deepcopy(fitres)

def _get_recordset_compatible_with_legacy_ins(ins_str: str) -> RecordSet:
recordset = RecordSet()
with context:
recordset = fit_res_to_recordset(fitres, keep_input=False)
fitres_ = recordset_to_fit_res(recordset, keep_input=False)

# add a ParametersRecord
array_dict = OrderedDict(
{str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
)
recordset.set_parameters(
f"{ins_str}.parameters", record=ParametersRecord(array_dict)
)
# only check if we didn't test for an invalid setting. Only in valid settings
# makes sense to evaluate the below, since both functions above have succesfully
# being executed.
if isinstance(context, nullcontext):
assert fitres_copy == fitres_

# add a ConfigsRecord
recordset.set_configs(
f"{ins_str}.config",
record=ConfigsRecord({"a": 1, "b": 2.0, "c": np.eye(2).flatten().tobytes()}),

def test_evaluateins_to_recordset_and_back() -> None:
"""Test conversion EvaluateIns --> RecordSet --> EvaluateIns."""
arrays = get_ndarrays()
evaluateins = EvaluateIns(
parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0}
)

return recordset
evaluateins_copy = deepcopy(evaluateins)

recordset = evaluate_ins_to_recordset(evaluateins, keep_input=False)

evaluateins_ = recordset_to_evaluate_ins(recordset, keep_input=False)

assert evaluateins_copy == evaluateins_


@pytest.mark.parametrize(
"ins_str, do_func, undo_func",
"context, metrics",
[
(nullcontext(), {"a": 1.0, "b": 0}),
(
"fitins",
partial(recordset_to_fit_ins, keep_input=True),
fit_ins_to_recordset,
),
(
"evaluateins",
partial(recordset_to_evaluate_ins, keep_input=True),
evaluate_ins_to_recordset,
),
pytest.raises(TypeError),
{"a": 1.0, "b": 3, "c": True},
), # fails due to unsupported type for metricsrecord value
],
)
def test_recordset_to_fit_or_evaluate_ins_and_back(
ins_str: str,
do_func: Callable[[RecordSet], Union[FitIns, EvaluateIns]],
undo_func: Callable[[Union[FitIns, EvaluateIns]], RecordSet],
def test_evaluateres_to_recordset_and_back(
context: Any, metrics: Dict[str, Scalar]
) -> None:
"""."""
valid_record_set = _get_recordset_compatible_with_legacy_ins(ins_str)
"""Test conversion EvaluateRes --> RecordSet --> EvaluateRes."""
evaluateres = EvaluateRes(
num_examples=1,
loss=0.1,
status=Status(code=Code(0), message=""),
metrics=metrics,
)

ins = do_func(valid_record_set)
evaluateres_copy = deepcopy(evaluateres)

reverted_record_set = undo_func(ins)
with context:
recordset = evaluate_res_to_recordset(evaluateres)
evaluateres_ = recordset_to_evaluate_res(recordset)

assert valid_record_set.configs == reverted_record_set.configs
# TODO: how to check parameters consistency (given than Array->Parameters is
# a destructive process ? (i.e. different metadata encoded))
# only check if we didn't test for an invalid setting. Only in valid settings
# makes sense to evaluate the below, since both functions above have succesfully
# being executed.
if isinstance(context, nullcontext):
assert evaluateres_copy == evaluateres_


def test_get_properties_ins_to_recordset_and_back() -> None:
"""Test conversion GetPropertiesIns --> RecordSet --> GetPropertiesIns."""
config_dict: Dict[str, Scalar] = {
"a": 1.0,
"b": 3,
"c": True,
} # valid since both Ins/Res communicate over ConfigsRecord

@pytest.mark.parametrize(
"ins_str, do_func, undo_func",
[
(
"getevaluateres",
recordset_to_evaluate_res,
evaluate_res_to_recordset,
),
],
)
def test_recordset_to_evaluate_res_and_back(
ins_str: str,
do_func: Callable[[RecordSet], EvaluateRes],
undo_func: Callable[[EvaluateRes], RecordSet],
) -> None:
getproperties_ins = GetPropertiesIns(config_dict)

recordset = RecordSet()
getproperties_ins_copy = deepcopy(getproperties_ins)

recordset = getproperties_ins_to_recordset(getproperties_ins)
getproperties_ins_ = recordset_to_getproperties_ins(recordset)

assert getproperties_ins_copy == getproperties_ins_

def test_getproperties_res_to_recordset_and_back() -> None:
"""."""
client_fn = _get_client_fn(ClientWithProps())

def test_get_properties_res_to_recordset_and_back() -> None:
"""Test conversion GetPropertiesRes --> RecordSet --> GetPropertiesRes."""
config_dict: Dict[str, Scalar] = {
"a": 1.0,
"b": 3,
"c": True,
} # valid since both Ins/Res communicate over ConfigsRecord

@pytest.mark.parametrize(
"ins_str, do_func, undo_func",
[
(
"getpropertiesins",
recordset_to_getproperties_ins,
getproperties_ins_to_recordset,
),
(
"getpropertiesres",
recordset_to_getproperties_res,
getproperties_res_to_recordset,
),
],
)
def test_recordset_to_get_properties_ins_or_res_and_back(
ins_str: str,
do_func: Callable[[RecordSet], Union[GetPropertiesIns, GetPropertiesRes]],
undo_func: Callable[[Union[GetPropertiesIns, GetPropertiesRes]], RecordSet],
) -> None:
"""."""
recordset = RecordSet()
recordset.set_configs(
f"{ins_str}.{'properties' if 'res' in ins_str else 'config'}",
record=ConfigsRecord({"a": 1, "b": 2.0, "c": np.eye(2).flatten().tobytes()}),
getproperties_res = GetPropertiesRes(
status=Status(code=Code(0), message=""), properties=config_dict
)

# embed status if it's a response message only
if "res" in ins_str:
recordset = _embed_status_into_recordset(
ins_str, status=Status(code=Code(0), message="hello"), recordset=recordset
)
getproperties_res_copy = deepcopy(getproperties_res)

recordset = getproperties_res_to_recordset(getproperties_res)
getproperties_res_ = recordset_to_getproperties_res(recordset)

assert getproperties_res_copy == getproperties_res_


def test_get_parameters_ins_to_recordset_and_back() -> None:
"""Test conversion GetParametersIns --> RecordSet --> GetParametersIns."""
config_dict: Dict[str, Scalar] = {
"a": 1.0,
"b": 3,
"c": True,
} # valid since both Ins/Res communicate over ConfigsRecord

getparameters_ins = GetParametersIns(config_dict)

getparameters_ins_copy = deepcopy(getparameters_ins)

recordset = getparameters_ins_to_recordset(getparameters_ins)
getparameters_ins_ = recordset_to_getparameters_ins(recordset)

assert getparameters_ins_copy == getparameters_ins_


def test_get_parameters_res_to_recordset_and_back() -> None:
"""Test conversion GetParametersRes --> RecordSet --> GetParametersRes."""
arrays = get_ndarrays()
getparameteres_res = GetParametersRes(
status=Status(code=Code(0), message=""),
parameters=ndarrays_to_parameters(arrays),
)

recordset_copy = deepcopy(recordset)
getparameters_res_copy = deepcopy(getparameteres_res)

ins = do_func(recordset)
recordset = getparameters_res_to_recordset(getparameteres_res)
getparameteres_res_ = recordset_to_getparameters_res(recordset)

recordset_ = undo_func(ins)
assert recordset_copy == recordset_
assert getparameters_res_copy == getparameteres_res_
Loading

0 comments on commit cbb9766

Please sign in to comment.