Skip to content

Commit

Permalink
fix a bug when use torch.jit.trace convert model (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou authored Jun 10, 2020
1 parent 31444ed commit 79cb627
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions resnest/torch/splat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def forward(self, x):

batch, rchannel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, rchannel//self.radix, dim=1)
if torch.__version__ < '1.5':
splited = torch.split(x, int(rchannel//self.radix), dim=1)
else:
splited = torch.split(x, rchannel//self.radix, dim=1)
gap = sum(splited)
else:
gap = x
Expand All @@ -69,7 +72,10 @@ def forward(self, x):
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

if self.radix > 1:
attens = torch.split(atten, rchannel//self.radix, dim=1)
if torch.__version__ < '1.5':
attens = torch.split(atten, int(rchannel//self.radix), dim=1)
else:
attens = torch.split(atten, rchannel//self.radix, dim=1)
out = sum([att*split for (att, split) in zip(attens, splited)])
else:
out = atten * x
Expand Down

0 comments on commit 79cb627

Please sign in to comment.