diff --git a/docs/src/api.md b/docs/src/api.md index 501d6d1..c333f20 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -139,4 +139,5 @@ AutoGP.particle_weights AutoGP.effective_sample_size AutoGP.covariance_kernels AutoGP.observation_noise_variances +AutoGP.decompose ``` diff --git a/src/GP.jl b/src/GP.jl index 407d2ba..3b2c2a5 100644 --- a/src/GP.jl +++ b/src/GP.jl @@ -72,6 +72,16 @@ Base.isapprox(a::Node, b::Node) = depth(::LeafNode) = 1 depth(node::BinaryOpNode) = node.depth +""" + unroll(node::Node) + +Unroll a covariance kernel into a flat Vector of all intermediate kernels. +""" +function unroll end +unroll(node::LeafNode) = [node] +unroll(node::BinaryOpNode) = vcat(unroll(node.left), unroll(node.right), node) + + @doc raw""" WhiteNoise(value) @@ -527,7 +537,7 @@ function _show_pretty(io::IO, node::LeafNode, pre, vert_bars::Tuple; first=false print(io, indent_str * "$(pretty(node))\n") end -_pretty_BinaryOpNode(node::Plus) = '\uFF0B' +_pretty_BinaryOpNode(node::Plus) = '+' _pretty_BinaryOpNode(node::Times) = '\u00D7' _pretty_BinaryOpNode(node::ChangePoint) = ("CP$((node.location,node.scale))") function _show_pretty(io::IO, node::BinaryOpNode, pre, vert_bars::Tuple; first=false, last=true) @@ -614,6 +624,7 @@ export BinaryOpNode export eval_cov export compute_cov_matrix export compute_cov_matrix_vectorized +export unroll export WhiteNoise export Constant diff --git a/src/api.jl b/src/api.jl index 79a5ef0..b1f9e04 100644 --- a/src/api.jl +++ b/src/api.jl @@ -650,3 +650,55 @@ function Base.show(df::DataFrames.DataFrame) summary=false, header_crayon=DataFrames.PrettyTables.Crayons.Crayon(), eltypes=false, rowlabel=Symbol()) end + +""" + function decompose(model::GPModel) + +Decompose each particle within `model` into its constituent kernels. +Supposing that [`num_particles`](@ref)`(model)` equals ``k``, the return +value `models::Vector{GPModel}` of `decompose` is a length-``k`` vector of +[`GPModel`](@ref) instances. + +Therefore, `models[i]` is a [`GPModel`](@ref) that represents the +decomposition of particle `i` in `model` into its constituent kernels. Each +particle in `models[i]` corresponds to a kernel fragment in the covariance +for particle `i` of `model` (i.e., one particle for each [`GP.Node`](@ref) +in the covariance kernel). + +The weights of `models[i]` are arbitrary and have no meaningful value. + +This function is particularly useful for visualizing the individual time +series structures that make up each particle of `model`. +""" +function decompose(model::GPModel) + kernels = covariance_kernels(model) + unrolled = map(GP.unroll, kernels) + @assert length(kernels) == num_particles(model) + models = Vector{GPModel}(undef, length(kernels)) + # ERROR: type GPConfig has no field WhiteNoise + # noises = Model.transform_param.( + # :noise, + # [trace[:noise] for trace in model.pf_state.traces],) + # .+ AutoGP.Model.JITTER + for (i, kernel_list::Vector{GP.Node}) in enumerate(unrolled) + # ERROR: type GPConfig has no field WhiteNoise + # Add observation noise as a WhiteNoise kernel. + # typeof(kernel_list) + # push!(kernel_list, GP.WhiteNoise(noises[i])) + # Initialize new GPModel. + models[i] = GPModel( + model.ds, model.y; + n_particles=length(kernel_list), config=model.config) + # -- Copy transforms, since add_data! may have been called on model. + models[i].ds_transform = model.ds_transform + models[i].y_transform = model.y_transform + # Force update each particle to match the kernel fragment. + for (j, trace) in enumerate(models[i].pf_state.traces) + models[i].pf_state.traces[j] = Inference.node_to_trace( + kernel_list[j], model.pf_state.traces[i]) + end + # Weights are arbitrary. + models[i].pf_state.log_weights = zeros(length(models[i].pf_state.traces)) + end + return models +end diff --git a/src/inference_utils.jl b/src/inference_utils.jl index f870650..ef4ed47 100644 --- a/src/inference_utils.jl +++ b/src/inference_utils.jl @@ -253,3 +253,17 @@ function get_observations_choicemap(trace::Gen.Trace) end return observations end + +function node_to_trace(node::Node, trace::Gen.Trace) + config = Gen.get_args(trace)[2] + choicemap_obs = get_observations_choicemap(trace) + choicemap_node = Gen.choicemap() + Gen.set_submap!(choicemap_node, :tree, node_to_choicemap(node, config)) + constraints = merge(choicemap_node, choicemap_obs) + constraints[:noise] = trace[:noise] + return Gen.generate( + Gen.get_gen_fn(trace), + Gen.get_args(trace), + constraints, + )[1] +end