From 79cb62707fe8b86d8044635023bdb7f45dc0457d Mon Sep 17 00:00:00 2001 From: zhoujun Date: Wed, 10 Jun 2020 10:22:12 +0800 Subject: [PATCH] fix a bug when use torch.jit.trace convert model (#79) --- resnest/torch/splat.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/resnest/torch/splat.py b/resnest/torch/splat.py index 18ac60d..c3f21b1 100644 --- a/resnest/torch/splat.py +++ b/resnest/torch/splat.py @@ -54,7 +54,10 @@ def forward(self, x): batch, rchannel = x.shape[:2] if self.radix > 1: - splited = torch.split(x, rchannel//self.radix, dim=1) + if torch.__version__ < '1.5': + splited = torch.split(x, int(rchannel//self.radix), dim=1) + else: + splited = torch.split(x, rchannel//self.radix, dim=1) gap = sum(splited) else: gap = x @@ -69,7 +72,10 @@ def forward(self, x): atten = self.rsoftmax(atten).view(batch, -1, 1, 1) if self.radix > 1: - attens = torch.split(atten, rchannel//self.radix, dim=1) + if torch.__version__ < '1.5': + attens = torch.split(atten, int(rchannel//self.radix), dim=1) + else: + attens = torch.split(atten, rchannel//self.radix, dim=1) out = sum([att*split for (att, split) in zip(attens, splited)]) else: out = atten * x