Skip to content

Commit

Permalink
phinet and xinet for solo more general, fixed #85
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebo-the-tramp committed Feb 28, 2024
1 parent b4b45c0 commit 062ed5a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
17 changes: 16 additions & 1 deletion micromind/networks/phinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,11 @@ def __init__(
kernel_size=k_size,
stride=stride,
bias=False,
padding=k_size // 2 if stride == 1 else (padding[1], padding[3]),
padding=k_size // 2
if isinstance(k_size, int) and stride == 1
else [x // 2 for x in k_size]
if stride == 1
else (padding[1], padding[3]),
)

bn_dw1 = nn.BatchNorm2d(
Expand Down Expand Up @@ -629,6 +633,7 @@ def __init__(
squeeze_excite: bool = True, # S1
divisor: int = 1,
return_layers=None,
flattened_embeddings=False,
) -> None:
super(PhiNet, self).__init__()
self.alpha = alpha
Expand All @@ -637,6 +642,8 @@ def __init__(
self.num_layers = num_layers
self.num_classes = num_classes
self.return_layers = return_layers
self.flattened_embeddings = flattened_embeddings
self.features_dim = 0

if compatibility: # disables operations hard for some platforms
h_swish = False
Expand Down Expand Up @@ -802,6 +809,14 @@ def __init__(
)
block_id += 1

if self.flattened_embeddings:

flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())

self._layers.append(flatten)

self.num_features = _make_divisible(int(block_filters * alpha), divisor=divisor)

if include_top:
# Includes classification head if required
self.classifier = nn.Sequential(
Expand Down
11 changes: 11 additions & 0 deletions micromind/networks/xinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
include_top=False,
base_filters: int = 16,
return_layers: Optional[List] = None,
flattened_embeddings=False,
):
super().__init__()

Expand All @@ -258,6 +259,8 @@ def __init__(
self.include_top = include_top
self.return_layers = return_layers
count_downsample = 0
self.flattened_embeddings = flattened_embeddings
self.features_dim = 0

self.conv1 = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -340,6 +343,9 @@ def __init__(
for i in self.return_layers:
print(f"Layer {i} - {self._layers[i].__class__}")

if self.flattened_embeddings:
self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())

self.input_shape = input_shape
if self.include_top:
self.classifier = nn.Sequential(
Expand All @@ -348,6 +354,8 @@ def __init__(
nn.Linear(int(num_filters[-1] * alpha), num_classes),
)

self.num_features = int(num_filters[-1] * alpha)

def forward(self, x):
"""Computes the forward step of the XiNet.
Arguments
Expand All @@ -374,6 +382,9 @@ def forward(self, x):
if layer_id in self.return_layers:
ret.append(x)

if self.flattened_embeddings:
x = self.flatten(x)

if self.include_top:
x = self.classifier(x)

Expand Down

0 comments on commit 062ed5a

Please sign in to comment.