Skip to content

Commit

Permalink
Add support for power function at 0; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Sep 27, 2023
1 parent e4d943c commit 32a41f1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
42 changes: 21 additions & 21 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,37 +110,37 @@ end
@generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: Number, T, N}
ex = quote
v = value(t)
v1 = ^(v[1], n)
w11 = 1
u1 = ^(v[1], n)
end
for i in 2:N
for k in 1:N
ex = quote
$ex
$(Symbol('v', i)) = +($([:((n * $(binomial(i - 2, j - 1)) -
$(binomial(i - 2, j - 2))) * $(Symbol('v', j)) *
v[$(i + 1 - j)])
for j in 1:(i - 1)]...)) / v[1]
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
end
end
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
return :(@inbounds $ex)
end

@generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: Integer, T, N}
# TODO: optimize for small powers
ex = quote
v = value(t)
v1 = ^(v[1], n)
end
for i in 2:N
subex = quote
$(Symbol('w', i, 1)) = 0
end
for k in 2:i
subex = quote
$subex
$(Symbol('w', i, k)) = +($([:((n * $(binomial(i - 2, j - 1)) -
$(binomial(i - 2, j - 2))) *
$(Symbol('w', j, k - 1)) *
v[$(i + 1 - j)])
for j in (k - 1):(i - 1)]...))
end
end
ex = quote
$ex
$(Symbol('v', i)) = +($([:((n * $(binomial(i - 2, j - 1)) -
$(binomial(i - 2, j - 2))) * $(Symbol('v', j)) *
v[$(i + 1 - j)])
for j in 1:(i - 1)]...)) / v[1]
$subex
$(Symbol('u', i)) = +($([:($(Symbol('w', i, k)) * $(Symbol('p', k)))
for k in 2:i]...))
end
end
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...)))
return :(@inbounds $ex)
end

Expand Down
8 changes: 6 additions & 2 deletions test/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ end
end

@testset "Binary functions" begin
some_number, another_number = 1.9, 2.6
some_number, another_number = 1.9, 5.6
for f in (*, /), order in (1, 4)
fdm = central_fdm(12, order)
closure = x -> exp(f(x, another_number))
@test derivative(closure, some_number, order)fdm(closure, some_number) rtol=1e-6
end
for f in (x -> x^7, x -> x^another_number), order in (2, 4)
for f in (x -> x^7, x -> x^another_number), order in (1, 2, 4)
fdm = central_fdm(12, order)
@test derivative(f, some_number, order)fdm(f, some_number) rtol=1e-6
end
for f in (x -> x^7, x -> x^another_number), order in (1, 2)
fdm = forward_fdm(12, order)
@test derivative(f, 0, order)fdm(f, 0) atol=1e-6
end
end

0 comments on commit 32a41f1

Please sign in to comment.