Skip to content

Commit

Permalink
Add JET.jl to the testing (#261)
Browse files Browse the repository at this point in the history
* Add JET.jl to the testing

* Fix two out of three errors

* Change the plot recipe a bit; but the error is not fixed yet

* Fix one of the JET issues

* Move the plot recipe into an extension to fix JET.jl complaints

* Remove `@test` as `JET.test_package` is a test already

* JuliaFormatter.jl
  • Loading branch information
nathanaelbosch authored Jan 8, 2024
1 parent 2b259e2 commit b5e36cf
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 17 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"

[weakdeps]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[extensions]
DiffEqDevToolsExt = "DiffEqDevTools"
RecipesBaseExt = "RecipesBase"

[compat]
ArrayAllocators = "0.3"
Expand Down
23 changes: 15 additions & 8 deletions src/solution_plotting.jl → ext/RecipesBaseExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
########################################################################################
# Plotting
########################################################################################
module RecipesBaseExt

using RecipesBase
using ProbNumDiffEq
using Statistics
import ProbNumDiffEq: AbstractProbODESolution
import SciMLBase: interpret_vars, getsyms

@recipe function f(
sol::AbstractProbODESolution;
idxs=nothing,
Expand All @@ -12,11 +17,11 @@
)
if vars !== nothing
Base.depwarn(
"To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.",
"To maintain consistency with solution indexing, keyword argument `vars` will be removed in a future version. Please use keyword argument `idxs` instead.",
:f; force=true)
(idxs !== nothing) &&
error(
"Simultaneously using keywords vars and idxs is not supported. Please only use idxs.",
"Simultaneously using keywords `vars` and `idxs` is not supported. Please only use idxs.",
)
idxs = vars
end
Expand All @@ -28,8 +33,8 @@
sol_rvs = sol_rvs[tstart.<=times.<=tend]
times = times[tstart.<=times.<=tend]
end
values = stack(mean.(sol_rvs))
stds = stack(std.(sol_rvs))
values = stack(mean.(sol_rvs))'
stds = stack(std.(sol_rvs))'

if isnothing(idxs)
ribbon --> ribbon_width * stds
Expand Down Expand Up @@ -62,7 +67,9 @@
end
return _x, _y, _z
else
error("Error with `vars` argument")
error("Error with `idxs` argument")
end
end
end

end
3 changes: 1 addition & 2 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ using FunctionWrappersWrappers
using TaylorSeries, TaylorIntegration
@reexport using StructArrays
using SimpleUnPack
using RecipesBase
using RecursiveArrayTools
using ForwardDiff
using ExponentialUtilities
Expand Down Expand Up @@ -75,7 +74,7 @@ include("initialization/classicsolverinit.jl")

include("solution.jl")
include("solution_sampling.jl")
include("solution_plotting.jl")
# include("solution_plotting.jl")

include("integrator_utils.jl")
include("filtering/markov_kernel.jl")
Expand Down
11 changes: 7 additions & 4 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ function OrdinaryDiffEq.alg_cache(
A, Q, Ah, Qh, P, PI = initialize_transition_matrices(FAC, prior, dt)
F, L = to_sde(prior)
F, L = to_factorized_matrix(FAC, F), to_factorized_matrix(FAC, L)
FHG_method =
!(prior isa IWP) ? FiniteHorizonGramians.ExpAndGram{eltype(F),13}() : nothing
FHG_cache =
!(prior isa IWP) ? FiniteHorizonGramians.alloc_mem(F, L, FHG_method) : nothing
FHG_method, FHG_cache = if !(prior isa IWP)
m = FiniteHorizonGramians.ExpAndGram{eltype(F),13}()
c = FiniteHorizonGramians.alloc_mem(F, L, m)
m, c
else
nothing, nothing
end

# Measurement Model
measurement_model = make_measurement_model(f)
Expand Down
5 changes: 5 additions & 0 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ struct IsometricKroneckerProduct{T<:Number,TB<:AbstractMatrix} <:
end
IsometricKroneckerProduct(ldim::Integer, B::AbstractVector) =
IsometricKroneckerProduct(ldim, reshape(B, :, 1))
IsometricKroneckerProduct(M::AbstractMatrix) = throw(
ArgumentError(
"Can not create IsometricKroneckerProduct from the provided matrix of type $(typeof(M))",
),
)

const IKP = IsometricKroneckerProduct

Expand Down
2 changes: 1 addition & 1 deletion src/priors/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function initialize_transition_matrices(
p::AbstractODEFilterPrior,
dt,
)
error("The chosen prior can not be implemented with a $fac factorization")
error("The chosen prior can not be implemented with a $FAC factorization")
end

"""
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteHorizonGramians = "b59a298d-d283-4a37-9369-85a9f9a111a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearRegression = "92481ed7-9fb7-40fd-80f2-46fd0f076581"
Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, SafeTestsets, Aqua, TimerOutputs
using Test, SafeTestsets, Aqua, JET, TimerOutputs
using ProbNumDiffEq
using ODEProblemLibrary

Expand Down Expand Up @@ -121,13 +121,19 @@ const GROUP = get(ENV, "GROUP", "All")
end

if GROUP == "All"
@timedtestset "Aqua.jl" begin
@timedtestset "Code quality (Aqua.jl)" begin
Aqua.test_all(
ProbNumDiffEq,
ambiguities=false,
piracies=false,
)
end
@timedtestset "Code linting (JET.jl)" begin
JET.test_package(
ProbNumDiffEq;
target_defined_modules=true,
)
end
end
end

Expand Down

0 comments on commit b5e36cf

Please sign in to comment.