From 32a41f185844bcff9e153474c5003c3114a15006 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Wed, 27 Sep 2023 15:32:57 -0400 Subject: [PATCH] Add support for power function at 0; add tests --- src/primitive.jl | 42 +++++++++++++++++++++--------------------- test/primitive.jl | 8 ++++++-- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/primitive.jl b/src/primitive.jl index 48f82f0..ffc2917 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -110,37 +110,37 @@ end @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: Number, T, N} ex = quote v = value(t) - v1 = ^(v[1], n) + w11 = 1 + u1 = ^(v[1], n) end - for i in 2:N + for k in 1:N ex = quote $ex - $(Symbol('v', i)) = +($([:((n * $(binomial(i - 2, j - 1)) - - $(binomial(i - 2, j - 2))) * $(Symbol('v', j)) * - v[$(i + 1 - j)]) - for j in 1:(i - 1)]...)) / v[1] + $(Symbol('p', k)) = ^(v[1], n - $(k - 1)) end end - ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...))) - return :(@inbounds $ex) -end - -@generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: Integer, T, N} - # TODO: optimize for small powers - ex = quote - v = value(t) - v1 = ^(v[1], n) - end for i in 2:N + subex = quote + $(Symbol('w', i, 1)) = 0 + end + for k in 2:i + subex = quote + $subex + $(Symbol('w', i, k)) = +($([:((n * $(binomial(i - 2, j - 1)) - + $(binomial(i - 2, j - 2))) * + $(Symbol('w', j, k - 1)) * + v[$(i + 1 - j)]) + for j in (k - 1):(i - 1)]...)) + end + end ex = quote $ex - $(Symbol('v', i)) = +($([:((n * $(binomial(i - 2, j - 1)) - - $(binomial(i - 2, j - 2))) * $(Symbol('v', j)) * - v[$(i + 1 - j)]) - for j in 1:(i - 1)]...)) / v[1] + $subex + $(Symbol('u', i)) = +($([:($(Symbol('w', i, k)) * $(Symbol('p', k))) + for k in 2:i]...)) end end - ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...))) + ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...))) return :(@inbounds $ex) end diff --git a/test/primitive.jl b/test/primitive.jl index 6769276..46dab5e 100644 --- a/test/primitive.jl +++ b/test/primitive.jl @@ -30,14 +30,18 @@ end end @testset "Binary functions" begin - some_number, another_number = 1.9, 2.6 + some_number, another_number = 1.9, 5.6 for f in (*, /), order in (1, 4) fdm = central_fdm(12, order) closure = x -> exp(f(x, another_number)) @test derivative(closure, some_number, order)≈fdm(closure, some_number) rtol=1e-6 end - for f in (x -> x^7, x -> x^another_number), order in (2, 4) + for f in (x -> x^7, x -> x^another_number), order in (1, 2, 4) fdm = central_fdm(12, order) @test derivative(f, some_number, order)≈fdm(f, some_number) rtol=1e-6 end + for f in (x -> x^7, x -> x^another_number), order in (1, 2) + fdm = forward_fdm(12, order) + @test derivative(f, 0, order)≈fdm(f, 0) atol=1e-6 + end end