Skip to content

Commit

Permalink
Resolve issues in starting and running experiments (federated/batch)
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 18, 2022
1 parent a8192b0 commit 9ebc444
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 16 deletions.
2 changes: 1 addition & 1 deletion fltk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train(self, num_epochs: int, round_id: int):
start_time = time.time()
running_loss = 0.0
final_running_loss = 0.0
self.logger.info(f"[RD-{round_id}] kicking of local training for {num_epochs} local epochs")
for local_epoch in range(num_epochs):
effective_epoch = round_id * num_epochs + local_epoch
progress = f'[RD-{round_id}][LE-{local_epoch}][EE-{effective_epoch}]'
Expand All @@ -87,7 +88,6 @@ def train(self, num_epochs: int, round_id: int):

training_cardinality = len(self.dataset.get_train_loader())
self.logger.info(f'{progress}{self.id}: Number of training samples: {training_cardinality}')

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

Expand Down
11 changes: 0 additions & 11 deletions fltk/core/distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,6 @@ def _init_device(self, default_device: torch.device = torch.device('cpu')): # p
torch.cuda.is_available = lambda: False
return default_device

def load_default_model(self):
"""
@deprecated Load a model from default model file. This function could be used to ensure consistent default model
behavior. When using PyTorch's DistributedDataParallel, however, the first step will always synchronize the
model.
"""

model_file = Path(f'{self.model.__name__}.model')
default_model_path = Path(self.config.get_default_model_folder_path()).joinpath(model_file)
load_model_from_file(self.model, default_model_path)

def train(self, epoch, log_interval: int = 50):
"""
Function to start training, regardless of DistributedDataParallel (DPP) or local training. DDP will account for
Expand Down
5 changes: 1 addition & 4 deletions fltk/core/distributed/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,7 @@ def run(self, clear: bool = False, experiment_replication: int = 1) -> None:
self.deployed_tasks.add(curr_task)
if not self._config.cluster_config.orchestrator.parallel_execution:
self.wait_for_jobs_to_complete()

self.stop()

self._logger.debug("Still alive...")
time.sleep(self.SLEEP_TIME)
self.wait_for_jobs_to_complete()
logging.info('Experiment completed, currently does not support waiting.')
logging.info('Experiment completed.')

0 comments on commit 9ebc444

Please sign in to comment.