Skip to content

Commit

Permalink
Merge pull request #115 from JuliaGaussianProcesses/fix-CI
Browse files Browse the repository at this point in the history
Fix CI
  • Loading branch information
simsurace authored Oct 4, 2023
2 parents d32aaa8 + 07588b1 commit f8b8302
Show file tree
Hide file tree
Showing 16 changed files with 46 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.5.15"
AbstractGPs = "0.5.17"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7, 1"
KernelFunctions = "0.9, 0.10.1"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6"
Zygote = "0.6.65"
julia = "1.6"
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Abstrac

[JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE)

# Dependency Status

In the interest of managing expectations, please note that TemporalGPs does not currently operate with the most current version of AbstractGPs / Zygote / ChainRules. I (Will) am aware of this problem, and will sort it out as soon as I have the time!

# Installation

TemporalGPs.jl is registered, so simply type the following at the REPL:
Expand Down
2 changes: 1 addition & 1 deletion examples/approx_space_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N_pr, T));
layout=(1, 2),
),
"posterior.png",
"approx_space_time_inference.png",
);
end
2 changes: 1 addition & 1 deletion examples/approx_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N_pr, T));
layout=(1, 2),
),
"posterior.png",
"approx_space_time_learning.png",
);
end
4 changes: 2 additions & 2 deletions examples/augmented_inference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using AbstractGPs
using TemporalGPs
using Distributions
using Distributions: Bernoulli
using StatsFuns: logistic

# In this example we are showing how to work with non-Gaussian likelihoods,
Expand Down Expand Up @@ -73,5 +73,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
plot!(plt, x_pr, f_post_samples; color=:red, alpha=0.3, label="");
plot!(plt, x, f_true; label="", lw=2.0, color=:blue); # Plot the true latent GP on top
scatter!(plt, x, y; label="", markersize=1.0, alpha=1.0); # Plot the data
savefig(plt, "posterior.png");
savefig(plt, "augmented_inference.png");
end
2 changes: 1 addition & 1 deletion examples/exact_space_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N, T_pr));
layout=(1, 2),
),
"posterior.png",
"exact_space_time_inference.png",
);
end
6 changes: 3 additions & 3 deletions examples/exact_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
# Exact inference only works for such grids.
# Times must be increasing, points in space can be anywhere.
N = 50;
T = 1_000;
T = 500;
points_in_space = collect(range(-3.0, 3.0; length=N));
points_in_time = RegularSpacing(0.0, 0.01, T);
x = RectilinearGrid(points_in_space, points_in_time);
Expand Down Expand Up @@ -73,7 +73,7 @@ final_params = unpack(training_results.minimizer)
f_post = posterior(build_gp(final_params)(x, final_params.var_noise), y);

# Specify some locations at which to make predictions.
T_pr = 1200;
T_pr = 600;
points_in_time_pr = RegularSpacing(0.0, 0.01, T_pr);
x_pr = RectilinearGrid(points_in_space, points_in_time_pr);

Expand All @@ -93,6 +93,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N, T_pr));
layout=(1, 2),
),
"posterior.png",
"exact_space_time_learning.png",
);
end
2 changes: 1 addition & 1 deletion examples/exact_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
plot!(x_pr, f_post_samples; color=:red, label="");
savefig(plt, "posterior.png");
savefig(plt, "exact_time_inference.png");
end
2 changes: 1 addition & 1 deletion examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
plot!(plt, x_pr, f_post_samples; color=:red, label="");
savefig(plt, "posterior.png");
savefig(plt, "exact_time_learning.png");
end
1 change: 1 addition & 0 deletions src/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ be equivalent to
function predict(x::Gaussian, f::AbstractLGC)
A, a, Q = get_fields(f)
m, P = get_fields(x)

# Symmetric wrapper needed for numerical stability. Do not unwrap.
return Gaussian(A * m + a, (A * symmetric(P)) * A' + Q)
end
Expand Down
11 changes: 6 additions & 5 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ function kernel_diagonals(k::DTCSeparable, x::RegularInTime)
space_kernel = k.k.l
time_kernel = k.k.r
time_vars = kernelmatrix_diag(time_kernel, get_times(x))
return map(
(s_t, x_r) -> Diagonal(kernelmatrix_diag(space_kernel, x_r) * s_t),
time_vars,
x.vs,
return Diagonal.(
kernelmatrix_diag.(
Ref(space_kernel),
x.vs
) .* time_vars
)
end

Expand Down Expand Up @@ -185,7 +186,7 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag
C = \(K_space_z_chol, C__)
Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C)

cs = _map((h, v) -> fill(h, length(v)), hs_t, x.vs) # This should currently be zero.
cs = fill.(hs_t, length.(x.vs)) # This should currently be zero.
Hs = _map(
((I, H_t), ) -> kron(I, H_t),
zip(Fill(ident_M, N), Hs_t),
Expand Down
11 changes: 10 additions & 1 deletion src/space_time/regular_in_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@ function Base.collect(x::RegularInTime)
return [(x, t) for (x, t) in zip(space_inputs, time_inputs)]
end

Base.getindex(x::RegularInTime, n::Int) = collect(x)[n]
function Base.getindex(x::RegularInTime, n::Int)
n 0 && throw(BoundsError(x, n))
sum_of_lengths = 0
for (i, v) in enumerate(x.vs)
temp = sum_of_lengths + length(v)
temp n && return (v[n - sum_of_lengths], x.ts[i])
sum_of_lengths = temp
end
throw(BoundsError(x, n))
end

Base.show(io::IO, x::RegularInTime) = Base.show(io::IO, collect(x))

Expand Down
7 changes: 0 additions & 7 deletions src/util/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c)
# StaticArrays #
# ---------------------------------------------------------------------------- #

function ProjectTo(x::SArray{S,T}) where {S, T}
return ProjectTo{SArray}(; element=_eltype_projectto(T), axes=axes(x), static_size=S)
end

(proj::ProjectTo{SArray})(dx::SArray) = SArray{proj.static_size}(dx.data)
(proj::ProjectTo{SArray})(dx::AbstractArray) = SArray{proj.static_size}(Tuple(dx))

function rrule(::Type{T}, x::Tuple) where {T<:SArray}
SArray_rrule(Δ) = begin
(NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...))
Expand Down
19 changes: 12 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ if GROUP == "examples"
Pkg.resolve()
Pkg.instantiate()

include(joinpath(pkgpath, "examples", "exact_time_inference.jl"))
include(joinpath(pkgpath, "examples", "exact_time_learning.jl"))
include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl"))
include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl"))
include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl"))
include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl"))
include(joinpath(pkgpath, "examples", "augmented_inference.jl"))
function include_with_info(filename)
@info "Running examples/$filename"
include(joinpath(pkgpath, "examples", filename))
end

include_with_info("exact_time_inference.jl")
include_with_info("exact_time_learning.jl")
include_with_info("exact_space_time_inference.jl")
include_with_info("exact_space_time_learning.jl")
include_with_info("approx_space_time_inference.jl")
include_with_info("approx_space_time_learning.jl")
include_with_info("augmented_inference.jl")
end
4 changes: 2 additions & 2 deletions test/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ include("../models/model_test_utils.jl")
validate_dims(lgssm)

# The two approaches to DTC computation should be equivalent up to roundoff error.
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, y)
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, y)
dtc_sde = dtc(fx, y, z_r)
@test dtc_naive dtc_sde rtol=1e-6

Expand Down Expand Up @@ -150,7 +150,7 @@ include("../models/model_test_utils.jl")
fx_naive = f_naive(naive_inputs_missings, 0.1)

# Compute DTC using both approaches.
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, naive_y_missings)
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, naive_y_missings)
dtc_sde = dtc(fx, y_missing, z_r)
@test dtc_naive dtc_sde rtol=1e-7 atol=1e-7

Expand Down
3 changes: 3 additions & 0 deletions test/space_time/regular_in_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ using TemporalGPs: RegularInTime
@test prod(size(x)) == length(collect(x))

@test all([getindex(x, n) for n in 1:length(x)] .== collect(x))
@test_throws BoundsError x[0]
@test_throws BoundsError x[-1]
@test_throws BoundsError x[length(x) + 1]
end

2 comments on commit f8b8302

@simsurace
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/91700

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.5 -m "<description of version>" f8b8302d33c4394f6b48cbdd4861fbf667c1fd54
git push origin v0.6.5

Please sign in to comment.