Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1D gaussian input node implementation (WIP) #129

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ProbabilisticCircuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module ProbabilisticCircuits
include("nodes/indicator_dist.jl")
include("nodes/categorical_dist.jl")
include("nodes/binomial_dist.jl")
include("nodes/gaussian_dist.jl")

include("nodes/plain_nodes.jl")

Expand Down
17 changes: 17 additions & 0 deletions src/io/jpc_io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const jpc_grammar = raw"""
node : "L" _WS INT _WS INT _WS SIGNED_INT -> literal_node
| "I" _WS INT _WS INT _WS INT _WS INT -> indicator_node
| "B" _WS INT _WS INT _WS INT _WS INT _WS LOGPROB -> binomial_node
| "G" _WS INT _WS INT _WS INT _WS LOGPROB _WS LOGPROB -> gaussian_node
| "C" _WS INT _WS INT _WS INT (_WS LOGPROB)+ -> categorical_node
| "P" _WS INT _WS INT _WS INT child_nodes -> prod_node
| "S" _WS INT _WS INT _WS INT weighted_child_nodes -> sum_node
Expand Down Expand Up @@ -66,6 +67,13 @@ end
t.nodes[x[1]] = PlainInputNode(var, Indicator(value))
end

@rule gaussian_node(t::PlainJpcParse, x) = begin
var = Base.parse(Int,x[3])
mu = Base.parse(Float64, x[4])
sigma = Base.parse(Float64, x[5])
t.nodes[x[1]] = PlainInputNode(var, Gaussian(mu, sigma))
end

@rule binomial_node(t::PlainJpcParse, x) = begin
var = Base.parse(Int,x[3])
N = Base.parse(UInt32, x[4])
Expand Down Expand Up @@ -134,6 +142,11 @@ function read_fast(input, ::Type{<:ProbCircuit} = PlainProbCircuit, ::JpcFormat
N = Base.parse(UInt32, tokens[5])
logp = Base.parse(Float64, tokens[6])
nodes[id] = PlainInputNode(var, Binomial(N, exp(logp)))
elseif startswith(line, "G")
var = Base.parse(Int,tokens[4])
mu = Base.parse(Float64, tokens[5])
sigma = Base.parse(Float64, tokens[6])
nodes[id] = PlainInputNode(var, Gaussian(mu, sigma))
elseif startswith(line, "P")
child_ids = Base.parse.(Int, tokens[5:end]) .+ 1
children = nodes[child_ids]
Expand Down Expand Up @@ -166,6 +179,7 @@ c L id-of-jpc-node id-of-vtree literal
c I id-of-jpc-node id-of-vtree variable indicator-value
c C id-of-jpc-node id-of-vtree variable {log-probability}+
c B id-of-jpc-node id-of-vtree variable binomial-N binomial-P
c G id-of-jpc-node id-of-vtree variable gaussian-mu gaussian-sigma
c P id-of-sum-jpc-node id-of-vtree number-of-children {child-id}+
c S id-of-product-jpc-node id-of-vtree number-of-children {child-id log-probability}+
c"""
Expand Down Expand Up @@ -193,6 +207,9 @@ function Base.write(io::IO, circuit::ProbCircuit, ::JpcFormat, vtreeid::Function
elseif d isa Binomial
print(io, "B $(labeling[n]) $(vtreeid(n)) $var $(d.N) $(log(d.p))")
println(io)
elseif d isa Gaussian
print(io, "G $(labeling[n]) $(vtreeid(n)) $var $(d.mu) $(d.sigma)")
println(io)
else
error("Input distribution type $(typeof(d)) is unknown to the JPC file format")
end
Expand Down
5 changes: 5 additions & 0 deletions src/io/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,9 @@ end
function latex(d::Binomial)
p = round(d.p, digits=3)
"Binomial($(d.N), $(p))"
end

function latex(d::Gaussian)
mu = round(d.mu, digits=3)
"Gaussian($(mu), $(d.sigma))"
end
119 changes: 119 additions & 0 deletions src/nodes/gaussian_dist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using CUDA
import Random: default_rng

export Gaussian

struct Gaussian <: InputDist
mu::Float32
sigma::Float32
end

struct BitsGaussian <: InputDist
# mu::Float32
# sigma::Float32
sigma::Float32
heap_start::UInt32
end

# heap_start index offsets
const GAUSS_HEAP_MU = UInt32(1)
const GAUSS_HEAP_FLOWVALUE = UInt32(2) # flow*value
const GAUSS_HEAP_FLOW = UInt32(3) # flow
const GAUSS_HEAP_MISSINGFLOW = UInt32(4) # missing_flow

Gaussian(mu::Float64, sigma::Float64) =
Gaussian(Float32(mu), Float32(sigma))

num_parameters(dist::Gaussian, independent) = 1

params(dist::Gaussian) = dist.mu

init_params(dist::Gaussian, perturbation) =
Gaussian(0.0, dist.sigma)

function bits(dist::Gaussian, heap)
heap_start = length(heap) + 1

# Append mu, sigma, flow*value, flow, missing_flow
append!(heap, dist.mu, zeros(Float32, 3))
BitsGaussian(dist.sigma, heap_start)
end

mu(dist::Gaussian, _ = nothing) = dist.mu
mu(dist::BitsGaussian, heap) = heap[dist.heap_start]

sigma(dist::Gaussian, _ = nothing) = dist.sigma

function unbits(dist::Gaussian, heap)
Gaussian(mu(dist, heap), dist.sigma)
end

function loglikelihood(dist::Gaussian, value, _ = nothing)
# normlogpdf((value - mu(dist))/sigma(dist))
log_gauss(value, mu(dist), dist.sigma)
end

function loglikelihood(dist::BitsGaussian, value, heap)
# normlogpdf((value - mu(dist, heap))/sigma(dist, heap))
log_gauss(value, mu(dist, heap), dist.sigma)
end

log_gauss(x, mu, sigma) = -log(sigma) - 0.5*log(2π) - 0.5*((x - mu)/sigma)^2
khosravipasha marked this conversation as resolved.
Show resolved Hide resolved

function flow(dist::BitsGaussian, value, node_flow, heap)
heap_start = dist.heap_start

if ismissing(value)
CUDA.@atomic heap[heap_start + GAUSS_HEAP_MISSINGFLOW] += node_flow
else
CUDA.@atomic heap[heap_start + GAUSS_HEAP_FLOWVALUE] += node_flow * value
CUDA.@atomic heap[heap_start + GAUSS_HEAP_FLOW] += node_flow
end
nothing
end

function update_params(dist::BitsGaussian, heap, pseudocount, inertia)
error("Not implemented error: `update_params`, $(typeof(dist))")
#heap_start = dist.heap_start

#missing_flow = heap[heap_start + GAUSS_HEAP_MISSINGFLOW]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think we should be very similar to binomial.

  • For computing "node_flow" we just add missing_flow to it directly.
  • For computing "node_flow * value", if we have missing flow then we add oldp * missing_flow to it, the reason is that if we don't observe that guassian the expected value is the oldp (or old_mu) in this case.

pseudocount was basically used to avoid getting 0 probabilities, so its preteding we observed x=i for every i. For the guassian case I case we would treat it the same as the missing_flows (or can ignore it for now since not sure if it fits with continous/guassian vairable ) so can first try with pseudocount=0

Overall, I think should look like something like this,

    missing_flow = heap[heap_start + GAUSS_HEAP_MISSINGFLOW]
    node_flow = heap[heap_start + GAUSS_HEAP_FLOW] + missing_flow + pseudocount

    old_mu = heap[heap_start]
    new = (heap[heap_start + GAUSS_HEAP_FLOWVALUE] + (missing_flow + pseudocount) * old_mu ) / (node_flow)

#node_flow = heap[heap_start + GAUSS_HEAP_FLOW] + missing_flow + pseudocount

# old_mu = heap[heap_start]

# TODO: How to convert this to Gaussian EM-update?
# new = (heap[heap_start + 2] + missing_flow * oldp * dist.N + pseudocount) / (node_flow * dist.N)
# new_p = oldp * inertia + new * (one(Float32) - inertia)

#new_mu = nothing

# update mu and sigma on heap
#heap[heap_start] = new_mu
#nothing
end

function clear_memory(dist::BitsGaussian, heap, rate)
heap_start = dist.heap_start
for i = 1 : 3
heap[heap_start + i] *= rate
end
nothing
end

function sample_state(dist::Union{BitsGaussian, Gaussian}, threshold, heap)
# Sample from standard normal
z = randn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Threshold already gives you the random number you need, so probably should not sample again. We sample all the randomness we need beforehand, and pass them along to the input nodes, so here threshold is basically sampled from a uniform distribution from [0,1]. The threhold is basically log of a uniform random variable from [0-1].

I guess for Guassian might need to use the reverse CDF to make it work using threshold. If its not easy to do, maybe we can adjust how randomness is passed to the input nodes.

Here is where we pass the random numbers to input nodes sample_state
https://github.com/Juice-jl/ProbabilisticCircuits.jl/blob/27cb093439c8db5b6e59f75567800ff92d4fffa6/src/queries/sample.jl#L174

And we sample all randomness needed before calling the kernel
https://github.com/Juice-jl/ProbabilisticCircuits.jl/blob/27cb093439c8db5b6e59f75567800ff92d4fffa6/src/queries/sample.jl#L133


# Reparameterize
return dist.sigma * z + dist.mu
end

### MAP
init_heap_map_state!(dist::Gaussian, heap) = nothing

init_heap_map_loglikelihood!(dist::Gaussian, heap) = nothing

map_state(dist::Gaussian, heap) = dist.mu

map_loglikelihood(dist::Gaussian, heap) = loglikelihood(dist, dist.mu, heap)

19 changes: 19 additions & 0 deletions test/helper/plain_dummy_circuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ function little_3var_binomial(firstvar=1; n = 10)
summate(multiply(n1, n2, n3))
end

function little_gmm(firstvar=1; sigma = 1)
n1 = PlainInputNode(firstvar, Gaussian(-2.0, sigma))
n2 = PlainInputNode(firstvar, Gaussian(0.0, sigma))
n3 = PlainInputNode(firstvar, Gaussian(2.0, sigma))
0.1 * n1 + 0.2 * n2 + 0.7 *n3
end

function little_2var_gmm(firstvar=1; sigma = 1)
n1_x = PlainInputNode(firstvar, Gaussian(-2.0, sigma))
n1_y = PlainInputNode(firstvar+1, Gaussian(-2.0, sigma))

n2_x = PlainInputNode(firstvar, Gaussian(0.0, sigma))
n2_y = PlainInputNode(firstvar+1, Gaussian(0.0, sigma))

n1 = multiply(n1_x, n1_y)
n2 = multiply(n2_x, n2_y)

summate(n1, n2)
end

function little_4var()
circuit = IOBuffer(b"""psdd 19
Expand Down
52 changes: 52 additions & 0 deletions test/io/jpc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,58 @@ end
end
end

@testset "JPC IO tests Gaussian" begin
pc = little_gmm()

mktempdir() do tmp
file = "$tmp/example_gaussian.jpc"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), true)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), false)
test_pc_equals(pc, pc2)

# Compressed
file = "$tmp/example_gaussian.jpc.gz"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

end
end

@testset "JPC IO tests 2D Gaussian" begin
pc = little_2var_gmm()

mktempdir() do tmp
file = "$tmp/example_gaussian2.jpc"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), true)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), false)
test_pc_equals(pc, pc2)

# Compressed
file = "$tmp/example_gaussian2.jpc.gz"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

end
end


@testset "Jpc IO tests hybrid" begin

Expand Down