From 8cabee2d91fff56c00a6d229dd506afd6c4e5749 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Wed, 6 Nov 2024 15:06:06 +0100 Subject: [PATCH] fix(cat): correct temperature scaling --- simple_einet/layers/distributions/abstract_leaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simple_einet/layers/distributions/abstract_leaf.py b/simple_einet/layers/distributions/abstract_leaf.py index bc1e18f..3844a3f 100644 --- a/simple_einet/layers/distributions/abstract_leaf.py +++ b/simple_einet/layers/distributions/abstract_leaf.py @@ -124,7 +124,7 @@ def dist_sample(distribution: dist.Distribution, ctx: SamplingContext = None) -> elif type(distribution) == CustomNormal: distribution = CustomNormal(mu=distribution.mu, sigma=distribution.sigma * np.sqrt(ctx.temperature_leaves)) elif type(distribution) == dist.Categorical: - distribution = dist.Categorical(logits=F.softmax(distribution.logits / ctx.temperature_leaves)) + distribution = dist.Categorical(logits=distribution.logits / ctx.temperature_leaves) samples = distribution.sample(sample_shape=(ctx.num_samples,)).float()