Skip to content

Commit

Permalink
'arm.electric_1a'
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanwebb committed Apr 7, 2019
1 parent 0f49c7e commit 36fa995
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyro_models/arm/electric_1a.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def model(data, params):
with pyro.plate('n_grade_pair', n_grade_pair):
mu_a = pyro.sample("mu_a", dist.Normal(0., 1.))
sigma_a = pyro.sample("sigma_a", dist.Uniform(0., 100.))
sigma_a_hat = sigma_a[grade_pair]
mu_a_hat = 100 * mu_a[grade_pair]
sigma_a_hat = sigma_a[...,grade_pair]
mu_a_hat = 100 * mu_a[...,grade_pair]
with pyro.plate('n_pair', n_pair):
a = pyro.sample("a", dist.Normal(mu_a_hat, sigma_a_hat))
with pyro.plate('n_grade', n_grade):
b = pyro.sample("b", dist.Normal(0., 100.))
sigma_y = pyro.sample("sigma_y", dist.Uniform(0., 100.))
sigma_y_hat = sigma_y[grade]
sigma_y_hat = sigma_y[...,grade]
with pyro.plate("data", N):
y_hat = a[pair] + b[grade] * treatment
y_hat = a[...,pair] + b[...,grade] * treatment
y = pyro.sample("y", dist.Normal(y_hat, sigma_y_hat), obs=y)

0 comments on commit 36fa995

Please sign in to comment.