From c89bf98b5bd1d37a5932c2a02c212b94a571fd9f Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 7 Oct 2024 10:05:17 +0200 Subject: [PATCH] Update to AutoDiffOperators v0.2 --- Project.toml | 2 +- docs/src/advanced_tutorial_lit.jl | 2 +- docs/src/tutorial_lit.jl | 2 +- src/mgvi_impl.jl | 2 +- test/test_jacobians.jl | 12 ++++++------ test/test_mgvi_impl.jl | 2 +- test/test_samplers.jl | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index a56c7c8..d135235 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f" [compat] -AutoDiffOperators = "0.1.3" +AutoDiffOperators = "0.1.3, 0.2" ChainRulesCore = "0.9.44, 0.10, 1" Distributed = "1" Distributions = "0.25" diff --git a/docs/src/advanced_tutorial_lit.jl b/docs/src/advanced_tutorial_lit.jl index 2cb3ca5..68cb9ed 100644 --- a/docs/src/advanced_tutorial_lit.jl +++ b/docs/src/advanced_tutorial_lit.jl @@ -36,7 +36,7 @@ using FFTW import ForwardDiff, Zygote using AutoDiffOperators -context = MGVIContext(ADModule(:Zygote)) +context = MGVIContext(ADSelector(Zygote)) #- Random.seed!(84612); diff --git a/docs/src/tutorial_lit.jl b/docs/src/tutorial_lit.jl index 25b46e6..ee74481 100644 --- a/docs/src/tutorial_lit.jl +++ b/docs/src/tutorial_lit.jl @@ -19,7 +19,7 @@ import Zygote # # We assume errors are normally distributed with unknown covariance, which has to be learned as well. -context = MGVIContext(ADModule(:Zygote)) +context = MGVIContext(ADSelector(Zygote)) const _x1_grid = [Float64(i)/10 for i in 1:25] const _x2_grid = [Float64(i)/10 + 0.1 for i in 1:15] diff --git a/src/mgvi_impl.jl b/src/mgvi_impl.jl index a5e4572..ecc8064 100644 --- a/src/mgvi_impl.jl +++ b/src/mgvi_impl.jl @@ -42,7 +42,7 @@ normal distribution of the same dimensionality as `init_param_point`. using Random, Distributions, MGVI import Zygote -context = MGVIContext(ADModule(:Zygote)) +context = MGVIContext(ADSelector(Zygote)) model(x::AbstractVector) = Normal(x[1], 0.2) true_param = [2.0] diff --git a/test/test_jacobians.jl b/test/test_jacobians.jl index 49cef4e..b13bb50 100644 --- a/test/test_jacobians.jl +++ b/test/test_jacobians.jl @@ -21,9 +21,9 @@ Test.@testset "test_jacobians_consistency" begin x = rand(6); l = rand(4); r = rand(6) J_ref = ForwardDiff.jacobian(f, x) - _, J1 = @inferred with_jacobian(f, x, Matrix, ADModule(:ForwardDiff)) - _, J2 = @inferred with_jacobian(f, x, LinearMap, ADModule(:Zygote)) - _, J3 = @inferred with_jacobian(f, x, LinearMap, ADModule(:ForwardDiff)) + _, J1 = @inferred with_jacobian(f, x, Matrix, ADSelector(ForwardDiff)) + _, J2 = @inferred with_jacobian(f, x, LinearMap, ADSelector(Zygote)) + _, J3 = @inferred with_jacobian(f, x, LinearMap, ADSelector(ForwardDiff)) for J in (J1, J2, J3) @test @inferred(Matrix(J)) ≈ J_ref @@ -38,9 +38,9 @@ Test.@testset "test_jacobians_consistency" begin _flat_model = MGVI.flat_params ∘ ModelPolyfit.model true_params = ModelPolyfit.true_params - _, full_jac = @inferred with_jacobian(_flat_model, true_params, Matrix, ADModule(:ForwardDiff)) - _, fwdder_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADModule(:ForwardDiff)) - _, fwdrevad_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADModule(:Zygote)) + _, full_jac = @inferred with_jacobian(_flat_model, true_params, Matrix, ADSelector(ForwardDiff)) + _, fwdder_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADSelector(ForwardDiff)) + _, fwdrevad_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADSelector(Zygote)) for i in 1:min(size(full_jac)...) vec = rand(size(full_jac, 2)) diff --git a/test/test_mgvi_impl.jl b/test/test_mgvi_impl.jl index 0687538..8c8f185 100644 --- a/test/test_mgvi_impl.jl +++ b/test/test_mgvi_impl.jl @@ -13,7 +13,7 @@ if !isdefined(Main, :ModelPolyfit) end Test.@testset "test_mgvi_optimize_step" begin - context = MGVIContext(ADModule(:Zygote)) + context = MGVIContext(ADSelector(Zygote)) model = ModelPolyfit.model true_params = ModelPolyfit.true_params diff --git a/test/test_samplers.jl b/test/test_samplers.jl index 4b6b9b0..a194a29 100644 --- a/test/test_samplers.jl +++ b/test/test_samplers.jl @@ -14,7 +14,7 @@ if :ModelPolyfit ∉ names(Main) end Test.@testset "test_cmp_residual_samplers" begin - context = MGVIContext(ADModule(:Zygote)) + context = MGVIContext(ADSelector(Zygote)) model = ModelPolyfit.model true_params = ModelPolyfit.true_params