Skip to content

Commit

Permalink
fix intermittently failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 15, 2024
1 parent 462b929 commit a7ab39e
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/pytorch_tests/function_tests/test_hessian_info_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class basic_model(torch.nn.Module):
def __init__(self):
super(basic_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn = BatchNorm2d(3)
self.relu = ReLU()

Expand All @@ -50,10 +50,10 @@ def forward(self, inp):
class advanced_model(torch.nn.Module):
def __init__(self):
super(advanced_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.dense = Linear(8, 7)
Expand All @@ -72,10 +72,10 @@ def forward(self, inp):
class multiple_outputs_model(torch.nn.Module):
def __init__(self):
super(multiple_outputs_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu1 = ReLU()
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn2 = BatchNorm2d(3)
self.relu2 = ReLU()
self.hswish = Hardswish()
Expand All @@ -96,8 +96,8 @@ def forward(self, inp):
class multiple_inputs_model(torch.nn.Module):
def __init__(self):
super(multiple_inputs_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.conv2 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)

def forward(self, inp1, inp2):
x1 = self.conv1(inp1)
Expand All @@ -108,7 +108,7 @@ def forward(self, inp1, inp2):
class reused_model(torch.nn.Module):
def __init__(self):
super(reused_model, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv1 = Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
self.bn1 = BatchNorm2d(3)
self.relu = ReLU()

Expand Down Expand Up @@ -222,6 +222,7 @@ class WeightsHessianTraceBasicModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=basic_model)
self.val_batch_size = 1
self.n_iters = 10

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
Expand All @@ -243,6 +244,7 @@ class WeightsHessianTraceAdvanceModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=advanced_model, n_iters=3)
self.val_batch_size = 2
self.n_iters = 10

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
Expand All @@ -267,6 +269,7 @@ class WeightsHessianTraceMultipleOutputsModelTest(BaseHessianTraceBasicModelTest
def __init__(self, unit_test):
super().__init__(unit_test, model=multiple_outputs_model, n_iters=3)
self.val_batch_size = 1
self.n_iters = 10

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
Expand All @@ -291,6 +294,7 @@ class WeightsHessianTraceReuseModelTest(BaseHessianTraceBasicModelTest):
def __init__(self, unit_test):
super().__init__(unit_test, model=reused_model, n_iters=3)
self.val_batch_size = 1
self.n_iters = 10

def run_test(self, seed=0):
graph, pytorch_impl = self._setup()
Expand Down

0 comments on commit a7ab39e

Please sign in to comment.