Skip to content

Commit

Permalink
Merge pull request #14 from ducha-aiki/patch-2
Browse files Browse the repository at this point in the history
Fix to map5
  • Loading branch information
radekosmulski authored Feb 15, 2019
2 parents 2323482 + bfbe8e2 commit 48c336a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def map5kfast(preds, targs, k=10):
scores[:,kk] = (top_5[:,kk] == targs).float() / float((kk+1))
return scores.max(dim=1)[0].mean()

def map5(preds, targs):
return map5kfast(preds, targs, 5)
def map5(preds,targs):
if type(preds) is list:
return torch.cat([map5fast(p, targs, 5).view(1) for p in preds ]).mean()
return map5fast(preds,targs, 5)

def top_5_preds(preds): return np.argsort(preds.numpy())[:, ::-1][:, :5]

Expand Down

0 comments on commit 48c336a

Please sign in to comment.