diff --git a/models/transformer.py b/models/transformer.py index 2853a43..6a30300 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -111,12 +111,9 @@ def forward(self, tgt, memory, pos=pos, query_pos=query_pos) if self.return_intermediate: intermediate.append(self.norm(output)) - + if self.norm is not None: output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate)