Skip to content

Commit

Permalink
Extend profiling data
Browse files Browse the repository at this point in the history
  • Loading branch information
bacox committed May 8, 2021
1 parent 4e1416c commit e458dfc
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
15 changes: 7 additions & 8 deletions fltk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,25 +250,24 @@ def test(self):
return accuracy, loss, class_precision, class_recall

def run_epochs(self, num_epoch):
start_time_train = datetime.datetime.now()
start_time_train = time.time()
self.dataset.get_train_sampler().set_epoch_size(num_epoch)
loss, weights = self.train(self.epoch_counter)
self.epoch_counter += num_epoch
elapsed_time_train = datetime.datetime.now() - start_time_train
train_time_ms = int(elapsed_time_train.total_seconds()*1000)
elapsed_train_time = time.time() - start_time_train

start_time_test = datetime.datetime.now()
start_time_test = time.time()
accuracy, test_loss, class_precision, class_recall = self.test()
elapsed_time_test = datetime.datetime.now() - start_time_test
test_time_ms = int(elapsed_time_test.total_seconds()*1000)
elapsed_test_time = time.time() - start_time_test

data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id)
data = EpochData(self.epoch_counter, num_epoch, elapsed_train_time, elapsed_test_time, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id)
self.epoch_results.append(data)

# Copy GPU tensors to CPU
for k, v in weights.items():
weights[k] = v.cpu()
return data, weights
end_func_time = time.time() - start_time_train
return data, weights, end_func_time

def save_model(self, epoch, suffix):
"""
Expand Down
37 changes: 24 additions & 13 deletions fltk/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def clients_ready(self):
time.sleep(2)
logging.info('All clients are ready')

def remote_run_epoch(self, epochs):
def remote_run_epoch(self, epochs_subset):
"""
Federated Learning steps:
1. Client selection
Expand All @@ -131,25 +131,33 @@ def remote_run_epoch(self, epochs):

for res in responses:
func_duration = res.future.wait()
res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration))
res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration()))
res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration, epochs_subset[0]))
res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_round_trip', res.duration(), epochs_subset[0]))
communication_duration_2way = res.duration() - func_duration
res.client.timing_data.append(
TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0]))
logging.info('Weights are updated')

# 3. Local training
responses: List[AsyncCall] = []
client_weights = []
for client in selected_clients:
response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs)
response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=len(epochs_subset))
responses.append(response)

self.epoch_counter += epochs
self.epoch_counter += len(epochs_subset)
for res in responses:
res.future.wait()
epoch_data, weights = res.future.wait()
epoch_data, weights, func_duration = res.future.wait()
self.client_data[epoch_data.client_id].append(epoch_data)
logging.info(f'{res.client.name} had a loss of {epoch_data.loss}')
logging.info(f'{res.client.name} had a epoch data of {epoch_data}')
res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'epoch_time_round_trip', res.duration()))
res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_inner', func_duration, epochs_subset[0]))
res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_train', epoch_data.duration_train, epochs_subset[0]))
res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_test', epoch_data.duration_test, epochs_subset[0]))
res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_round_trip', res.duration(), epochs_subset[0]))
communication_duration_2way = res.duration() - func_duration
res.client.timing_data.append(TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0]))

res.client.tb_writer.add_scalar('training loss',
epoch_data.loss_train, # for every 1000 minibatches
Expand Down Expand Up @@ -231,14 +239,17 @@ def run(self):
self.clients_ready()
self.update_client_data_sizes()

epoch_to_run = self.num_epoch
addition = 0


# Get total epoch to run
epoch_to_run = self.config.epochs
epoch_size = self.config.epochs_per_cycle
for epoch in range(epoch_to_run):
logging.info(f'Running epoch {epoch}')
self.remote_run_epoch(epoch_size)
addition += 1

epochs = list(range(1, epoch_to_run + 1))
epoch_chunks = [epochs[x:x + epoch_size] for x in range(0, len(epochs), epoch_size)]
for epoch_subset in epoch_chunks:
logging.info(f'Running epochs {epoch_subset}')
self.remote_run_epoch(epoch_subset)
logging.info('Available clients with data')
logging.info(self.client_data.keys())

Expand Down
5 changes: 3 additions & 2 deletions fltk/util/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
@dataclass
class EpochData:
epoch_id: int
duration_train: int
duration_test: int
num_epochs: int
duration_train: float
duration_test: float
loss_train: float
accuracy: float
loss: float
Expand Down

0 comments on commit e458dfc

Please sign in to comment.