diff --git a/Project.toml b/Project.toml index 23764c2f..8a88b67b 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractGPs = "0.5.15" +AbstractGPs = "0.5.17" Bessels = "0.2.8" BlockDiagonals = "0.1.7" ChainRulesCore = "1" @@ -25,5 +25,5 @@ FillArrays = "0.13.0 - 0.13.7, 1" KernelFunctions = "0.9, 0.10.1" StaticArrays = "1" StructArrays = "0.5, 0.6" -Zygote = "0.6" +Zygote = "0.6.65" julia = "1.6" diff --git a/README.md b/README.md index f794c680..035bd8db 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,6 @@ TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Abstrac [JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE) -# Dependency Status - -In the interest of managing expectations, please note that TemporalGPs does not currently operate with the most current version of AbstractGPs / Zygote / ChainRules. I (Will) am aware of this problem, and will sort it out as soon as I have the time! - # Installation TemporalGPs.jl is registered, so simply type the following at the REPL: diff --git a/examples/approx_space_time_inference.jl b/examples/approx_space_time_inference.jl index e8ac4838..51f5d795 100644 --- a/examples/approx_space_time_inference.jl +++ b/examples/approx_space_time_inference.jl @@ -61,6 +61,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" heatmap(reshape(σ_post_marginals, N_pr, T)); layout=(1, 2), ), - "posterior.png", + "approx_space_time_inference.png", ); end diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index eeb68f6a..ca007981 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -100,6 +100,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" heatmap(reshape(σ_post_marginals, N_pr, T)); layout=(1, 2), ), - "posterior.png", + "approx_space_time_learning.png", ); end diff --git a/examples/augmented_inference.jl b/examples/augmented_inference.jl index 5f1acc62..206e587d 100644 --- a/examples/augmented_inference.jl +++ b/examples/augmented_inference.jl @@ -1,6 +1,6 @@ using AbstractGPs using TemporalGPs -using Distributions +using Distributions: Bernoulli using StatsFuns: logistic # In this example we are showing how to work with non-Gaussian likelihoods, @@ -73,5 +73,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" plot!(plt, x_pr, f_post_samples; color=:red, alpha=0.3, label=""); plot!(plt, x, f_true; label="", lw=2.0, color=:blue); # Plot the true latent GP on top scatter!(plt, x, y; label="", markersize=1.0, alpha=1.0); # Plot the data - savefig(plt, "posterior.png"); + savefig(plt, "augmented_inference.png"); end diff --git a/examples/exact_space_time_inference.jl b/examples/exact_space_time_inference.jl index 8c329f39..b6045678 100644 --- a/examples/exact_space_time_inference.jl +++ b/examples/exact_space_time_inference.jl @@ -59,6 +59,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" heatmap(reshape(σ_post_marginals, N, T_pr)); layout=(1, 2), ), - "posterior.png", + "exact_space_time_inference.png", ); end diff --git a/examples/exact_space_time_learning.jl b/examples/exact_space_time_learning.jl index 7971037a..f9664ac3 100644 --- a/examples/exact_space_time_learning.jl +++ b/examples/exact_space_time_learning.jl @@ -39,7 +39,7 @@ end # Exact inference only works for such grids. # Times must be increasing, points in space can be anywhere. N = 50; -T = 1_000; +T = 500; points_in_space = collect(range(-3.0, 3.0; length=N)); points_in_time = RegularSpacing(0.0, 0.01, T); x = RectilinearGrid(points_in_space, points_in_time); @@ -73,7 +73,7 @@ final_params = unpack(training_results.minimizer) f_post = posterior(build_gp(final_params)(x, final_params.var_noise), y); # Specify some locations at which to make predictions. -T_pr = 1200; +T_pr = 600; points_in_time_pr = RegularSpacing(0.0, 0.01, T_pr); x_pr = RectilinearGrid(points_in_space, points_in_time_pr); @@ -93,6 +93,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" heatmap(reshape(σ_post_marginals, N, T_pr)); layout=(1, 2), ), - "posterior.png", + "exact_space_time_learning.png", ); end diff --git a/examples/exact_time_inference.jl b/examples/exact_time_inference.jl index 37661be4..ad6a0d68 100644 --- a/examples/exact_time_inference.jl +++ b/examples/exact_time_inference.jl @@ -48,5 +48,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1); plot!(plt, f_post(x_pr); ribbon_scale=3.0, label=""); plot!(x_pr, f_post_samples; color=:red, label=""); - savefig(plt, "posterior.png"); + savefig(plt, "exact_time_inference.png"); end diff --git a/examples/exact_time_learning.jl b/examples/exact_time_learning.jl index 64b79218..b4b05cb4 100644 --- a/examples/exact_time_learning.jl +++ b/examples/exact_time_learning.jl @@ -86,5 +86,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE" scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1); plot!(plt, f_post(x_pr); ribbon_scale=3.0, label=""); plot!(plt, x_pr, f_post_samples; color=:red, label=""); - savefig(plt, "posterior.png"); + savefig(plt, "exact_time_learning.png"); end diff --git a/src/models/linear_gaussian_conditionals.jl b/src/models/linear_gaussian_conditionals.jl index c342e943..0b5fe798 100644 --- a/src/models/linear_gaussian_conditionals.jl +++ b/src/models/linear_gaussian_conditionals.jl @@ -46,6 +46,7 @@ be equivalent to function predict(x::Gaussian, f::AbstractLGC) A, a, Q = get_fields(f) m, P = get_fields(x) + # Symmetric wrapper needed for numerical stability. Do not unwrap. return Gaussian(A * m + a, (A * symmetric(P)) * A' + Q) end diff --git a/src/space_time/pseudo_point.jl b/src/space_time/pseudo_point.jl index 6517e685..bcb90d5a 100644 --- a/src/space_time/pseudo_point.jl +++ b/src/space_time/pseudo_point.jl @@ -103,10 +103,11 @@ function kernel_diagonals(k::DTCSeparable, x::RegularInTime) space_kernel = k.k.l time_kernel = k.k.r time_vars = kernelmatrix_diag(time_kernel, get_times(x)) - return map( - (s_t, x_r) -> Diagonal(kernelmatrix_diag(space_kernel, x_r) * s_t), - time_vars, - x.vs, + return Diagonal.( + kernelmatrix_diag.( + Ref(space_kernel), + x.vs + ) .* time_vars ) end @@ -185,7 +186,7 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag C = \(K_space_z_chol, C__) Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C) - cs = _map((h, v) -> fill(h, length(v)), hs_t, x.vs) # This should currently be zero. + cs = fill.(hs_t, length.(x.vs)) # This should currently be zero. Hs = _map( ((I, H_t), ) -> kron(I, H_t), zip(Fill(ident_M, N), Hs_t), diff --git a/src/space_time/regular_in_time.jl b/src/space_time/regular_in_time.jl index 67d51099..c3abac1b 100644 --- a/src/space_time/regular_in_time.jl +++ b/src/space_time/regular_in_time.jl @@ -24,7 +24,16 @@ function Base.collect(x::RegularInTime) return [(x, t) for (x, t) in zip(space_inputs, time_inputs)] end -Base.getindex(x::RegularInTime, n::Int) = collect(x)[n] +function Base.getindex(x::RegularInTime, n::Int) + n ≤ 0 && throw(BoundsError(x, n)) + sum_of_lengths = 0 + for (i, v) in enumerate(x.vs) + temp = sum_of_lengths + length(v) + temp ≥ n && return (v[n - sum_of_lengths], x.ts[i]) + sum_of_lengths = temp + end + throw(BoundsError(x, n)) +end Base.show(io::IO, x::RegularInTime) = Base.show(io::IO, collect(x)) diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index 23bed859..fe9e67eb 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -28,13 +28,6 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c) # StaticArrays # # ---------------------------------------------------------------------------- # -function ProjectTo(x::SArray{S,T}) where {S, T} - return ProjectTo{SArray}(; element=_eltype_projectto(T), axes=axes(x), static_size=S) -end - -(proj::ProjectTo{SArray})(dx::SArray) = SArray{proj.static_size}(dx.data) -(proj::ProjectTo{SArray})(dx::AbstractArray) = SArray{proj.static_size}(Tuple(dx)) - function rrule(::Type{T}, x::Tuple) where {T<:SArray} SArray_rrule(Δ) = begin (NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...)) diff --git a/test/runtests.jl b/test/runtests.jl index c5ce381a..0620b706 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,11 +110,16 @@ if GROUP == "examples" Pkg.resolve() Pkg.instantiate() - include(joinpath(pkgpath, "examples", "exact_time_inference.jl")) - include(joinpath(pkgpath, "examples", "exact_time_learning.jl")) - include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl")) - include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl")) - include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl")) - include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl")) - include(joinpath(pkgpath, "examples", "augmented_inference.jl")) + function include_with_info(filename) + @info "Running examples/$filename" + include(joinpath(pkgpath, "examples", filename)) + end + + include_with_info("exact_time_inference.jl") + include_with_info("exact_time_learning.jl") + include_with_info("exact_space_time_inference.jl") + include_with_info("exact_space_time_learning.jl") + include_with_info("approx_space_time_inference.jl") + include_with_info("approx_space_time_learning.jl") + include_with_info("augmented_inference.jl") end diff --git a/test/space_time/pseudo_point.jl b/test/space_time/pseudo_point.jl index cc230e6a..8e43fa04 100644 --- a/test/space_time/pseudo_point.jl +++ b/test/space_time/pseudo_point.jl @@ -96,7 +96,7 @@ include("../models/model_test_utils.jl") validate_dims(lgssm) # The two approaches to DTC computation should be equivalent up to roundoff error. - dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, y) + dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, y) dtc_sde = dtc(fx, y, z_r) @test dtc_naive ≈ dtc_sde rtol=1e-6 @@ -150,7 +150,7 @@ include("../models/model_test_utils.jl") fx_naive = f_naive(naive_inputs_missings, 0.1) # Compute DTC using both approaches. - dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, naive_y_missings) + dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, naive_y_missings) dtc_sde = dtc(fx, y_missing, z_r) @test dtc_naive ≈ dtc_sde rtol=1e-7 atol=1e-7 diff --git a/test/space_time/regular_in_time.jl b/test/space_time/regular_in_time.jl index 22a1fc86..2b98fe01 100644 --- a/test/space_time/regular_in_time.jl +++ b/test/space_time/regular_in_time.jl @@ -10,4 +10,7 @@ using TemporalGPs: RegularInTime @test prod(size(x)) == length(collect(x)) @test all([getindex(x, n) for n in 1:length(x)] .== collect(x)) + @test_throws BoundsError x[0] + @test_throws BoundsError x[-1] + @test_throws BoundsError x[length(x) + 1] end