Skip to content

Commit

Permalink
Enable hpu graph tests. (#141)
Browse files Browse the repository at this point in the history
* Fix hpu graph tests

* Update test_hpu_graphs.py

---------

Co-authored-by: Jerome Anand <[email protected]>
  • Loading branch information
ankitgola005 and jerome-habana authored Feb 2, 2024
1 parent c6c3a31 commit 05452d0
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions tests/test_pytorch/test_hpu_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def train_with_capture_and_replay(self, batch, batch_idx):
# Then we capture
optimizer.zero_grad(set_to_none=True)
with htcore.hpu.graph(self.g):
static_y_pred = self(self.static_input)
self.static_loss = f.cross_entropy(static_y_pred, self.static_target)
self.static_y_pred = self(self.static_input)
self.static_loss = f.cross_entropy(self.static_y_pred, self.static_target)
self.static_loss.backward()
optimizer.step()
return self.static_loss
Expand Down Expand Up @@ -149,25 +149,14 @@ def validation_step_automatic(self, batch, batch_idx):

def test_step(self, batch, batch_idx):
"""Test step."""
x, y = batch
if self.graph_mode == HPUGraphMode.INFERENCE_CAPTURE_AND_REPLAY:
if batch_idx == 0:
with htcore.hpu.graph(self.g):
static_y_pred = self.forward(self.static_input)
self.static_loss = f.cross_entropy(static_y_pred, self.static_target)
else:
self.static_input.copy_(x)
self.static_target.copy_(y)
self.g.replay()
self.log("test_acc", self.accuracy(None, y, self.static_y_pred))
self.validation_step_capture_replay(batch, batch_idx)
else:
self.log("test_acc", self.accuracy(self.forward(x), y))
self.validation_step_automatic(batch, batch_idx)

@staticmethod
def accuracy(logits, y, pred=None):
def accuracy(logits, y):
"""Calculate accuracy."""
if pred is not None:
return torch.sum(torch.eq(pred, y).to(torch.float32)) / len(y)
return torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)

def configure_optimizers(self):
Expand Down Expand Up @@ -242,12 +231,12 @@ def test_hpu_graphs(tmpdir, graph_mode, mode):
@pytest.mark.parametrize(
"train_modes",
[
# [(HPUGraphMode.TRAIN_NONE), (HPUGraphMode.TRAIN_CAPTURE_AND_REPLAY)], #Fix me: - accuracy issue
[(HPUGraphMode.TRAIN_NONE), (HPUGraphMode.TRAIN_CAPTURE_AND_REPLAY)],
[(HPUGraphMode.TRAIN_NONE), (HPUGraphMode.TRAIN_MAKE_GRAPHED_CALLABLES)],
[(HPUGraphMode.TRAIN_NONE), (HPUGraphMode.TRAIN_MODULECACHER)],
],
ids=[
# "baseline_vs_capture_and_replay",
"baseline_vs_capture_and_replay",
"baseline_vs_make_graphed_callables",
"baseline_vs_modulecacher",
],
Expand All @@ -260,10 +249,13 @@ def test_hpu_graph_accuracy_train(tmpdir, train_modes):
data_module = MNISTDataModule(batch_size=200)
loss_metrics.append(train_model(tmpdir, 1, model=hpu_graph_model, data_module=data_module, profiler=None))
assert torch.allclose(
loss_metrics[0]["train_loss"], loss_metrics[1]["train_loss"], rtol=0.05
loss_metrics[0]["train_loss"],
loss_metrics[1]["train_loss"],
rtol=0.05,
atol=0.05,
), loss_metrics # Compare train loss
assert torch.allclose(
loss_metrics[0]["val_acc"], loss_metrics[1]["val_acc"], rtol=0.05
loss_metrics[0]["val_acc"], loss_metrics[1]["val_acc"], rtol=0.05, atol=0.05
), loss_metrics # Compare val acc


Expand All @@ -288,8 +280,8 @@ def test_hpu_graph_accuracy_inference(tmpdir, train_modes):
train_model(tmpdir, 1, model=hpu_graph_model, data_module=data_module, mode="test", profiler=None)
)
assert torch.allclose(
loss_metrics[0]["test_acc"], loss_metrics[1]["test_acc"], rtol=0.05
), loss_metrics # Compare test acc
loss_metrics[0]["val_acc"], loss_metrics[1]["val_acc"], rtol=0.05, atol=0.05
), loss_metrics # Compare val acc


def test_automatic_optimization_graph_capture(tmpdir):
Expand Down

0 comments on commit 05452d0

Please sign in to comment.