-
Notifications
You must be signed in to change notification settings - Fork 0
/
template_fpn.py
42 lines (34 loc) · 1.32 KB
/
template_fpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from ..registry import NECKS
from .fpn import FPN
@NECKS.register_module
class TemplateFPN(FPN):
def __init__(self,
in_channels,
out_channels,
num_outs,
merge_method='concat',
**kwargs):
self.merge_method = merge_method
if self.merge_method == 'concat':
in_channels = [2 * t for t in in_channels]
super(TemplateFPN, self).__init__(in_channels, out_channels, num_outs,
**kwargs)
def merge_feats(self, template_feat, fact_feat):
if self.merge_method == 'concat':
# here, we use the simple concat to merge the features
# you can explore more methods for more possibility
# and better performance
x = [
torch.cat((x1, x2), dim=1)
for x1, x2 in zip(template_feat, fact_feat)
]
return x
else:
raise NotImplementedError
def forward(self, template_feat, fact_feat):
assert len(template_feat) == len(fact_feat) == len(self.in_channels)
# merge the features from two images
inputs = self.merge_feats(template_feat, fact_feat)
# then forward as FPN
return super(TemplateFPN, self).forward(inputs)