diff --git a/src/ProbabilisticCircuits.jl b/src/ProbabilisticCircuits.jl index a48164c4..5d357ace 100644 --- a/src/ProbabilisticCircuits.jl +++ b/src/ProbabilisticCircuits.jl @@ -12,6 +12,7 @@ module ProbabilisticCircuits include("nodes/indicator_dist.jl") include("nodes/categorical_dist.jl") include("nodes/binomial_dist.jl") + include("nodes/gaussian_dist.jl") include("nodes/plain_nodes.jl") diff --git a/src/io/jpc_io.jl b/src/io/jpc_io.jl index bf8400ed..daad1210 100644 --- a/src/io/jpc_io.jl +++ b/src/io/jpc_io.jl @@ -12,6 +12,7 @@ const jpc_grammar = raw""" node : "L" _WS INT _WS INT _WS SIGNED_INT -> literal_node | "I" _WS INT _WS INT _WS INT _WS INT -> indicator_node | "B" _WS INT _WS INT _WS INT _WS INT _WS LOGPROB -> binomial_node + | "G" _WS INT _WS INT _WS INT _WS LOGPROB _WS LOGPROB -> gaussian_node | "C" _WS INT _WS INT _WS INT (_WS LOGPROB)+ -> categorical_node | "P" _WS INT _WS INT _WS INT child_nodes -> prod_node | "S" _WS INT _WS INT _WS INT weighted_child_nodes -> sum_node @@ -66,6 +67,13 @@ end t.nodes[x[1]] = PlainInputNode(var, Indicator(value)) end +@rule gaussian_node(t::PlainJpcParse, x) = begin + var = Base.parse(Int,x[3]) + mu = Base.parse(Float64, x[4]) + sigma = Base.parse(Float64, x[5]) + t.nodes[x[1]] = PlainInputNode(var, Gaussian(mu, sigma)) +end + @rule binomial_node(t::PlainJpcParse, x) = begin var = Base.parse(Int,x[3]) N = Base.parse(UInt32, x[4]) @@ -134,6 +142,11 @@ function read_fast(input, ::Type{<:ProbCircuit} = PlainProbCircuit, ::JpcFormat N = Base.parse(UInt32, tokens[5]) logp = Base.parse(Float64, tokens[6]) nodes[id] = PlainInputNode(var, Binomial(N, exp(logp))) + elseif startswith(line, "G") + var = Base.parse(Int,tokens[4]) + mu = Base.parse(Float64, tokens[5]) + sigma = Base.parse(Float64, tokens[6]) + nodes[id] = PlainInputNode(var, Gaussian(mu, sigma)) elseif startswith(line, "P") child_ids = Base.parse.(Int, tokens[5:end]) .+ 1 children = nodes[child_ids] @@ -166,6 +179,7 @@ c L id-of-jpc-node id-of-vtree literal c I id-of-jpc-node id-of-vtree variable indicator-value c C id-of-jpc-node id-of-vtree variable {log-probability}+ c B id-of-jpc-node id-of-vtree variable binomial-N binomial-P +c G id-of-jpc-node id-of-vtree variable gaussian-mu gaussian-sigma c P id-of-sum-jpc-node id-of-vtree number-of-children {child-id}+ c S id-of-product-jpc-node id-of-vtree number-of-children {child-id log-probability}+ c""" @@ -193,6 +207,9 @@ function Base.write(io::IO, circuit::ProbCircuit, ::JpcFormat, vtreeid::Function elseif d isa Binomial print(io, "B $(labeling[n]) $(vtreeid(n)) $var $(d.N) $(log(d.p))") println(io) + elseif d isa Gaussian + print(io, "G $(labeling[n]) $(vtreeid(n)) $var $(d.mu) $(d.sigma)") + println(io) else error("Input distribution type $(typeof(d)) is unknown to the JPC file format") end diff --git a/src/io/plot.jl b/src/io/plot.jl index 16b2917b..7075dbea 100644 --- a/src/io/plot.jl +++ b/src/io/plot.jl @@ -48,4 +48,9 @@ end function latex(d::Binomial) p = round(d.p, digits=3) "Binomial($(d.N), $(p))" +end + +function latex(d::Gaussian) + mu = round(d.mu, digits=3) + "Gaussian($(mu), $(d.sigma))" end \ No newline at end of file diff --git a/src/nodes/gaussian_dist.jl b/src/nodes/gaussian_dist.jl new file mode 100644 index 00000000..f1687d6b --- /dev/null +++ b/src/nodes/gaussian_dist.jl @@ -0,0 +1,115 @@ +using CUDA +import Random: default_rng + +export Gaussian + +struct Gaussian <: InputDist + mu::Float32 + sigma::Float32 +end + +struct BitsGaussian <: InputDist + # mu::Float32 + # sigma::Float32 + sigma::Float32 + heap_start::UInt32 +end + +# heap_start index offsets +const GAUSS_HEAP_MU = UInt32(1) +const GAUSS_HEAP_FLOWVALUE = UInt32(2) # flow*value +const GAUSS_HEAP_FLOW = UInt32(3) # flow +const GAUSS_HEAP_MISSINGFLOW = UInt32(4) # missing_flow + +Gaussian(mu::Float64, sigma::Float64) = + Gaussian(Float32(mu), Float32(sigma)) + +num_parameters(dist::Gaussian, independent) = 1 + +params(dist::Gaussian) = dist.mu + +init_params(dist::Gaussian, perturbation) = + Gaussian(0.0, dist.sigma) + +function bits(dist::Gaussian, heap) + heap_start = length(heap) + 1 + + # Append mu, sigma, flow*value, flow, missing_flow + append!(heap, dist.mu, zeros(Float32, 3)) + BitsGaussian(dist.sigma, heap_start) +end + +mu(dist::Gaussian, _ = nothing) = dist.mu +mu(dist::BitsGaussian, heap) = heap[dist.heap_start] + +sigma(dist::Gaussian, _ = nothing) = dist.sigma + +function unbits(dist::Gaussian, heap) + Gaussian(mu(dist, heap), dist.sigma) +end + +function loglikelihood(dist::Gaussian, value, _ = nothing) + # normlogpdf((value - mu(dist))/sigma(dist)) + log_gauss(value, mu(dist), dist.sigma) +end + +function loglikelihood(dist::BitsGaussian, value, heap) + # normlogpdf((value - mu(dist, heap))/sigma(dist, heap)) + log_gauss(value, mu(dist, heap), dist.sigma) +end + +log_gauss(x, mu, sigma) = -log(sigma) - 0.5*log(2π) - 0.5*((x - mu)/sigma)^2 + +function flow(dist::BitsGaussian, value, node_flow, heap) + heap_start = dist.heap_start + + if ismissing(value) + CUDA.@atomic heap[heap_start + GAUSS_HEAP_MISSINGFLOW] += node_flow + else + CUDA.@atomic heap[heap_start + GAUSS_HEAP_FLOWVALUE] += node_flow * value + CUDA.@atomic heap[heap_start + GAUSS_HEAP_FLOW] += node_flow + end + nothing +end + +function update_params(dist::BitsGaussian, heap, pseudocount, inertia) + heap_start = dist.heap_start + + missing_flow = heap[heap_start + GAUSS_HEAP_MISSINGFLOW] + node_flow = heap[heap_start + GAUSS_HEAP_FLOW] + missing_flow + pseudocount + + old_mu = heap[heap_start] + + new = (heap[heap_start + GAUSS_HEAP_FLOWVALUE] + (missing_flow + pseudocount) * old_mu) / (node_flow) + new_mu = old_mu * inertia + new * (one(Float32) - inertia) + + # update mu on heap + heap[heap_start] = new_mu + nothing +end + +function clear_memory(dist::BitsGaussian, heap, rate) + heap_start = dist.heap_start + for i = 1 : 3 + heap[heap_start + i] *= rate + end + nothing +end + +function sample_state(dist::Union{BitsGaussian, Gaussian}, threshold, heap) + # Sample from standard normal + z = randn() + + # Reparameterize + return dist.sigma * z + dist.mu +end + +### MAP +init_heap_map_state!(dist::Gaussian, heap) = nothing + +init_heap_map_loglikelihood!(dist::Gaussian, heap) = nothing + +map_state(dist::Gaussian, heap) = dist.mu + +map_loglikelihood(dist::Gaussian, heap) = loglikelihood(dist, dist.mu, heap) + diff --git a/test/helper/plain_dummy_circuits.jl b/test/helper/plain_dummy_circuits.jl index 62de8bc6..e202486c 100644 --- a/test/helper/plain_dummy_circuits.jl +++ b/test/helper/plain_dummy_circuits.jl @@ -44,6 +44,25 @@ function little_3var_binomial(firstvar=1; n = 10) summate(multiply(n1, n2, n3)) end +function little_gmm(firstvar=1; sigma = 1) + n1 = PlainInputNode(firstvar, Gaussian(-1.0, sigma)) + n2 = PlainInputNode(firstvar, Gaussian(1.0, sigma)) + + 0.5 * n1 + 0.5 * n2 +end + +function little_2var_gmm(firstvar=1; sigma = 1) + n1_x = PlainInputNode(firstvar, Gaussian(-2.0, sigma)) + n1_y = PlainInputNode(firstvar+1, Gaussian(-2.0, sigma)) + + n2_x = PlainInputNode(firstvar, Gaussian(0.0, sigma)) + n2_y = PlainInputNode(firstvar+1, Gaussian(0.0, sigma)) + + n1 = multiply(n1_x, n1_y) + n2 = multiply(n2_x, n2_y) + + 0.2 * n1 + 0.8 * n2 +end function little_4var() circuit = IOBuffer(b"""psdd 19 diff --git a/test/input_distributions_tests.jl b/test/input_distributions_tests.jl index 31982ea7..c87e9b54 100644 --- a/test/input_distributions_tests.jl +++ b/test/input_distributions_tests.jl @@ -13,6 +13,11 @@ using ProbabilisticCircuits: bits, PlainInputNode n = PlainInputNode(1, Categorical(4)) @test issetequal(randvars(n), [1]) @test all(n.dist.logps .≈ [log(0.25), log(0.25), log(0.25), log(0.25)]) + + n = PlainInputNode(1, Gaussian(0.0, 1.0)) + @test issetequal(randvars(n), [1]) + @test n.dist.mu == 0.0 + @test n.dist.sigma == 1.0 end diff --git a/test/io/jpc_tests.jl b/test/io/jpc_tests.jl index 5bc4feec..fb5a5511 100644 --- a/test/io/jpc_tests.jl +++ b/test/io/jpc_tests.jl @@ -87,6 +87,58 @@ end end end +@testset "JPC IO tests Gaussian" begin + pc = little_gmm() + + mktempdir() do tmp + file = "$tmp/example_gaussian.jpc" + write(file, pc) + + pc2 = read(file, ProbCircuit) + test_pc_equals(pc, pc2) + + pc2 = read(file, ProbCircuit, JpcFormat(), true) + test_pc_equals(pc, pc2) + + pc2 = read(file, ProbCircuit, JpcFormat(), false) + test_pc_equals(pc, pc2) + + # Compressed + file = "$tmp/example_gaussian.jpc.gz" + write(file, pc) + + pc2 = read(file, ProbCircuit) + test_pc_equals(pc, pc2) + + end +end + +@testset "JPC IO tests 2D Gaussian" begin + pc = little_2var_gmm() + + mktempdir() do tmp + file = "$tmp/example_gaussian2.jpc" + write(file, pc) + + pc2 = read(file, ProbCircuit) + test_pc_equals(pc, pc2) + + pc2 = read(file, ProbCircuit, JpcFormat(), true) + test_pc_equals(pc, pc2) + + pc2 = read(file, ProbCircuit, JpcFormat(), false) + test_pc_equals(pc, pc2) + + # Compressed + file = "$tmp/example_gaussian2.jpc.gz" + write(file, pc) + + pc2 = read(file, ProbCircuit) + test_pc_equals(pc, pc2) + + end +end + @testset "Jpc IO tests hybrid" begin diff --git a/test/queries/likelihood_tests.jl b/test/queries/likelihood_tests.jl index dd4a1f15..27f608e5 100644 --- a/test/queries/likelihood_tests.jl +++ b/test/queries/likelihood_tests.jl @@ -1,5 +1,6 @@ using Test, DirectedAcyclicGraphs, ProbabilisticCircuits, CUDA using ProbabilisticCircuits: CuBitsProbCircuit +using StatsFuns include("../helper/plain_dummy_circuits.jl") include("../helper/data.jl") @@ -74,4 +75,122 @@ include("../helper/data.jl") end +end + + +@testset "gaussian 1-var-gmm likelihood" begin + EPS = 1e-6 + + pc = little_gmm(); + @test pc isa ProbCircuit + + data = Vector{Float64}([-1.0; 0.0; 1.0]) + + gmm_mu = [-1.0, 1.0] + gmm_sigma = [1.0, 1.0] + gmm_w = [0.5, 0.5] + + n = size(data, 1) # Num data + m = size(gmm_w, 1) # Num mixture components + + # Repeat data for each comp. and standardize + z = ((repeat(reshape(data, 1, n), m, 1)) .- repeat(gmm_mu, 1, n)) ./ repeat(gmm_sigma, 1, n) + + # Compute true probs with StatsFuns + p_m = normpdf.(z) + + # Weighted sum of probs from each dist for each datapoint + true_probs = transpose(gmm_w) * p_m + true_probs = reshape(true_probs, n, 1) + + # Bigger Batch size + probs = exp.(loglikelihoods(pc, reshape(data, n, 1); batch_size = 32)) + probs = reshape(probs, n, 1) + + @test true_probs ≈ probs atol=EPS + + # Smaller Batch size + lls = exp.(loglikelihoods(pc, reshape(data, n, 1); batch_size = 2)) + @test true_probs ≈ probs atol=EPS + + @test num_randvars(pc) == 1 + + # GPU Tests + # TODO: test on GPU + if CUDA.functional() + pc = little_gmm() + bpc = CuBitsProbCircuit(pc) + + data = cu(reshape(data, n, 1)) + + probs = loglikelihoods(bpc, data; batch_size = 32) + + @test true_probs ≈ probs atol=EPS + end + +end + + +@testset "gaussian 2-var-gmm likelihood" begin + EPS = 1e-6 + + pc = little_2var_gmm(); + @test pc isa ProbCircuit + + data = Matrix{Float64}([-2.0 -2.0; 0.0 0.0; 1.0 1.0; -2.0 0.0]) + + gmm_mu = [-2.0 -2.0; 0.0 0.0] + + gmm_sigma = 1.0 + gmm_w = [0.2; 0.8] + + n = size(data, 1) # Num data + m = size(gmm_w, 1) # Num mixture components + d = 2 + + # Compute GMM probs for each data-point + # z = ((repeat(reshape(data, n, d), 1, 1, m)) .- repeat(reshape(gmm_mu, 1, d, m), n, 1, 1)) ./ repeat([gmm_sigma], n, d, m) + true_probs = zeros(n) + for i in 1:n + for k in 1:m + x = data[i, :] + + m_mu = gmm_mu[k, :] + + # Standardization + z = (x .- m_mu) ./ gmm_sigma + + # Iterative weighted sum of each comp. probs + m_w = gmm_w[k] + true_probs[i] += m_w * prod(normpdf.(z)) + end + end + + true_probs = reshape(true_probs, n, 1) + + # Bigger Batch size + probs = exp.(loglikelihoods(pc, reshape(data, n, d); batch_size = 32)) + probs = reshape(probs, n, 1) + + @test true_probs ≈ probs atol=EPS + + # Smaller Batch size + lls = exp.(loglikelihoods(pc, reshape(data, n, d); batch_size = 2)) + @test true_probs ≈ probs atol=EPS + + @test num_randvars(pc) == 2 + + # GPU Tests + # TODO: test on GPU + if CUDA.functional() + pc = little_gmm() + bpc = CuBitsProbCircuit(pc) + + data = cu(reshape(data, n, d)) + + probs = loglikelihoods(bpc, data; batch_size = 32) + + @test true_probs ≈ probs atol=EPS + end + end \ No newline at end of file