diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 6a656cb661d2..a2f57f9b7258 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -14,10 +14,11 @@ # ============================================================================== """Flower client app.""" - from abc import ABC from typing import Callable +import numpy as np + from flwr.client.client import Client from flwr.common import ( Config, @@ -259,7 +260,10 @@ def _fit(self: Client, ins: FitIns) -> FitRes: results = self.numpy_client.fit(parameters, ins.config) # type: ignore if not ( len(results) == 3 - and isinstance(results[0], list) + and isinstance(results[0], list) # Check if it's a list + and all( + isinstance(p, np.ndarray) for p in results[0] + ) # Check elements are np.ndarray and isinstance(results[1], int) and isinstance(results[2], dict) ): diff --git a/src/py/flwr/client/numpy_client_test.py b/src/py/flwr/client/numpy_client_test.py index c5d520a73ce1..f64d2e4ec25a 100644 --- a/src/py/flwr/client/numpy_client_test.py +++ b/src/py/flwr/client/numpy_client_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower NumPyClient tests.""" +import numpy as np from flwr.common import Config, NDArrays, Properties, Scalar @@ -40,8 +41,15 @@ def get_parameters(self, config: Config) -> NDArrays: def fit( self, parameters: NDArrays, config: dict[str, Scalar] ) -> tuple[NDArrays, int, dict[str, Scalar]]: - """Simulate training by returning empty weights, 0 samples, empty metrics.""" - return [], 0, {} + """Simulate training by returning updated weights, 0 samples, and metrics.""" + # Simulate updated parameters as NumPy arrays + updated_parameters = [np.array([0.1, 0.2]), np.array([0.3, 0.4])] + + # Simulate training metrics + metrics = {"accuracy": 0.95} + + # Return updated parameters, number of examples, and metrics + return updated_parameters, 0, metrics def evaluate( self, parameters: NDArrays, config: dict[str, Scalar] @@ -156,3 +164,75 @@ def test_has_evaluate_false() -> None: # Assert assert actual == expected + + +def test_fit_return_type() -> None: + """Test that fit returns the correct type.""" + # Prepare + client = OverridingClient() + + # Execute + parameters, num_examples, metrics = client.fit( + parameters=[np.array([0.1, 0.2])], config={"epochs": 5} + ) + + # Assert + # Check if parameters is a list and all elements are np.ndarray + assert isinstance(parameters, list) + assert all(isinstance(p, np.ndarray) for p in parameters) + + # Check other return types + assert isinstance(num_examples, int) + assert isinstance(metrics, dict) + assert all( + isinstance(k, str) and isinstance(v, (bool, bytes, float, int, str)) + for k, v in metrics.items() + ) + + +def test_evaluate_return_type() -> None: + """Test that evaluate returns the correct type.""" + # Prepare + client = OverridingClient() + + # Execute + loss, num_examples, metrics = client.evaluate( + parameters=[np.array([0.1, 0.2])], config={"batch_size": 32} + ) + + # Assert + assert isinstance(loss, float) + assert isinstance(num_examples, int) + assert isinstance(metrics, dict) + assert all(isinstance(k, str) for k in metrics) # Fix: Removed `.keys()` + assert all(isinstance(v, (bool, bytes, float, int, str)) for v in metrics.values()) + + +def test_get_parameters_return_type() -> None: + """Test that get_parameters returns the correct type.""" + # Prepare + client = OverridingClient() + + # Execute + parameters = client.get_parameters(config={}) + + # Assert + # Check if parameters is a list and all elements are np.ndarray + assert isinstance(parameters, list) + assert all(isinstance(p, np.ndarray) for p in parameters) + + +def test_get_properties_return_type() -> None: + """Test that get_properties returns the correct type.""" + # Prepare + client = OverridingClient() + + # Execute + properties = client.get_properties(config={}) + + # Assert + assert isinstance(properties, dict) # Properties is a dict[str, Scalar] + assert all(isinstance(k, str) for k in properties) + assert all( + isinstance(v, (bool, bytes, float, int, str)) for v in properties.values() + )