Skip to content

Commit

Permalink
chore(DataInterpolations): fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-b-b committed Oct 7, 2024
1 parent b9637bd commit 2b92662
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
17 changes: 8 additions & 9 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,10 @@ function CubicSpline(u::uType,
end

function CubicSpline(u::uType,
t;
extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractMatrix}
t;
extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractMatrix}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
Expand All @@ -474,12 +474,13 @@ function CubicSpline(u::uType,
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
ax = axes(u)[1:end-1]
ax = axes(u)[1:(end - 1)]
typed_zero = zero(6(u[ax..., begin + 2] - u[ax..., begin + 1]) / h[begin + 2] -
6(u[ax..., begin + 1] - u[ax..., begin]) / h[begin + 1])
6(u[ax..., begin + 1] - u[ax..., begin]) / h[begin + 1])

h_ = reshape(h, 1, :)
d = 6*((u[ax..., 3:n+1] - u[ax..., 2:n]) ./ h_[:, 3:n+1]) - 6*((u[ax..., 2:n] - u[ax..., 1:n-1]) ./ h_[:, 2:n])
d = 6 * ((u[ax..., 3:(n + 1)] - u[ax..., 2:n]) ./ h_[:, 3:(n + 1)]) -
6 * ((u[ax..., 2:n] - u[ax..., 1:(n - 1)]) ./ h_[:, 2:n])
d = cat(typed_zero, d, typed_zero; dims = ndims(d))

z = (tA \ d')'
Expand Down Expand Up @@ -517,8 +518,6 @@ function CubicSpline(
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, assume_linear_t)
end



"""
BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true)
Expand Down
2 changes: 1 addition & 1 deletion src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function cubic_spline_parameters(u::AbstractVector, h, z, idx)
end

function cubic_spline_parameters(u::AbstractArray, h, z, idx)
ax = axes(u)[1:end-1]
ax = axes(u)[1:(end - 1)]
c₁ = (u[ax..., idx + 1] / h[idx + 1] - z[ax..., idx + 1] * h[idx + 1] / 6)
c₂ = (u[ax..., idx] / h[idx + 1] - z[ax..., idx] * h[idx + 1] / 6)
return c₁, c₂
Expand Down
6 changes: 3 additions & 3 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,11 @@ end
@test_throws DataInterpolations.ExtrapolationError A(-2.0)
@test_throws DataInterpolations.ExtrapolationError A(2.0)

@testset "AbstractMatrix" begin
t = 0.1:0.1:1.0
@testset "AbstractMatrix" begin
t = 0.1:0.1:1.0
u = [sin.(t) cos.(t)]' |> collect
c = CubicSpline(u, t)
t_test = 0.1:0.05:1.0
t_test = 0.1:0.05:1.0
u_test = reduce(hcat, c.(t_test))
@test isapprox(u_test[1, :], sin.(t_test), atol = 1e-3)
@test isapprox(u_test[2, :], cos.(t_test), atol = 1e-3)
Expand Down

0 comments on commit 2b92662

Please sign in to comment.