From 36fa995aeab208e5849e603d2e5cba699f6a8c4c Mon Sep 17 00:00:00 2001 From: Stefan Webb Date: Sat, 6 Apr 2019 19:52:39 -0700 Subject: [PATCH] '`arm.electric_1a`' --- pyro_models/arm/electric_1a.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyro_models/arm/electric_1a.py b/pyro_models/arm/electric_1a.py index e856267..84c2983 100644 --- a/pyro_models/arm/electric_1a.py +++ b/pyro_models/arm/electric_1a.py @@ -49,14 +49,14 @@ def model(data, params): with pyro.plate('n_grade_pair', n_grade_pair): mu_a = pyro.sample("mu_a", dist.Normal(0., 1.)) sigma_a = pyro.sample("sigma_a", dist.Uniform(0., 100.)) - sigma_a_hat = sigma_a[grade_pair] - mu_a_hat = 100 * mu_a[grade_pair] + sigma_a_hat = sigma_a[...,grade_pair] + mu_a_hat = 100 * mu_a[...,grade_pair] with pyro.plate('n_pair', n_pair): a = pyro.sample("a", dist.Normal(mu_a_hat, sigma_a_hat)) with pyro.plate('n_grade', n_grade): b = pyro.sample("b", dist.Normal(0., 100.)) sigma_y = pyro.sample("sigma_y", dist.Uniform(0., 100.)) - sigma_y_hat = sigma_y[grade] + sigma_y_hat = sigma_y[...,grade] with pyro.plate("data", N): - y_hat = a[pair] + b[grade] * treatment + y_hat = a[...,pair] + b[...,grade] * treatment y = pyro.sample("y", dist.Normal(y_hat, sigma_y_hat), obs=y)