From 8358e0344a02e4e6b309b4daf4aea5de1021d254 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 29 Nov 2024 11:29:06 +0100 Subject: [PATCH] Reintroduced old validation loss calculation, let's see if this fixes something --- mala/network/trainer.py | 267 ++++++++++++++++++++++++++++++++++------ 1 file changed, 231 insertions(+), 36 deletions(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index b5eb0892..76cf5b55 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -675,46 +675,241 @@ def _validate_network(self, data_set_fractions, metrics): ) loader_id += 1 else: - with torch.no_grad(): - for snapshot_number in trange( - offset_snapshots, - number_of_snapshots + offset_snapshots, - desc="Validation", - disable=self.parameters_full.verbosity < 2, - ): - # Get optimal batch size and number of batches per snapshotss - grid_size = ( - self.data.parameters.snapshot_directories_list[ - snapshot_number - ].grid_size - ) + # If only the LDOS is in the validation metrics (as is the + # case for, e.g., distributed network trainings), we can + # use a faster (or at least better parallelizing) code + if ( + len(self.parameters.validation_metrics) == 1 + and self.parameters.validation_metrics[0] == "ldos" + ): + validation_loss_sum = torch.zeros( + 1, device=self.parameters._configuration["device"] + ) + with torch.no_grad(): + if self.parameters._configuration["gpu"]: + report_freq = self.parameters.training_log_interval + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + tsample = time.time() + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to( + self.parameters._configuration[ + "device" + ], + non_blocking=True, + ) + y = y.to( + self.parameters._configuration[ + "device" + ], + non_blocking=True, + ) + + if ( + self.parameters.use_graphs + and self.validation_graph is None + ): + printout( + "Capturing CUDA graph for validation.", + min_verbosity=2, + ) + s = torch.cuda.Stream( + self.parameters._configuration[ + "device" + ] + ) + s.wait_stream( + torch.cuda.current_stream( + self.parameters._configuration[ + "device" + ] + ) + ) + # Warmup for graphs + with torch.cuda.stream(s): + for _ in range(20): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network( + x + ) + if ( + self.parameters_full.use_ddp + ): + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + torch.cuda.current_stream( + self.parameters._configuration[ + "device" + ] + ).wait_stream(s) + + # Create static entry point tensors to graph + self.static_input_validation = ( + torch.empty_like(x) + ) + self.static_target_validation = ( + torch.empty_like(y) + ) + + # Capture graph + self.validation_graph = ( + torch.cuda.CUDAGraph() + ) + with torch.cuda.graph( + self.validation_graph + ): + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + self.static_prediction_validation = self.network( + self.static_input_validation + ) + if ( + self.parameters_full.use_ddp + ): + self.static_loss_validation = self.network.module.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + else: + self.static_loss_validation = self.network.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + + if self.validation_graph: + self.static_input_validation.copy_(x) + self.static_target_validation.copy_(y) + self.validation_graph.replay() + validation_loss_sum += ( + self.static_loss_validation + ) + else: + with torch.cuda.amp.autocast( + enabled=self.parameters.use_mixed_precision + ): + prediction = self.network(x) + if self.parameters_full.use_ddp: + loss = self.network.module.calculate_loss( + prediction, y + ) + else: + loss = self.network.calculate_loss( + prediction, y + ) + validation_loss_sum += loss + if ( + batchid != 0 + and (batchid + 1) % report_freq == 0 + ): + torch.cuda.synchronize( + self.parameters._configuration[ + "device" + ] + ) + sample_time = time.time() - tsample + avg_sample_time = ( + sample_time / report_freq + ) + avg_sample_tput = ( + report_freq + * x.shape[0] + / sample_time + ) + printout( + f"batch {batchid + 1}, " # /{total_samples}, " + f"validation avg time: {avg_sample_time} " + f"validation avg throughput: {avg_sample_tput}", + min_verbosity=2, + ) + tsample = time.time() + batchid += 1 + torch.cuda.synchronize( + self.parameters._configuration["device"] + ) + else: + batchid = 0 + for loader in data_loaders: + for x, y in loader: + x = x.to( + self.parameters._configuration[ + "device" + ] + ) + y = y.to( + self.parameters._configuration[ + "device" + ] + ) + prediction = self.network(x) + if self.parameters_full.use_ddp: + validation_loss_sum += ( + self.network.module.calculate_loss( + prediction, y + ).item() + ) + else: + validation_loss_sum += ( + self.network.calculate_loss( + prediction, y + ).item() + ) + batchid += 1 + + validation_loss = validation_loss_sum.item() / batchid + errors[data_set_type]["ldos"] = validation_loss - optimal_batch_size = self._correct_batch_size( - grid_size, self.parameters.mini_batch_size - ) - number_of_batches_per_snapshot = int( - grid_size / optimal_batch_size - ) + else: + with torch.no_grad(): + for snapshot_number in trange( + offset_snapshots, + number_of_snapshots + offset_snapshots, + desc="Validation", + disable=self.parameters_full.verbosity < 2, + ): + # Get optimal batch size and number of batches per snapshotss + grid_size = ( + self.data.parameters.snapshot_directories_list[ + snapshot_number + ].grid_size + ) - actual_outputs, predicted_outputs = ( - self._forward_entire_snapshot( - snapshot_number, - data_sets[0], - data_set_type[0:2], - number_of_batches_per_snapshot, - optimal_batch_size, + optimal_batch_size = self._correct_batch_size( + grid_size, self.parameters.mini_batch_size ) - ) - calculated_errors = self._calculate_errors( - actual_outputs, - predicted_outputs, - metrics, - snapshot_number, - ) - for metric in metrics: - errors[data_set_type][metric].append( - calculated_errors[metric] + number_of_batches_per_snapshot = int( + grid_size / optimal_batch_size + ) + + actual_outputs, predicted_outputs = ( + self._forward_entire_snapshot( + snapshot_number, + data_sets[0], + data_set_type[0:2], + number_of_batches_per_snapshot, + optimal_batch_size, + ) ) + calculated_errors = self._calculate_errors( + actual_outputs, + predicted_outputs, + metrics, + snapshot_number, + ) + for metric in metrics: + errors[data_set_type][metric].append( + calculated_errors[metric] + ) return errors def __prepare_to_train(self, optimizer_dict):