Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 19, 2024
1 parent 8f68938 commit a996409
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/py/flwr/common/metricsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def set_metrics(

def is_valid(value: MetricsScalar) -> None:
"""Check if value is of expected type."""
if not isinstance(value, get_args(MetricsScalar)):
if not isinstance(value, get_args(MetricsScalar)) or isinstance(
value, bool
):
raise TypeError(
"Not all values are of valid type."
f" Expected {MetricsRecordValues} but you passed {type(value)}."
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_set_metrics_to_metricsrecord_with_correct_types(
"key_type, value_fn",
[
(str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported)
(str, lambda x: bool(x.flatten()[0])), # str: bool (supported: unsupported)
(
str,
lambda x: x.flatten().astype("str").tolist(),
Expand All @@ -213,7 +214,7 @@ def test_set_metrics_to_metricsrecord_with_correct_types(
],
)
def test_set_metrics_to_metricsrecord_with_incorrect_types(
key_type: Type[Union[str, int, float]],
key_type: Type[Union[str, int, float, bool]],
value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]],
) -> None:
"""Test adding metrics of various unsupported types to a MetricsRecord."""
Expand Down Expand Up @@ -270,10 +271,12 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input(
(str, lambda x: str(x.flatten()[0])), # str: str
(str, lambda x: int(x.flatten()[0])), # str: int
(str, lambda x: float(x.flatten()[0])), # str: float
(str, lambda x: bool(x.flatten()[0])), # str: bool
(str, lambda x: x.flatten().tobytes()), # str: bytes
(str, lambda x: x.flatten().astype("str").tolist()), # str: List[str]
(str, lambda x: x.flatten().astype("int").tolist()), # str: List[int]
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
(str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool]
(str, lambda x: [x.flatten().tobytes()]), # str: List[bytes]
],
)
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
MetricsScalarList = Union[List[int], List[float]]
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
# Value types for common.ConfigsRecord
ConfigsScalar = Union[MetricsScalar, str, bytes]
ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes]]
ConfigsScalar = Union[MetricsScalar, str, bytes, bool]
ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]]
ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]

Metrics = Dict[str, Scalar]
Expand Down

0 comments on commit a996409

Please sign in to comment.