Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Correct numpy client fit - return data type - error message #4666

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
8 changes: 6 additions & 2 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# ==============================================================================
"""Flower client app."""


from abc import ABC
from typing import Callable

import numpy as np

from flwr.client.client import Client
from flwr.common import (
Config,
Expand Down Expand Up @@ -259,7 +260,10 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
results = self.numpy_client.fit(parameters, ins.config) # type: ignore
if not (
len(results) == 3
and isinstance(results[0], list)
and isinstance(results[0], list) # Check if it's a list
and all(
isinstance(p, np.ndarray) for p in results[0]
) # Check elements are np.ndarray
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
Expand Down
84 changes: 82 additions & 2 deletions src/py/flwr/client/numpy_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Flower NumPyClient tests."""

import numpy as np

from flwr.common import Config, NDArrays, Properties, Scalar

Expand All @@ -40,8 +41,15 @@ def get_parameters(self, config: Config) -> NDArrays:
def fit(
self, parameters: NDArrays, config: dict[str, Scalar]
) -> tuple[NDArrays, int, dict[str, Scalar]]:
"""Simulate training by returning empty weights, 0 samples, empty metrics."""
return [], 0, {}
"""Simulate training by returning updated weights, 0 samples, and metrics."""
# Simulate updated parameters as NumPy arrays
updated_parameters = [np.array([0.1, 0.2]), np.array([0.3, 0.4])]

# Simulate training metrics
metrics = {"accuracy": 0.95}

# Return updated parameters, number of examples, and metrics
return updated_parameters, 0, metrics

def evaluate(
self, parameters: NDArrays, config: dict[str, Scalar]
Expand Down Expand Up @@ -156,3 +164,75 @@ def test_has_evaluate_false() -> None:

# Assert
assert actual == expected


def test_fit_return_type() -> None:
"""Test that fit returns the correct type."""
# Prepare
client = OverridingClient()

# Execute
parameters, num_examples, metrics = client.fit(
parameters=[np.array([0.1, 0.2])], config={"epochs": 5}
)

# Assert
# Check if parameters is a list and all elements are np.ndarray
assert isinstance(parameters, list)
assert all(isinstance(p, np.ndarray) for p in parameters)

# Check other return types
assert isinstance(num_examples, int)
assert isinstance(metrics, dict)
assert all(
isinstance(k, str) and isinstance(v, (bool, bytes, float, int, str))
for k, v in metrics.items()
)


def test_evaluate_return_type() -> None:
"""Test that evaluate returns the correct type."""
# Prepare
client = OverridingClient()

# Execute
loss, num_examples, metrics = client.evaluate(
parameters=[np.array([0.1, 0.2])], config={"batch_size": 32}
)

# Assert
assert isinstance(loss, float)
assert isinstance(num_examples, int)
assert isinstance(metrics, dict)
assert all(isinstance(k, str) for k in metrics) # Fix: Removed `.keys()`
assert all(isinstance(v, (bool, bytes, float, int, str)) for v in metrics.values())


def test_get_parameters_return_type() -> None:
"""Test that get_parameters returns the correct type."""
# Prepare
client = OverridingClient()

# Execute
parameters = client.get_parameters(config={})

# Assert
# Check if parameters is a list and all elements are np.ndarray
assert isinstance(parameters, list)
assert all(isinstance(p, np.ndarray) for p in parameters)


def test_get_properties_return_type() -> None:
"""Test that get_properties returns the correct type."""
# Prepare
client = OverridingClient()

# Execute
properties = client.get_properties(config={})

# Assert
assert isinstance(properties, dict) # Properties is a dict[str, Scalar]
assert all(isinstance(k, str) for k in properties)
assert all(
isinstance(v, (bool, bytes, float, int, str)) for v in properties.values()
)
Loading