-
Notifications
You must be signed in to change notification settings - Fork 13
/
wiw.py
executable file
·97 lines (80 loc) · 4.54 KB
/
wiw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Copyright (c) 2019-present NAVER Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch.nn as nn
# for debug
# from modules.transformation import TPS_SpatialTransformerNetwork
# from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
# from modules.sequence_modeling import BidirectionalLSTM
# from modules.prediction import Attention
from .modules.transformation import TPS_SpatialTransformerNetwork
from .modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
from .modules.sequence_modeling import BidirectionalLSTM
from .modules.prediction import Attention
class WIW(nn.Module):
def __init__(self, **kwargs):
super(WIW, self).__init__()
self.kwargs = kwargs
self.stages = {'Trans': kwargs['Transformation'], 'Feat': kwargs['FeatureExtraction'],
'Seq': kwargs['SequenceModeling'], 'Pred': kwargs['Prediction']}
""" Transformation """
if kwargs['Transformation'] == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=kwargs['num_fiducial'], I_size=(kwargs['imgH'], kwargs['imgW']), I_r_size=(kwargs['imgH'], kwargs['imgW']), I_channel_num=kwargs['input_channel'])
else:
print('No Transformation module specified')
""" FeatureExtraction """
if kwargs['FeatureExtraction'] == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(kwargs['input_channel'], kwargs['output_channel'])
elif kwargs['FeatureExtraction'] == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(kwargs['input_channel'], kwargs['output_channel'])
elif kwargs['FeatureExtraction'] == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(kwargs['input_channel'], kwargs['output_channel'])
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = kwargs['output_channel'] # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if kwargs['SequenceModeling'] == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, kwargs['hidden_size'], kwargs['hidden_size']),
BidirectionalLSTM(kwargs['hidden_size'], kwargs['hidden_size'], kwargs['hidden_size']))
self.SequenceModeling_output = kwargs['hidden_size']
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if kwargs['Prediction'] == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, kwargs['num_class'])
elif kwargs['Prediction'] == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, kwargs['hidden_size'], kwargs['num_class'])
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text, is_train=True):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)
""" Sequence modeling stage """
if self.stages['Seq'] == 'BiLSTM':
contextual_feature = self.SequenceModeling(visual_feature)
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.kwargs['batch_max_length'])
return prediction