Skip to content

Commit

Permalink
Implementation of director-aggregator gRPC communication
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksandr Mokrov <[email protected]>

Update openfl/federated/plan/plan.py

Co-authored-by: Igor Davidyuk <[email protected]>
  • Loading branch information
aleksandr-mokrov and igor-davidyuk committed Feb 16, 2023
1 parent 20519fa commit d5ec206
Show file tree
Hide file tree
Showing 11 changed files with 498 additions and 211 deletions.
50 changes: 50 additions & 0 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import queue
from logging import getLogger
from typing import List

from openfl.interface.aggregation_functions import WeightedAverage
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
Expand Down Expand Up @@ -573,6 +574,55 @@ def send_local_task_results(self, collaborator_name, round_number, task_name,

self._end_of_task_check(task_name)

def get_experiment_description(self) -> dict:
"""Get a experiment information by name for specific user."""
progress = self.round_number / self.rounds_to_train
model_statuses = self.model_download_statuses
tasks = [{
'name': task['function'],
'description': 'Task description Mock',
} for task in self.assigner.tasks.values()]
collaborators = [{
'name': name,
'status': 'pending_mock',
'progress': 0.0,
'round': 0,
'current_task': 'Current Task Mock',
'next_task': 'Next Task Mock'
} for name in self.authorized_cols]
result = {
'current_round': self.round_number,
'total_rounds': self.rounds_to_train,
'download_statuses': {
'models': model_statuses,
'logs': [{
'name': 'aggregator',
'status': 'ready'
}],
},
'collaborators': collaborators,
'tasks': tasks,
'progress': progress
}
return result

@property
def model_download_statuses(self) -> List[dict]:
"""Return model download statuses representation."""
best_model_status = 'ready' if self.best_tensor_dict else 'pending'
last_model_status = 'ready' if self.last_tensor_dict else 'pending'
model_statuses = [{
'name': 'best',
'status': best_model_status,
}, {
'name': 'last',
'status': last_model_status,
}, {
'name': 'init',
'status': 'ready'
}]
return model_statuses

def _process_named_tensor(self, named_tensor, collaborator_name):
"""
Extract the named tensor fields.
Expand Down
159 changes: 58 additions & 101 deletions openfl/component/director/director.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from pathlib import Path
from typing import Callable
from typing import Iterable
from typing import List
from typing import Union
from typing import ValuesView

from openfl.protocols import base_pb2
from openfl.transport import AsyncAggregatorGRPCClient
from openfl.transport.grpc.exceptions import ShardNotFoundError

from .experiment import Experiment
Expand Down Expand Up @@ -89,6 +91,26 @@ async def set_new_experiment(
self.experiments_registry.add(experiment)
return True

async def get_aggregator_client(self, experiment_name):
"""Return an aggregator client for the experiment."""
exp = self.experiments_registry[experiment_name]
while exp.status != Status.IN_PROGRESS:
await asyncio.sleep(1)
agg_port = exp.plan.agg_port
agg_addr = exp.plan.agg_addr
logger.info(f'Aggregator uri: {agg_addr}:{agg_port}')

aggregator_client = AsyncAggregatorGRPCClient(
agg_addr,
agg_port,
tls=self.tls,
disable_client_auth=not self.tls,
root_certificate=self.root_certificate,
certificate=self.certificate,
private_key=self.private_key
)
return aggregator_client

async def get_experiment_status(
self,
experiment_name: str,
Expand All @@ -100,26 +122,23 @@ async def get_experiment_status(
return None
return self.experiments_registry[experiment_name].status

def get_trained_model(self, experiment_name: str, caller: str, model_type: str):
async def get_trained_model(self, experiment_name: str, caller: str, model_type: str):
"""Get trained model."""
if (experiment_name not in self.experiments_registry
or caller not in self.experiments_registry[experiment_name].users):
logger.error('No experiment data in the stash')
return None

aggregator = self.experiments_registry[experiment_name].aggregator

if aggregator.last_tensor_dict is None:
logger.error('Aggregator have no aggregated model to return')
exp = self.experiments_registry[experiment_name]
if exp.status != Status.IN_PROGRESS:
return None

if model_type == 'best':
return aggregator.best_tensor_dict
elif model_type == 'last':
return aggregator.last_tensor_dict
else:
logger.error('Unknown model type required.')
return None
aggregator_client = await self.get_aggregator_client(experiment_name)
trained_model = await aggregator_client.get_trained_model(
experiment_name,
model_type
)

return trained_model

def get_experiment_data(self, experiment_name: str) -> Path:
"""Get experiment data."""
Expand Down Expand Up @@ -176,19 +195,9 @@ async def stream_metrics(self, experiment_name: str, caller: str):
f' does not have access to this experiment'
)

while not self.experiments_registry[experiment_name].aggregator:
await asyncio.sleep(1)
aggregator = self.experiments_registry[experiment_name].aggregator

while True:
if not aggregator.metric_queue.empty():
yield aggregator.metric_queue.get()
continue

if aggregator.all_quit_jobs_sent() and aggregator.metric_queue.empty():
return

yield None
aggregator_client = await self.get_aggregator_client(experiment_name)
async for metric_dict in aggregator_client.get_metric_stream():
yield metric_dict

def remove_experiment_data(self, experiment_name: str, caller: str):
"""Remove experiment data from stash."""
Expand Down Expand Up @@ -241,7 +250,7 @@ def update_envoy_status(

return self.envoy_health_check_period

def get_envoys(self) -> list:
def get_envoys(self) -> ValuesView:
"""Get a status information about envoys."""
logger.info(f'Shard registry: {self._shard_registry}')
for envoy_info in self._shard_registry.values():
Expand All @@ -250,11 +259,11 @@ def get_envoys(self) -> list:
+ envoy_info.get('valid_duration', 0)
)
envoy_name = envoy_info['shard_info']['node_info']['name']
envoy_info['experiment_name'] = self.col_exp[envoy_name]
envoy_info['experiment_name'] = self.col_exp.get(envoy_name)

return self._shard_registry.values()

def get_experiments_list(self, caller: str) -> list:
async def get_experiments_list(self, caller: str) -> list:
"""Get experiments list for specific user."""
experiments = self.experiments_registry.get_user_experiments(caller)
result = []
Expand All @@ -264,45 +273,32 @@ def get_experiments_list(self, caller: str) -> list:
'status': exp.status,
'collaborators_amount': len(exp.collaborators),
}
progress = _get_experiment_progress(exp)
if progress is not None:
exp_data['progress'] = progress
if exp.aggregator:
tasks_amount = len({
task['function']
for task in exp.aggregator.assigner.tasks.values()
})
exp_data['tasks_amount'] = tasks_amount
if exp.status == Status.IN_PROGRESS:
aggregator_client = await self.get_aggregator_client(exp.name)
experiment_pb2 = await aggregator_client.get_experiment_description()
exp_data['progress'] = experiment_pb2.progress
exp_data['tasks_amount'] = len(experiment_pb2.tasks)
result.append(exp_data)

return result

def get_experiment_description(self, caller: str, name: str) -> dict:
async def get_experiment_description(self, caller: str, experiment_name: str) -> dict:
"""Get a experiment information by name for specific user."""
exp = self.experiments_registry.get(name)
exp = self.experiments_registry.get(experiment_name)
if not exp or caller not in exp.users:
logger.info(f'Experiment {experiment_name} for user {caller} does not exist.')
return {}
progress = _get_experiment_progress(exp)
model_statuses = _get_model_download_statuses(exp)
tasks = _get_experiment_tasks(exp)
collaborators = _get_experiment_collaborators(exp)
result = {
'name': name,
'status': exp.status,
'current_round': exp.aggregator.round_number,
'total_rounds': exp.aggregator.rounds_to_train,
'download_statuses': {
'models': model_statuses,
'logs': [{
'name': 'aggregator',
'status': 'ready'
}],
},
'collaborators': collaborators,
'tasks': tasks,
'progress': progress
}
return result
if exp.status != Status.IN_PROGRESS:
return base_pb2.ExperimentDescription(
name=exp.name,
status=exp.status,
)
aggregator_client = await self.get_aggregator_client(experiment_name)
experiment_pb2 = await aggregator_client.get_experiment_description()
experiment_pb2.name = experiment_name
experiment_pb2.status = exp.status

return experiment_pb2

async def start_experiment_execution_loop(self):
"""Run task to monitor and run experiments."""
Expand Down Expand Up @@ -331,42 +327,3 @@ async def start_experiment_execution_loop(self):
queue = self.col_exp_queues[col_name]
await queue.put(experiment.name)
await run_aggregator_future


def _get_model_download_statuses(experiment) -> List[dict]:
best_model_status = 'ready' if experiment.aggregator.best_tensor_dict else 'pending'
last_model_status = 'ready' if experiment.aggregator.last_tensor_dict else 'pending'
model_statuses = [{
'name': 'best',
'status': best_model_status,
}, {
'name': 'last',
'status': last_model_status,
}, {
'name': 'init',
'status': 'ready'
}]
return model_statuses


def _get_experiment_progress(experiment) -> Union[float, None]:
if experiment.status == Status.IN_PROGRESS:
return experiment.aggregator.round_number / experiment.aggregator.rounds_to_train


def _get_experiment_tasks(experiment) -> List[dict]:
return [{
'name': task['function'],
'description': 'Task description Mock',
} for task in experiment.aggregator.assigner.tasks.values()]


def _get_experiment_collaborators(experiment) -> List[dict]:
return [{
'name': name,
'status': 'pending_mock',
'progress': 0.0,
'round': 0,
'current_task': 'Current Task Mock',
'next_task': 'Next Task Mock'
} for name in experiment.aggregator.authorized_cols]
7 changes: 4 additions & 3 deletions openfl/component/director/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self.sender = sender
self.init_tensor_dict = init_tensor_dict
self.plan_path = Path(plan_path)
self.plan = None
self.users = set() if users is None else set(users)
self.status = Status.PENDING
self.aggregator = None
Expand Down Expand Up @@ -125,11 +126,11 @@ def _create_aggregator_grpc_server(
private_key: Union[Path, str] = None,
certificate: Union[Path, str] = None,
) -> AggregatorGRPCServer:
plan = Plan.parse(plan_config_path=self.plan_path)
plan.authorized_cols = list(self.collaborators)
self.plan = Plan.parse(plan_config_path=self.plan_path)
self.plan.authorized_cols = list(self.collaborators)

logger.info(f'🧿 Created an Aggregator Server for {self.name} experiment.')
aggregator_grpc_server = plan.interactive_api_get_server(
aggregator_grpc_server = self.plan.interactive_api_get_server(
tensor_dict=self.init_tensor_dict,
root_certificate=root_certificate,
certificate=certificate,
Expand Down
14 changes: 11 additions & 3 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __init__(self):
self.hash_ = None
self.name_ = None
self.serializer_ = None
self.agg_addr = None
self.agg_port = None

@property
def hash(self): # NOQA
Expand All @@ -237,6 +239,12 @@ def hash(self): # NOQA

return self.hash_.hexdigest()

def generate_agg_port(self):
"""Generate an aggregator port by plan hash."""
return int(
self.hash[:8], 16
) % (60999 - 49152) + 49152

def resolve(self):
"""Resolve the federation settings."""
self.federation_uuid = f'{self.name}_{self.hash[:8]}'
Expand All @@ -247,11 +255,11 @@ def resolve(self):

if self.config['network'][SETTINGS]['agg_addr'] == AUTO:
self.config['network'][SETTINGS]['agg_addr'] = getfqdn_env()
self.agg_addr = self.config['network'][SETTINGS]['agg_addr']

if self.config['network'][SETTINGS]['agg_port'] == AUTO:
self.config['network'][SETTINGS]['agg_port'] = int(
self.hash[:8], 16
) % (60999 - 49152) + 49152
self.config['network'][SETTINGS]['agg_port'] = self.generate_agg_port()
self.agg_port = self.config['network'][SETTINGS]['agg_port']

def get_assigner(self):
"""Get the plan task assigner."""
Expand Down
6 changes: 1 addition & 5 deletions openfl/protocols/aggregator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ message SendLocalTaskResultsResponse {
MessageHeader header = 1;
}

// The same with director.proto
message GetMetricStreamRequest {
string experiment_name = 1;
}

// The same with director.proto
Expand Down Expand Up @@ -104,9 +102,7 @@ message TrainedModelResponse {
}

// The same with director.proto
message GetExperimentDescriptionRequest {
string name = 1;
}
message GetExperimentDescriptionRequest {}

// The same with director.proto
message GetExperimentDescriptionResponse {
Expand Down
Empty file.
2 changes: 2 additions & 0 deletions openfl/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from .grpc import AggregatorGRPCClient
from .grpc import AggregatorGRPCServer
from .grpc import AsyncAggregatorGRPCClient
from .grpc import DirectorGRPCServer

__all__ = [
'AggregatorGRPCServer',
'AggregatorGRPCClient',
'AsyncAggregatorGRPCClient',
'DirectorGRPCServer',
]
Loading

0 comments on commit d5ec206

Please sign in to comment.