Skip to content

Commit

Permalink
chore: add BatchNorm operator in compile_torch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Feb 7, 2024
1 parent 09ad7a6 commit 503e049
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self, input_output, activation_function):
self.conv1 = nn.Conv2d(input_output, 3, 3, stride=1, padding=1)
self.pool = nn.AvgPool2d(2, 2)
self.conv2 = nn.Conv2d(3, 3, 1)
self.bn1 = nn.BatchNorm2d(3)
self.fc1 = nn.Linear(3 * 3 * 3, 5)
self.fc2 = nn.Linear(5, 3)
self.fc3 = nn.Linear(3, 2)
Expand All @@ -210,6 +211,7 @@ def forward(self, x):
"""
x = self.pool(self.activation_function(self.conv1(x)))
x = self.activation_function(self.conv2(x))
x = self.bn1(x)
x = x.flatten(1)
x = self.activation_function(self.fc1(x))
x = self.activation_function(self.fc2(x))
Expand Down Expand Up @@ -1555,6 +1557,7 @@ def __init__(self, input_output, activation_function) -> None:
super().__init__()

self.conv1 = nn.Conv1d(input_output, 2, 2, stride=1, padding=0)
self.bn1 = nn.BatchNorm1d(2)
self.act = activation_function()
self.fc1 = nn.Linear(input_output, 3)

Expand All @@ -1569,5 +1572,6 @@ def forward(self, x):
"""
x = self.act(self.conv1(x))
x = self.bn1(x)
x = self.fc1(x)
return x

0 comments on commit 503e049

Please sign in to comment.