Skip to content

Commit

Permalink
Fix calculate_kl
Browse files Browse the repository at this point in the history
  • Loading branch information
Piyush-555 committed Aug 8, 2020
1 parent 81e6a35 commit 16e2462
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def acc(outputs, targets):
return np.mean(outputs.cpu().numpy().argmax(axis=1) == targets.data.cpu().numpy())


def calculate_kl(mu_p, sig_p, mu_q, sig_q):
def calculate_kl(mu_q, sig_q, mu_p, sig_p):
kl = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
return kl

Expand Down

0 comments on commit 16e2462

Please sign in to comment.