From e1deb8e87071cd98e167c8ad3d2f15925f56e85b Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Tue, 6 Feb 2024 14:57:17 +0100 Subject: [PATCH] chore: add BatchNorm operator in compile_torch tests --- src/concrete/ml/pytest/torch_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index 08e0d834f..f741ca073 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -194,6 +194,7 @@ def __init__(self, input_output, activation_function): self.conv1 = nn.Conv2d(input_output, 3, 3, stride=1, padding=1) self.pool = nn.AvgPool2d(2, 2) self.conv2 = nn.Conv2d(3, 3, 1) + self.bn1 = nn.BatchNorm2d(3) self.fc1 = nn.Linear(3 * 3 * 3, 5) self.fc2 = nn.Linear(5, 3) self.fc3 = nn.Linear(3, 2) @@ -210,6 +211,7 @@ def forward(self, x): """ x = self.pool(self.activation_function(self.conv1(x))) x = self.activation_function(self.conv2(x)) + x = self.bn1(x) x = x.flatten(1) x = self.activation_function(self.fc1(x)) x = self.activation_function(self.fc2(x)) @@ -1555,6 +1557,7 @@ def __init__(self, input_output, activation_function) -> None: super().__init__() self.conv1 = nn.Conv1d(input_output, 2, 2, stride=1, padding=0) + self.bn1 = nn.BatchNorm1d(2) self.act = activation_function() self.fc1 = nn.Linear(input_output, 3) @@ -1569,5 +1572,6 @@ def forward(self, x): """ x = self.act(self.conv1(x)) + x = self.bn1(x) x = self.fc1(x) return x