Skip to content

Commit

Permalink
Add decompose function to the AutoGP API.
Browse files Browse the repository at this point in the history
  • Loading branch information
fsaad committed Oct 29, 2024
1 parent fa91639 commit 5acd1f0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ AutoGP.particle_weights
AutoGP.effective_sample_size
AutoGP.covariance_kernels
AutoGP.observation_noise_variances
AutoGP.decompose
```
13 changes: 12 additions & 1 deletion src/GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ Base.isapprox(a::Node, b::Node) =
depth(::LeafNode) = 1
depth(node::BinaryOpNode) = node.depth

"""
unroll(node::Node)
Unroll a covariance kernel into a flat Vector of all intermediate kernels.
"""
function unroll end
unroll(node::LeafNode) = [node]
unroll(node::BinaryOpNode) = vcat(unroll(node.left), unroll(node.right), node)


@doc raw"""
WhiteNoise(value)
Expand Down Expand Up @@ -527,7 +537,7 @@ function _show_pretty(io::IO, node::LeafNode, pre, vert_bars::Tuple; first=false
print(io, indent_str * "$(pretty(node))\n")
end

_pretty_BinaryOpNode(node::Plus) = '\uFF0B'
_pretty_BinaryOpNode(node::Plus) = '+'
_pretty_BinaryOpNode(node::Times) = '\u00D7'
_pretty_BinaryOpNode(node::ChangePoint) = ("CP$((node.location,node.scale))")
function _show_pretty(io::IO, node::BinaryOpNode, pre, vert_bars::Tuple; first=false, last=true)
Expand Down Expand Up @@ -614,6 +624,7 @@ export BinaryOpNode
export eval_cov
export compute_cov_matrix
export compute_cov_matrix_vectorized
export unroll

export WhiteNoise
export Constant
Expand Down
52 changes: 52 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,55 @@ function Base.show(df::DataFrames.DataFrame)
summary=false, header_crayon=DataFrames.PrettyTables.Crayons.Crayon(),
eltypes=false, rowlabel=Symbol())
end

"""
function decompose(model::GPModel)
Decompose each particle within `model` into its constituent kernels.
Supposing that [`num_particles`](@ref)`(model)` equals ``k``, the return
value `models::Vector{GPModel}` of `decompose` is a length-``k`` vector of
[`GPModel`](@ref) instances.
Therefore, `models[i]` is a [`GPModel`](@ref) that represents the
decomposition of particle `i` in `model` into its constituent kernels. Each
particle in `models[i]` corresponds to a kernel fragment in the covariance
for particle `i` of `model` (i.e., one particle for each [`GP.Node`](@ref)
in the covariance kernel).
The weights of `models[i]` are arbitrary and have no meaningful value.
This function is particularly useful for visualizing the individual time
series structures that make up each particle of `model`.
"""
function decompose(model::GPModel)
kernels = covariance_kernels(model)
unrolled = map(GP.unroll, kernels)
@assert length(kernels) == num_particles(model)
models = Vector{GPModel}(undef, length(kernels))
# ERROR: type GPConfig has no field WhiteNoise
# noises = Model.transform_param.(
# :noise,
# [trace[:noise] for trace in model.pf_state.traces],)
# .+ AutoGP.Model.JITTER
for (i, kernel_list::Vector{GP.Node}) in enumerate(unrolled)
# ERROR: type GPConfig has no field WhiteNoise
# Add observation noise as a WhiteNoise kernel.
# typeof(kernel_list)
# push!(kernel_list, GP.WhiteNoise(noises[i]))
# Initialize new GPModel.
models[i] = GPModel(
model.ds, model.y;
n_particles=length(kernel_list), config=model.config)
# -- Copy transforms, since add_data! may have been called on model.
models[i].ds_transform = model.ds_transform
models[i].y_transform = model.y_transform
# Force update each particle to match the kernel fragment.
for (j, trace) in enumerate(models[i].pf_state.traces)
models[i].pf_state.traces[j] = Inference.node_to_trace(
kernel_list[j], model.pf_state.traces[i])
end
# Weights are arbitrary.
models[i].pf_state.log_weights = zeros(length(models[i].pf_state.traces))
end
return models
end
14 changes: 14 additions & 0 deletions src/inference_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,17 @@ function get_observations_choicemap(trace::Gen.Trace)
end
return observations
end

function node_to_trace(node::Node, trace::Gen.Trace)
config = Gen.get_args(trace)[2]
choicemap_obs = get_observations_choicemap(trace)
choicemap_node = Gen.choicemap()
Gen.set_submap!(choicemap_node, :tree, node_to_choicemap(node, config))
constraints = merge(choicemap_node, choicemap_obs)
constraints[:noise] = trace[:noise]
return Gen.generate(
Gen.get_gen_fn(trace),
Gen.get_args(trace),
constraints,
)[1]
end

0 comments on commit 5acd1f0

Please sign in to comment.