Skip to content

Commit

Permalink
Reintroduced old validation loss calculation, let's see if this fixes…
Browse files Browse the repository at this point in the history
… something
  • Loading branch information
RandomDefaultUser committed Nov 29, 2024
1 parent dddccd4 commit 8358e03
Showing 1 changed file with 231 additions and 36 deletions.
267 changes: 231 additions & 36 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,46 +675,241 @@ def _validate_network(self, data_set_fractions, metrics):
)
loader_id += 1
else:
with torch.no_grad():
for snapshot_number in trange(
offset_snapshots,
number_of_snapshots + offset_snapshots,
desc="Validation",
disable=self.parameters_full.verbosity < 2,
):
# Get optimal batch size and number of batches per snapshotss
grid_size = (
self.data.parameters.snapshot_directories_list[
snapshot_number
].grid_size
)
# If only the LDOS is in the validation metrics (as is the
# case for, e.g., distributed network trainings), we can
# use a faster (or at least better parallelizing) code
if (
len(self.parameters.validation_metrics) == 1
and self.parameters.validation_metrics[0] == "ldos"
):
validation_loss_sum = torch.zeros(
1, device=self.parameters._configuration["device"]
)
with torch.no_grad():
if self.parameters._configuration["gpu"]:
report_freq = self.parameters.training_log_interval
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
tsample = time.time()
batchid = 0
for loader in data_loaders:
for x, y in loader:
x = x.to(
self.parameters._configuration[
"device"
],
non_blocking=True,
)
y = y.to(
self.parameters._configuration[
"device"
],
non_blocking=True,
)

if (
self.parameters.use_graphs
and self.validation_graph is None
):
printout(
"Capturing CUDA graph for validation.",
min_verbosity=2,
)
s = torch.cuda.Stream(
self.parameters._configuration[
"device"
]
)
s.wait_stream(
torch.cuda.current_stream(
self.parameters._configuration[
"device"
]
)
)
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = self.network(
x
)
if (
self.parameters_full.use_ddp
):
loss = self.network.module.calculate_loss(
prediction, y
)
else:
loss = self.network.calculate_loss(
prediction, y
)
torch.cuda.current_stream(
self.parameters._configuration[
"device"
]
).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = (
torch.empty_like(x)
)
self.static_target_validation = (
torch.empty_like(y)
)

# Capture graph
self.validation_graph = (
torch.cuda.CUDAGraph()
)
with torch.cuda.graph(
self.validation_graph
):
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
self.static_prediction_validation = self.network(
self.static_input_validation
)
if (
self.parameters_full.use_ddp
):
self.static_loss_validation = self.network.module.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)
else:
self.static_loss_validation = self.network.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)

if self.validation_graph:
self.static_input_validation.copy_(x)
self.static_target_validation.copy_(y)
self.validation_graph.replay()
validation_loss_sum += (
self.static_loss_validation
)
else:
with torch.cuda.amp.autocast(
enabled=self.parameters.use_mixed_precision
):
prediction = self.network(x)
if self.parameters_full.use_ddp:
loss = self.network.module.calculate_loss(
prediction, y
)
else:
loss = self.network.calculate_loss(
prediction, y
)
validation_loss_sum += loss
if (
batchid != 0
and (batchid + 1) % report_freq == 0
):
torch.cuda.synchronize(
self.parameters._configuration[
"device"
]
)
sample_time = time.time() - tsample
avg_sample_time = (
sample_time / report_freq
)
avg_sample_tput = (
report_freq
* x.shape[0]
/ sample_time
)
printout(
f"batch {batchid + 1}, " # /{total_samples}, "
f"validation avg time: {avg_sample_time} "
f"validation avg throughput: {avg_sample_tput}",
min_verbosity=2,
)
tsample = time.time()
batchid += 1
torch.cuda.synchronize(
self.parameters._configuration["device"]
)
else:
batchid = 0
for loader in data_loaders:
for x, y in loader:
x = x.to(
self.parameters._configuration[
"device"
]
)
y = y.to(
self.parameters._configuration[
"device"
]
)
prediction = self.network(x)
if self.parameters_full.use_ddp:
validation_loss_sum += (
self.network.module.calculate_loss(
prediction, y
).item()
)
else:
validation_loss_sum += (
self.network.calculate_loss(
prediction, y
).item()
)
batchid += 1

validation_loss = validation_loss_sum.item() / batchid
errors[data_set_type]["ldos"] = validation_loss

optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
number_of_batches_per_snapshot = int(
grid_size / optimal_batch_size
)
else:
with torch.no_grad():
for snapshot_number in trange(
offset_snapshots,
number_of_snapshots + offset_snapshots,
desc="Validation",
disable=self.parameters_full.verbosity < 2,
):
# Get optimal batch size and number of batches per snapshotss
grid_size = (
self.data.parameters.snapshot_directories_list[
snapshot_number
].grid_size
)

actual_outputs, predicted_outputs = (
self._forward_entire_snapshot(
snapshot_number,
data_sets[0],
data_set_type[0:2],
number_of_batches_per_snapshot,
optimal_batch_size,
optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
)
calculated_errors = self._calculate_errors(
actual_outputs,
predicted_outputs,
metrics,
snapshot_number,
)
for metric in metrics:
errors[data_set_type][metric].append(
calculated_errors[metric]
number_of_batches_per_snapshot = int(
grid_size / optimal_batch_size
)

actual_outputs, predicted_outputs = (
self._forward_entire_snapshot(
snapshot_number,
data_sets[0],
data_set_type[0:2],
number_of_batches_per_snapshot,
optimal_batch_size,
)
)
calculated_errors = self._calculate_errors(
actual_outputs,
predicted_outputs,
metrics,
snapshot_number,
)
for metric in metrics:
errors[data_set_type][metric].append(
calculated_errors[metric]
)
return errors

def __prepare_to_train(self, optimizer_dict):
Expand Down

0 comments on commit 8358e03

Please sign in to comment.