Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Aggregator / Assigner leaky abstraction #1301

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db

20 changes: 6 additions & 14 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
task_group: str = "learning",
):
"""Initializes the Aggregator.

Expand All @@ -110,9 +109,7 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
task_group (str, optional): Selected task_group for assignment.
"""
self.task_group = task_group
self.round_number = 0
self.next_model_round_number = 0

Expand All @@ -129,11 +126,10 @@ def __init__(
self.straggler_handling_policy = straggler_handling_policy()

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
self.assigner = assigner
if self.assigner.is_task_group_evaluation():
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
)
logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}")

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []
Expand All @@ -142,11 +138,7 @@ def __init__(
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
# # override the assigner selected_task_group
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
# # and doesn't have selected_task_group as attribute
# assigner.selected_task_group = task_group
self.assigner = assigner

self.quit_job_sent_to = []

self.tensor_db = TensorDB()
Expand Down Expand Up @@ -308,8 +300,8 @@ def _load_initial_tensors(self):
)

# Check selected task_group before updating round number
if self.task_group == "evaluation":
logger.info(f"Skipping round_number check for {self.task_group} task_group")
if self.assigner.is_task_group_evaluation():
logger.info("Skipping round_number check for evaluation run")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number
Expand Down
10 changes: 10 additions & 0 deletions openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number):
"""Abstract method."""
raise NotImplementedError

def is_task_group_evaluation(self):
"""Check if the selected task group is for 'evaluation' run.

Returns:
bool: True if the selected task group is 'evaluation', False otherwise.
"""
if hasattr(self, "selected_task_group"):
return self.selected_task_group == "evaluation"
return False

def get_all_tasks_for_round(self, round_number):
"""Return tasks for the current round.

Expand Down
7 changes: 3 additions & 4 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ def start_(plan, authorized_cols, task_group):
cols_config_path=Path(authorized_cols).absolute(),
)

# Set task_group in aggregator and assigner settings if provided
# Set task_group in assigner settings if provided
if task_group:
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
if "settings" not in parsed_plan.config["assigner"]:
parsed_plan.config["assigner"]["settings"] = {}
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

Expand Down
5 changes: 1 addition & 4 deletions tests/openfl/component/aggregator/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,13 @@ def agg(mocker, model, assigner):
'some_uuid',
'federation_uuid',
['col1', 'col2'],

'init_state_path',
'best_state_path',
'last_state_path',

assigner,
)
)
return agg


@pytest.mark.parametrize(
'cert_common_name,collaborator_common_name,authorized_cols,single_cccn,expected_is_valid', [
('col1', 'col1', ['col1', 'col2'], '', True),
Expand Down
15 changes: 0 additions & 15 deletions tests/openfl/interface/test_aggregator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ def test_aggregator_start(mock_parse):
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
Expand Down Expand Up @@ -55,11 +50,6 @@ def test_aggregator_start_illegal_plan(mock_parse, mock_is_directory_traversal):
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
Expand Down Expand Up @@ -89,11 +79,6 @@ def test_aggregator_start_illegal_cols(mock_parse, mock_is_directory_traversal):
mock_plan.get = {'task_group': 'learning'}.get
# Add the config attribute with proper nesting
mock_plan.config = {
'aggregator': {
'settings': {
'task_group': 'learning'
}
},
'assigner': {
'settings': {
'selected_task_group': 'learning'
Expand Down