Skip to content

Commit

Permalink
Update to AutoDiffOperators v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Oct 7, 2024
1 parent 91b2a6d commit c89bf98
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"

[compat]
AutoDiffOperators = "0.1.3"
AutoDiffOperators = "0.1.3, 0.2"
ChainRulesCore = "0.9.44, 0.10, 1"
Distributed = "1"
Distributions = "0.25"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/advanced_tutorial_lit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using FFTW
import ForwardDiff, Zygote
using AutoDiffOperators

context = MGVIContext(ADModule(:Zygote))
context = MGVIContext(ADSelector(Zygote))

#-
Random.seed!(84612);
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorial_lit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import Zygote
#
# We assume errors are normally distributed with unknown covariance, which has to be learned as well.

context = MGVIContext(ADModule(:Zygote))
context = MGVIContext(ADSelector(Zygote))

const _x1_grid = [Float64(i)/10 for i in 1:25]
const _x2_grid = [Float64(i)/10 + 0.1 for i in 1:15]
Expand Down
2 changes: 1 addition & 1 deletion src/mgvi_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ normal distribution of the same dimensionality as `init_param_point`.
using Random, Distributions, MGVI
import Zygote
context = MGVIContext(ADModule(:Zygote))
context = MGVIContext(ADSelector(Zygote))
model(x::AbstractVector) = Normal(x[1], 0.2)
true_param = [2.0]
Expand Down
12 changes: 6 additions & 6 deletions test/test_jacobians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ Test.@testset "test_jacobians_consistency" begin
x = rand(6); l = rand(4); r = rand(6)
J_ref = ForwardDiff.jacobian(f, x)

_, J1 = @inferred with_jacobian(f, x, Matrix, ADModule(:ForwardDiff))
_, J2 = @inferred with_jacobian(f, x, LinearMap, ADModule(:Zygote))
_, J3 = @inferred with_jacobian(f, x, LinearMap, ADModule(:ForwardDiff))
_, J1 = @inferred with_jacobian(f, x, Matrix, ADSelector(ForwardDiff))
_, J2 = @inferred with_jacobian(f, x, LinearMap, ADSelector(Zygote))
_, J3 = @inferred with_jacobian(f, x, LinearMap, ADSelector(ForwardDiff))

for J in (J1, J2, J3)
@test @inferred(Matrix(J)) J_ref
Expand All @@ -38,9 +38,9 @@ Test.@testset "test_jacobians_consistency" begin
_flat_model = MGVI.flat_params ModelPolyfit.model
true_params = ModelPolyfit.true_params

_, full_jac = @inferred with_jacobian(_flat_model, true_params, Matrix, ADModule(:ForwardDiff))
_, fwdder_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADModule(:ForwardDiff))
_, fwdrevad_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADModule(:Zygote))
_, full_jac = @inferred with_jacobian(_flat_model, true_params, Matrix, ADSelector(ForwardDiff))
_, fwdder_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADSelector(ForwardDiff))
_, fwdrevad_jac = @inferred with_jacobian(_flat_model, true_params, LinearMap, ADSelector(Zygote))

for i in 1:min(size(full_jac)...)
vec = rand(size(full_jac, 2))
Expand Down
2 changes: 1 addition & 1 deletion test/test_mgvi_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ if !isdefined(Main, :ModelPolyfit)
end

Test.@testset "test_mgvi_optimize_step" begin
context = MGVIContext(ADModule(:Zygote))
context = MGVIContext(ADSelector(Zygote))

model = ModelPolyfit.model
true_params = ModelPolyfit.true_params
Expand Down
2 changes: 1 addition & 1 deletion test/test_samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if :ModelPolyfit ∉ names(Main)
end

Test.@testset "test_cmp_residual_samplers" begin
context = MGVIContext(ADModule(:Zygote))
context = MGVIContext(ADSelector(Zygote))

model = ModelPolyfit.model
true_params = ModelPolyfit.true_params
Expand Down

0 comments on commit c89bf98

Please sign in to comment.