Skip to content

Commit

Permalink
Style improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 12, 2024
1 parent da7342e commit 14decaf
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ end

# Tilde pipeline
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
if is_target_varname(context, vn)
return if is_target_varname(context, vn)
# Fall back to the default behavior.
return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi)
DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi)
elseif has_conditioned_gibbs(context, vn)
# Short-circuit the tilde assume if `vn` is present in `context`.
value = get_conditioned_gibbs(context, vn)
# TODO(mhauru) Is the call to logpdf correct if context.context is not
# DefaultContext?
return value, logpdf(right, value), vi
value, logpdf(right, value), vi
else
# If the varname has not been conditioned on, nor is it a target variable, its
# presumably a new variable that should be sampled from its prior. We need to add
Expand All @@ -123,7 +123,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
return value, lp, vi
value, lp, vi

Check warning on line 126 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L125-L126

Added lines #L125 - L126 were not covered by tests
end
end

Expand All @@ -132,14 +132,14 @@ function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi
)
# See comment in the above, rng-less version of this method for an explanation.
if is_target_varname(context, vn)
return DynamicPPL.tilde_assume(
return if is_target_varname(context, vn)
DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)
elseif has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
# TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext?
return value, logpdf(right, value), vi
value, logpdf(right, value), vi
else
value, lp, new_global_vi = DynamicPPL.tilde_assume(

Check warning on line 144 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L144

Added line #L144 was not covered by tests
rng,
Expand All @@ -150,7 +150,7 @@ function DynamicPPL.tilde_assume(
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
return value, lp, vi
value, lp, vi

Check warning on line 153 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L152-L153

Added lines #L152 - L153 were not covered by tests
end
end

Expand All @@ -175,14 +175,12 @@ end
# Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf.
# See comments there for more details.
function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi)
if is_target_varname(context, vns)
return DynamicPPL.dot_tilde_assume(
DynamicPPL.childcontext(context), right, left, vns, vi
)
return if is_target_varname(context, vns)
DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi)
elseif has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))

Check warning on line 181 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L177-L181

Added lines #L177 - L181 were not covered by tests
# TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext?
return value, broadcast_logpdf(right, value), vi
value, broadcast_logpdf(right, value), vi

Check warning on line 183 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L183

Added line #L183 was not covered by tests
else
prior_sampler = DynamicPPL.SampleFromPrior()
value, lp, new_global_vi = DynamicPPL.dot_tilde_assume(

Check warning on line 186 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L185-L186

Added lines #L185 - L186 were not covered by tests
Expand All @@ -194,22 +192,22 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
return value, lp, vi
value, lp, vi

Check warning on line 195 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L194-L195

Added lines #L194 - L195 were not covered by tests
end
end

# As above but with an RNG.
function DynamicPPL.dot_tilde_assume(

Check warning on line 200 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L200

Added line #L200 was not covered by tests
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi
)
if is_target_varname(context, vns)
return DynamicPPL.dot_tilde_assume(
return if is_target_varname(context, vns)
DynamicPPL.dot_tilde_assume(

Check warning on line 204 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L203-L204

Added lines #L203 - L204 were not covered by tests
rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi
)
elseif has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))

Check warning on line 208 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L207-L208

Added lines #L207 - L208 were not covered by tests
# TODO(mhauru) As above, is logpdf correct if context.context is not DefaultContext?
return value, broadcast_logpdf(right, value), vi
value, broadcast_logpdf(right, value), vi

Check warning on line 210 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L210

Added line #L210 was not covered by tests
else
prior_sampler = DynamicPPL.SampleFromPrior()
value, lp, new_global_vi = DynamicPPL.dot_tilde_assume(

Check warning on line 213 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
Expand All @@ -222,7 +220,7 @@ function DynamicPPL.dot_tilde_assume(
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
return value, lp, vi
value, lp, vi

Check warning on line 223 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L222-L223

Added lines #L222 - L223 were not covered by tests
end
end

Expand Down

0 comments on commit 14decaf

Please sign in to comment.