diff --git a/micromind/networks/yolo.py b/micromind/networks/yolo.py index 804bedb..470b3dc 100644 --- a/micromind/networks/yolo.py +++ b/micromind/networks/yolo.py @@ -345,12 +345,15 @@ class Yolov8Neck(nn.Module): Arguments --------- - w : float - Width multiple of the Darknet. - r : float - Ratio multiple of the Darknet. - d : float - Depth multiple of the Darknet. + filters : list, optional + List of filter sizes for different layers. Default: [256, 512, 768]. + up : list, optional + List of upsampling factors. Default: [2, 2]. + heads : list, optional + List indicating whether each detection head is active. + Default: [True, True, True]. + d : float, optional + Depth multiple of the Darknet. Default: 1. """ def __init__( @@ -373,9 +376,9 @@ def __init__( shortcut=False, ) """ - Only if we decide to use teh 2nd and 3rd detection head we define - the needed blocks. Otherwise the not needed blcoks would be initialied - (and thus would occupy space) but never used. + Only if we decide to use the 2nd and 3rd detection head we define + the needed blocks. Otherwise the not needed blocks would be initialized + (and thus would occupy space) but will never be used. """ if self.heads[1] or self.heads[2]: self.n3 = Conv( @@ -411,12 +414,17 @@ def forward(self, p3, p4, p5): Arguments --------- - x : tuple - Input to the neck. + p3 : torch.Tensor + First feature map coming from the backbone. + p4 : torch.Tensor + Second feature map coming from the backbone. + p5 : torch.Tensor + Third feature map coming from the backbone. Returns ------- - Three intermediate representations with different resolutions : list + list[torch.Tensor] + Three intermediate representations with different resolutions. """ x = self.up1(p5) x = torch.cat((x, p4), dim=1) @@ -425,21 +433,21 @@ def forward(self, p3, p4, p5): h1 = torch.cat((h1, p3), dim=1) head_1 = self.n2(h1) return_heads = [] - """ - Only if we decide to use teh 2nd and 3rd detection head we execute the - needed blocks. Otherwise the not needed blcoks would be initialied - (and thus would occupy space) but never used. - """ + + # here we check if the 1st head should be returned if self.heads[0]: return_heads.append(head_1) + # here we check if the 2nd head should be executed if self.heads[1] or self.heads[2]: h2 = self.n3(head_1) h2 = torch.cat((h2, x), dim=1) head_2 = self.n4(h2) + # here we check if the 2nd head should be returned if self.heads[1]: return_heads.append(head_2) + # here we check if the 3rd head should beexecuted and returned if self.heads[2]: h3 = self.n5(head_2) h3 = torch.cat((h3, p5), dim=1) @@ -457,6 +465,9 @@ class DetectionHead(nn.Module): Number of classes to predict. filters : tuple Number of channels of the three inputs of the detection head. + heads : list, optional + List indicating whether each detection head is active. + Default: [True, True, True]. """ def __init__(self, nc=80, filters=(), heads=[True, True, True]): @@ -492,8 +503,11 @@ def forward(self, x): Arguments --------- - x : list + x : list[torch.Tensor] Input to the detection head. + In the YOLOv8 standard implementation it contains the three outputs of + the neck. In a more general case it contains as many tensors as the number + of active heads in the initialization. Returns ------- diff --git a/recipes/object_detection/train.py b/recipes/object_detection/train.py index e207537..9cd2114 100644 --- a/recipes/object_detection/train.py +++ b/recipes/object_detection/train.py @@ -64,6 +64,21 @@ def get_parameters(self, heads=[True, True, True]): """ Gets the parameters with which to initialize the network detection part (SPPF block, Yolov8Neck, DetectionHead). + + Arguments + --------- + heads : list, optional + List indicating whether each detection head is active. + Default: [True, True, True]. + + Returns + ------- + tuple + Tuple containing the parameters for initializing the network detection part: + - Tuple (c1, c2): Tuple of input channel sizes for the SPPF block. + - List neck_filters: List of filter sizes for Yolov8Neck. + - List up: List of upsampling factors for Yolov8Neck. + - List head_filters: List of filter sizes for DetectionHead. """ in_shape = self.modules["backbone"].input_shape x = torch.randn(1, *in_shape)