-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- implement a new ModeBasedAssigner class extending RandomGroupedAssi…
…gner - update defaults/assigner.yaml to use ModeBasedAssigner with default mode to train_and_validate - update fedeval sample workspace to use default assigner, tasks and aggregator - use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1 - removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks - add tests for ModeBasedAssigner Signed-off-by: Shailesh Pant <[email protected]>
- Loading branch information
1 parent
bc5ba2b
commit 6a04db5
Showing
8 changed files
with
150 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 0 additions & 7 deletions
7
openfl-workspace/workspace/plan/defaults/federated-evaluation/assigner.yaml
This file was deleted.
Oops, something went wrong.
7 changes: 0 additions & 7 deletions
7
openfl-workspace/workspace/plan/defaults/federated-evaluation/tasks_torch.yaml
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
"""Mode based assigner module.""" | ||
|
||
import numpy as np | ||
|
||
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner | ||
|
||
|
||
class ModeBasedAssigner(RandomGroupedAssigner): | ||
r"""The task assigner maintains a list of tasks. | ||
Also it decides the policy for which collaborator should run those tasks | ||
There may be many types of policies implemented, but a natural place to | ||
start is with a: | ||
- ModeBasedAssigner : | ||
Given a set of task groups and mode it filters the task groups based | ||
on mode. It futher enforces checks for specific modes. | ||
Post filtering it deletgates the task assignment to RandomGroupedAssigner. | ||
Attributes: | ||
task_groups* (list of object): Task groups to assign. | ||
mode* (str): Mode to determine task assignments. | ||
.. note:: | ||
\* - Plan setting. | ||
""" | ||
|
||
def __init__(self, task_groups, mode, **kwargs): | ||
"""Initializes the ModeBasedAssigner. | ||
Args: | ||
task_groups (list of object): Task groups to assign. | ||
mode (str): Mode to determine task assignments. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
self.task_groups = task_groups | ||
self.mode = mode | ||
super().__init__(task_groups=task_groups, **kwargs) | ||
|
||
def define_task_assignments(self): | ||
"""Define task assignments for each round and collaborator. | ||
This method filters tasks for the | ||
collaborators for each round based on the mode. | ||
Args: | ||
None | ||
Returns: | ||
None | ||
""" | ||
self.task_groups = [ | ||
group for group in self.task_groups | ||
if group["name"] == self.mode | ||
] | ||
if self.mode == "evaluate" : | ||
assert ( | ||
self.rounds == 1 | ||
), "Number of rounds should be 1 for evaluate mode" | ||
super().define_task_assignments() |
73 changes: 73 additions & 0 deletions
73
tests/openfl/component/assigner/test_mode_based_assigner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
from openfl.component.assigner.mode_based_assigner import ModeBasedAssigner | ||
|
||
@pytest.fixture | ||
def sample_task_groups(): | ||
return [ | ||
{"name": "train", "tasks": ["task1", "task2"]}, | ||
{"name": "evaluate", "tasks": ["task3"]}, | ||
{"name": "validate", "tasks": ["task4"]} | ||
] | ||
|
||
def test_init_with_valid_mode(): | ||
"""Test initialization with valid mode and task groups.""" | ||
task_groups = [{"name": "train", "tasks": ["task1"]}] | ||
assigner = ModeBasedAssigner(task_groups=task_groups, mode="train") | ||
assert assigner.mode == "train" | ||
assert assigner.task_groups == task_groups | ||
|
||
def test_define_task_assignments_train_mode(sample_task_groups): | ||
"""Test task assignments filtering for train mode.""" | ||
assigner = ModeBasedAssigner( | ||
task_groups=sample_task_groups, | ||
mode="train", | ||
rounds=3 | ||
) | ||
assigner.define_task_assignments() | ||
assert len(assigner.task_groups) == 1 | ||
assert assigner.task_groups[0]["name"] == "train" | ||
|
||
def test_define_task_assignments_evaluate_mode(sample_task_groups): | ||
"""Test task assignments filtering for evaluate mode with rounds=1.""" | ||
assigner = ModeBasedAssigner( | ||
task_groups=sample_task_groups, | ||
mode="evaluate", | ||
rounds=1 | ||
) | ||
assigner.define_task_assignments() | ||
assert len(assigner.task_groups) == 1 | ||
assert assigner.task_groups[0]["name"] == "evaluate" | ||
|
||
def test_evaluate_mode_with_invalid_rounds(sample_task_groups): | ||
"""Test that evaluate mode raises error when rounds > 1.""" | ||
assigner = ModeBasedAssigner( | ||
task_groups=sample_task_groups, | ||
mode="evaluate", | ||
rounds=2 | ||
) | ||
with pytest.raises(AssertionError): | ||
assigner.define_task_assignments() | ||
|
||
def test_empty_task_groups_after_filtering(): | ||
"""Test behavior when no task groups match the specified mode.""" | ||
task_groups = [{"name": "train", "tasks": ["task1"]}] | ||
assigner = ModeBasedAssigner( | ||
task_groups=task_groups, | ||
mode="non_existent_mode" | ||
) | ||
assigner.define_task_assignments() | ||
assert len(assigner.task_groups) == 0 | ||
|
||
def test_multiple_matching_task_groups(): | ||
"""Test behavior when multiple task groups match the mode.""" | ||
task_groups = [ | ||
{"name": "train", "tasks": ["task1"]}, | ||
{"name": "train", "tasks": ["task2"]} | ||
] | ||
assigner = ModeBasedAssigner( | ||
task_groups=task_groups, | ||
mode="train" | ||
) | ||
assigner.define_task_assignments() | ||
assert len(assigner.task_groups) == 2 | ||
assert all(group["name"] == "train" for group in assigner.task_groups) |