Skip to content

Commit

Permalink
Dont infer what is hard to infer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 20, 2024
1 parent 01a789b commit 1dd64ac
Showing 1 changed file with 3 additions and 28 deletions.
31 changes: 3 additions & 28 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ every solve call.
- `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or
`Val(false)` for type stability. By default this is automatically inferred based on the
size of the input and outputs, however this is type unstable for any array type that
doesn't store array size as part of type information. Note that if problem is inplace
and `bcresid_prototype` in BVPFunction is not specified, then `nlls` is assumed to be
`false`.
doesn't store array size as part of type information. If we can't reliably infer this,
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
"""
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace, nlls}
Expand Down Expand Up @@ -156,11 +155,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
if f.bcresid_prototype !== nothing
_nlls = length(f.bcresid_prototype) != length(_u0)
else
if iip
_nlls = false # Should we assume `true` instead?
else
_nlls = length(f.bc(FakeSolutionObject(u0), p, tspan)) != length(_u0)
end
_nlls = Nothing # Cannot reliably infer
end
end
else
Expand All @@ -176,26 +171,6 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
end
end

"""
isnonlinearleastsquares(prob::BVProblem)
Returns `true` if the underlying problem is a nonlinear least squares problem.
"""
@inline function isnonlinearleastsquares(::BVProblem{uType,
tType, iip, nlls}) where {uType, tType, iip, nlls}
return nlls
end

struct FakeSolutionObject{U}
u::U
end

(sol::FakeSolutionObject)(t) = sol.u
Base.length(::FakeSolutionObject) = 1
Base.firstindex(::FakeSolutionObject) = 1
Base.lastindex(::FakeSolutionObject) = 1
Base.getindex(sol::FakeSolutionObject, i::Int) = sol.u

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
Expand Down

0 comments on commit 1dd64ac

Please sign in to comment.