Skip to content

Commit

Permalink
chore: add metric functions to strategy
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy committed Feb 22, 2024
1 parent 7bb9bef commit 57fcdbf
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions substrafl/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,24 @@ def perform_evaluation(
raise NotImplementedError

@remote_data
<<<<<<< HEAD
def evaluate(self, datasamples: Any, shared_state: Any = None) -> Dict[str, float]:
=======
def evaluate(self, datasamples: Any, shared_state: Any = None):
>>>>>>> 5b87f8f (chore: add metric functions to strategy)
"""Is executed for each TestDataOrganizations.
Args:
datasamples (typing.Any): The output of the ``get_data`` method of the opener.
shared_state (typing.Any): None for the first round of the computation graph
then the returned object from the previous organization of the computation graph.
<<<<<<< HEAD
Returns:
Dict[str, float]: keys of the dict are the metric name, and values are the computed
performances.
=======
>>>>>>> 5b87f8f (chore: add metric functions to strategy)
"""
predictions = self.algo.predict(datasamples, shared_state)
return {
Expand Down Expand Up @@ -286,7 +293,11 @@ def _check_metric_function(metric_function: Callable) -> None:
Raises:
exceptions.MetricFunctionTypeError: metric_function must be of type "function"
exceptions.MetricFunctionSignatureError: metric_function must ONLY contains
<<<<<<< HEAD
datasamples and predictions as parameters
=======
datasamples and predictions_path as parameters
>>>>>>> 5b87f8f (chore: add metric functions to strategy)
"""

if not inspect.isfunction(metric_function):
Expand All @@ -301,11 +312,19 @@ def _check_metric_function(metric_function: Callable) -> None:
)
elif "predictions" not in parameters:
raise exceptions.MetricFunctionSignatureError(
<<<<<<< HEAD
"The metric_function: {metric_function.__name__} must contain predictions as parameter."
)
elif len(parameters) != 2:
raise exceptions.MetricFunctionSignatureError(
"""The metric_function: {metric_function.__name__} must ONLY contains datasamples and predictions as
=======
"The metric_function: {metric_function.__name__} must contain predictions_path as parameter."
)
elif len(parameters) != 2:
raise exceptions.MetricFunctionSignatureError(
"""The metric_function: {metric_function.__name__} must ONLY contains datasamples and predictions_path as
>>>>>>> 5b87f8f (chore: add metric functions to strategy)
parameters."""
)
Expand Down

0 comments on commit 57fcdbf

Please sign in to comment.