Skip to content

Commit

Permalink
Add a test for underdetermined systems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 14, 2023
1 parent 7868499 commit 9bc19de
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,19 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
resid_bc, y)

sd_collocation = if jac_alg.nonbc_diffmode isa AbstractSparseADType
J_full_band = BandedMatrix(Ones{eltype(y)}(L + cache.M * (N - 1), cache.M * N),
(L + 1, cache.M))
__sparsity_detection_alg(__generate_sparse_jacobian_prototype(cache,
cache.problem_type, y, y, cache.M, N))
if L < cache.M
# For underdetermined problems we use sparse since we don't have banded qr
colored_matrix = __generate_sparse_jacobian_prototype(cache,

Check warning on line 300 in src/solve/mirk.jl

View check run for this annotation

Codecov / codecov/patch

src/solve/mirk.jl#L300

Added line #L300 was not covered by tests
cache.problem_type, y, y, cache.M, N)
J_full_band = nothing
__sparsity_detection_alg(ColoredMatrix(sparse(colored_matrix.M),

Check warning on line 303 in src/solve/mirk.jl

View check run for this annotation

Codecov / codecov/patch

src/solve/mirk.jl#L302-L303

Added lines #L302 - L303 were not covered by tests
colored_matrix.row_colorvec, colored_matrix.col_colorvec))
else
J_full_band = BandedMatrix(Ones{eltype(y)}(L + cache.M * (N - 1), cache.M * N),
(L + 1, cache.M + max(cache.M - L, 0)))
__sparsity_detection_alg(__generate_sparse_jacobian_prototype(cache,
cache.problem_type, y, y, cache.M, N))
end
else
J_full_band = nothing
NoSparsityDetection()
Expand Down
112 changes: 112 additions & 0 deletions test/mirk/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,115 @@ using BoundaryValueDiffEq, LinearAlgebra, Test
@test norm(vcat(resida, residb)) < 1e-2
end
end

# This is not a very meaningful problem, but it tests that our solvers are not throwing an
# error
@testset "Underconstrained BVP: Rod BVP" begin
SOLVERS = [mirk(; nlsolve) for mirk in (MIRK4, MIRK5, MIRK6),
nlsolve in (LevenbergMarquardt(), GaussNewton(), nothing)]

function hat(y)
return [0 -y[3] y[2]
y[3] 0 -y[1]
-y[2] y[1] 0]
end

function inv_hat(skew)
[skew[3, 2]; skew[1, 3]; skew[2, 1]]
end

function rod_ode!(dy, y, p, t, Kse_inv, Kbt_inv, rho, A, g)
R = reshape(@view(y[4:12]), 3, 3)
n = @view y[13:15]
m = @view y[16:18]

v = Kse_inv * R' * n
v[3] += 1.0
u = Kbt_inv * R' * m
ps = R * v
@views dy[1:3] .= ps
@views dy[4:12] .= vec(R * hat(u))
@views dy[13:15] .= -rho * A * g
@views dy[16:18] .= -hat(ps) * n
end

function bc_a!(residual, y, p)
# Extract rotations from y
R0_u = reshape(@view(y[4:12]), 3, 3)

# Extract rotations from p
R0 = reshape(@view(p[4:12]), 3, 3)

@views residual[1:3] = y[1:3] .- p[1:3]
@views residual[4:6] = inv_hat(R0_u' * R0 - R0_u * R0')
return nothing
end

function bc_b!(residual, y, p)
# Extract rotations from y
RL_u = reshape(@view(y[4:12]), 3, 3)

# Extract rotations from p
RL = reshape(@view(p[16:24]), 3, 3)

@views residual[1:3] = y[1:3] .- p[13:15]
@views residual[4:6] = inv_hat(RL_u' * RL - RL_u * RL')
return nothing
end

function bc!(residual, sol, p, t)
y1 = first(sol)
y2 = last(sol)
R0_u = reshape(@view(y1[4:12]), 3, 3)
RL_u = reshape(@view(y2[4:12]), 3, 3)

# Extract rotations from p
R0 = reshape(@view(p[4:12]), 3, 3)
RL = reshape(@view(p[16:24]), 3, 3)

@views residual[1:3] = y1[1:3] .- p[1:3]
@views residual[4:6] = inv_hat(R0_u' * R0 - R0_u * R0')
@views residual[7:9] = y2[1:3] .- p[13:15]
@views residual[10:12] = inv_hat(RL_u' * RL - RL_u * RL')

return nothing
end

# Parameters
E = 200e9
G = 80e9
r = 0.001
rho = 8000
g = [9.81; 0; 0]
L = 0.5
A = pi * r^2
I = pi * r^4 / 4
J = 2 * I
Kse = diagm([G * A, G * A, E * A])
Kbt = diagm([E * I, E * I, G * J])

# Boundary Conditions
p0 = [0; 0; 0]
R0 = vec(LinearAlgebra.I(3))
pL = [0; -0.1 * L; 0.8 * L]
RL = vec(LinearAlgebra.I(3))

# Main Simulation
tspan = (0.0, L)
rod_ode!(dy, y, p, t) = rod_ode!(dy, y, p, t, inv(Kse), inv(Kbt), rho, A, g)
y0 = vcat(p0, R0, zeros(6))
p = vcat(p0, R0, pL, RL)
prob_tp = TwoPointBVProblem(rod_ode!, (bc_a!, bc_b!), y0, tspan, p,
bcresid_prototype = (zeros(6), zeros(6)))
prob = BVProblem(BVPFunction(rod_ode!, bc!; bcresid_prototype = zeros(12)), y0, tspan,
p)

for solver in SOLVERS
@time sol = solve(prob_tp, solver; verbose = false, dt = 0.1, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
@test SciMLBase.successful_retcode(sol.retcode)
@time sol = solve(prob, solver; verbose = false, dt = 0.1, abstol = 1e-3,
reltol = 1e-3, nlsolve_kwargs = (; maxiters = 50, abstol = 1e-3, reltol = 1e-3))
@test SciMLBase.successful_retcode(sol.retcode)
end
end

0 comments on commit 9bc19de

Please sign in to comment.