Skip to content

Commit

Permalink
Remove tests that were added back to DPPL
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 29, 2024
1 parent 3180c64 commit f812c69
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 165 deletions.
63 changes: 0 additions & 63 deletions test/dynamicppl/model.jl

This file was deleted.

104 changes: 2 additions & 102 deletions test/dynamicppl/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,58 +37,6 @@ using Turing
end
end

@testset "link!" begin
@model gdemo(x, y) = begin
s ~ InverseGamma(2, 3)
m ~ Uniform(0, 2)
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
end
model = gdemo(1.0, 2.0)

vi = DynamicPPL.VarInfo()
meta = vi.metadata
model(vi, DynamicPPL.SampleFromUniform())
@test all(x -> !DynamicPPL.istrans(vi, x), meta.vns)

alg = HMC(0.1, 5)
spl = DynamicPPL.Sampler(alg, model)
v = copy(meta.vals)
DynamicPPL.link!!(vi, spl, model)
@test all(x -> DynamicPPL.istrans(vi, x), meta.vns)
DynamicPPL.invlink!!(vi, spl, model)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.vns)
@test meta.vals v atol = 1e-10

vi = DynamicPPL.TypedVarInfo(vi)
meta = vi.metadata
alg = HMC(0.1, 5)
spl = DynamicPPL.Sampler(alg, model)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns)
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
DynamicPPL.link!!(vi, spl, model)
@test all(x -> DynamicPPL.istrans(vi, x), meta.s.vns)
@test all(x -> DynamicPPL.istrans(vi, x), meta.m.vns)
DynamicPPL.invlink!!(vi, spl, model)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns)
@test meta.s.vals v_s atol = 1e-10
@test meta.m.vals v_m atol = 1e-10

# Transforming only a subset of the variables
spl = DynamicPPL.Sampler(HMC(0.1, 5, :m), model)
DynamicPPL.link!!(vi, spl, model)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns)
@test all(x -> DynamicPPL.istrans(vi, x), meta.m.vns)
DynamicPPL.invlink!!(vi, spl, model)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns)
@test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns)
@test meta.s.vals v_s atol = 1e-10
@test meta.m.vals v_m atol = 1e-10
end

@testset "orders" begin
csym = gensym() # unique per model
vn_z1 = @varname z[1]
Expand Down Expand Up @@ -176,6 +124,7 @@ using Turing
@test vi.metadata.b.orders == [2]
@test DynamicPPL.get_num_produce(vi) == 3
end

@testset "replay" begin
# Generate synthesised data
xs = rand(Normal(0.5, 1), 100)
Expand All @@ -196,6 +145,7 @@ using Turing
# Sampling
chain = sample(priorsinarray(xs), HMC(0.01, 10), 10)
end

@testset "varname" begin
@model function mat_name_test()
p = Array{Any}(undef, 2, 2)
Expand Down Expand Up @@ -288,41 +238,6 @@ using Turing

# Test the update of group IDs
g_demo_f = igtest()

# This test section no longer seems as applicable, considering the
# user will never end up using an UntypedVarInfo. The `VarInfo`
# Varible is also not passed around in the same way as it used to be.

# TODO: Has to be fixed

#= g = DynamicPPL.Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f)
vi = VarInfo()
g_demo_f(vi, SampleFromPrior())
_, state = @inferred AbstractMCMC.step(Random.default_rng(), g_demo_f, g)
pg, hmc = state.states
@test pg isa TypedVarInfo
@test hmc isa Turing.Inference.HMCState
vi1 = state.vi
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()]
@inferred g_demo_f(vi1, hmc)
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])]
g = DynamicPPL.Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f)
pg, hmc = g.state.samplers
vi = empty!!(TypedVarInfo(vi))
@inferred g_demo_f(vi, SampleFromPrior())
pg.state.vi = vi
step!(Random.default_rng(), g_demo_f, pg, 1)
vi = pg.state.vi
@inferred g_demo_f(vi, hmc)
@test vi.metadata.x.gids[1] == Set([pg.selector])
@test vi.metadata.y.gids[1] == Set([pg.selector])
@test vi.metadata.z.gids[1] == Set([pg.selector])
@test vi.metadata.w.gids[1] == Set([hmc.selector])
@test vi.metadata.u.gids[1] == Set([hmc.selector]) =#
end

@testset "Turing#2151: eltype(vi, spl)" begin
Expand Down Expand Up @@ -355,21 +270,6 @@ using Turing
model = state_space(y, length(t))
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(; compile=true)), n), 1) == n
end

if Threads.nthreads() > 1
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
@model function f(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
return x ~ Normal(m, 1)
end
model = f(1)
chain = sample(model, NUTS(), MCMCThreads(), 10, 2)
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end
end
end

end

0 comments on commit f812c69

Please sign in to comment.