From 998db881c199b05f3f7855fa5184ae781957c5f2 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Tue, 26 Nov 2024 15:53:26 -0500 Subject: [PATCH] Patch release for fixing division bug --- Project.toml | 2 +- examples/ode.jl | 45 --------------------------------------------- src/chainrules.jl | 4 +++- src/utils.jl | 6 ++++-- test/primitive.jl | 8 ++++---- 5 files changed, 12 insertions(+), 53 deletions(-) delete mode 100644 examples/ode.jl diff --git a/Project.toml b/Project.toml index 52f8afb..0d81ed4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaylorDiff" uuid = "b36ab563-344f-407b-a36a-4f200bebf99c" authors = ["Songchen Tan "] -version = "0.3.0" +version = "0.3.1" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" diff --git a/examples/ode.jl b/examples/ode.jl deleted file mode 100644 index 5ad2240..0000000 --- a/examples/ode.jl +++ /dev/null @@ -1,45 +0,0 @@ -using ADTypes -using DifferentiationInterface -using ModelingToolkit, DifferentialEquations -using TaylorDiff, ForwardDiff -using Enzyme, Zygote, ReverseDiff -using SciMLSensitivity - -@parameters a -@variables t x1(t) -D = Differential(t) -states = [x1] -parameters = [a] - -@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters) -model = structural_simplify(pre_model) - -ic = Dict(x1 => 1.0) -p_true = Dict(a => 2.0) - -problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true) -soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12) -display(soln(0.5, idxs = [x1])) - -function different_time(new_ic, new_params, new_t) - #newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params) - #newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params) - newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p = new_params) - newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0)) - new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - return (soln(new_t, idxs = [x1])) -end - -function just_t(new_t) - return different_time(ic, p_true, new_t)[1] -end -display(different_time(ic, p_true, 2e-5)) -display(just_t(0.5)) - -#display(ForwardDiff.derivative(just_t,1.0)) -display(TaylorDiff.derivative(just_t, 1.0, Val(1))) #isnan error -#display(value_and_gradient(just_t, AutoForwardDiff(), 1.0)) -#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0)) -#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0)) -#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0)) -#display(value_and_gradient(just_t, AutoZygote(), 1.0)) diff --git a/src/chainrules.jl b/src/chainrules.jl index def2ee3..efd039e 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -30,7 +30,9 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} end function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P} - partials_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄) + function partials_pullback(v̄::NTuple{P, A}) + NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄) + end return partials(t), partials_pullback end diff --git a/src/utils.jl b/src/utils.jl index cef7a01..fee00eb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -91,8 +91,10 @@ function process(d, expr) @match x begin a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx]) (a_ = b_) => (push!(known_names, a); :($a = $b)) - (a_ += b_) => a in known_names ? :($a += $b) : (push!(known_names, a); :($a = $b)) - (a_ -= b_) => a in known_names ? :($a -= $b) : (push!(known_names, a); :($a = -$b)) + (a_ += b_) => a in known_names ? :($a += $b) : + (push!(known_names, a); :($a = $b)) + (a_ -= b_) => a in known_names ? :($a -= $b) : + (push!(known_names, a); :($a = -$b)) TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...)))) _ => x end diff --git a/test/primitive.jl b/test/primitive.jl index e00801a..cdd38c3 100644 --- a/test/primitive.jl +++ b/test/primitive.jl @@ -50,10 +50,10 @@ end # end end -@testset "Multi-argument functions" begin - @test derivative(x -> 1 + 1/x, 1.0, Val(1))≈-1.0 rtol=1e-6 - @test derivative(x -> (x+1)/x, 1.0, Val(1))≈-1.0 rtol=1e-6 - @test derivative(x -> x/x, 1.0, Val(1))≈ 0.0 rtol=1e-6 +@testset "Multi-argument functions" begin + @test derivative(x -> 1 + 1 / x, 1.0, Val(1))≈-1.0 rtol=1e-6 + @test derivative(x -> (x + 1) / x, 1.0, Val(1))≈-1.0 rtol=1e-6 + @test derivative(x -> x / x, 1.0, Val(1))≈0.0 rtol=1e-6 end @testset "Corner cases" begin