Skip to content

Commit

Permalink
improved docs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobeltrami committed Dec 12, 2023
1 parent 11f66f0 commit f03cbdc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
50 changes: 32 additions & 18 deletions micromind/networks/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,15 @@ class Yolov8Neck(nn.Module):
Arguments
---------
w : float
Width multiple of the Darknet.
r : float
Ratio multiple of the Darknet.
d : float
Depth multiple of the Darknet.
filters : list, optional
List of filter sizes for different layers. Default: [256, 512, 768].
up : list, optional
List of upsampling factors. Default: [2, 2].
heads : list, optional
List indicating whether each detection head is active.
Default: [True, True, True].
d : float, optional
Depth multiple of the Darknet. Default: 1.
"""

def __init__(
Expand All @@ -373,9 +376,9 @@ def __init__(
shortcut=False,
)
"""
Only if we decide to use teh 2nd and 3rd detection head we define
the needed blocks. Otherwise the not needed blcoks would be initialied
(and thus would occupy space) but never used.
Only if we decide to use the 2nd and 3rd detection head we define
the needed blocks. Otherwise the not needed blocks would be initialized
(and thus would occupy space) but will never be used.
"""
if self.heads[1] or self.heads[2]:
self.n3 = Conv(
Expand Down Expand Up @@ -411,12 +414,17 @@ def forward(self, p3, p4, p5):
Arguments
---------
x : tuple
Input to the neck.
p3 : torch.Tensor
First feature map coming from the backbone.
p4 : torch.Tensor
Second feature map coming from the backbone.
p5 : torch.Tensor
Third feature map coming from the backbone.
Returns
-------
Three intermediate representations with different resolutions : list
list[torch.Tensor]
Three intermediate representations with different resolutions.
"""
x = self.up1(p5)
x = torch.cat((x, p4), dim=1)
Expand All @@ -425,21 +433,21 @@ def forward(self, p3, p4, p5):
h1 = torch.cat((h1, p3), dim=1)
head_1 = self.n2(h1)
return_heads = []
"""
Only if we decide to use teh 2nd and 3rd detection head we execute the
needed blocks. Otherwise the not needed blcoks would be initialied
(and thus would occupy space) but never used.
"""

# here we check if the 1st head should be returned
if self.heads[0]:
return_heads.append(head_1)

# here we check if the 2nd head should be executed
if self.heads[1] or self.heads[2]:
h2 = self.n3(head_1)
h2 = torch.cat((h2, x), dim=1)
head_2 = self.n4(h2)
# here we check if the 2nd head should be returned
if self.heads[1]:
return_heads.append(head_2)

# here we check if the 3rd head should beexecuted and returned
if self.heads[2]:
h3 = self.n5(head_2)
h3 = torch.cat((h3, p5), dim=1)
Expand All @@ -457,6 +465,9 @@ class DetectionHead(nn.Module):
Number of classes to predict.
filters : tuple
Number of channels of the three inputs of the detection head.
heads : list, optional
List indicating whether each detection head is active.
Default: [True, True, True].
"""

def __init__(self, nc=80, filters=(), heads=[True, True, True]):
Expand Down Expand Up @@ -492,8 +503,11 @@ def forward(self, x):
Arguments
---------
x : list
x : list[torch.Tensor]
Input to the detection head.
In the YOLOv8 standard implementation it contains the three outputs of
the neck. In a more general case it contains as many tensors as the number
of active heads in the initialization.
Returns
-------
Expand Down
15 changes: 15 additions & 0 deletions recipes/object_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def get_parameters(self, heads=[True, True, True]):
"""
Gets the parameters with which to initialize the network detection part
(SPPF block, Yolov8Neck, DetectionHead).
Arguments
---------
heads : list, optional
List indicating whether each detection head is active.
Default: [True, True, True].
Returns
-------
tuple
Tuple containing the parameters for initializing the network detection part:
- Tuple (c1, c2): Tuple of input channel sizes for the SPPF block.
- List neck_filters: List of filter sizes for Yolov8Neck.
- List up: List of upsampling factors for Yolov8Neck.
- List head_filters: List of filter sizes for DetectionHead.
"""
in_shape = self.modules["backbone"].input_shape
x = torch.randn(1, *in_shape)
Expand Down

0 comments on commit f03cbdc

Please sign in to comment.