From f8f159703186299b299225e6f79c1c6f78f76e9e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 29 Oct 2024 14:57:02 -0400 Subject: [PATCH] fix aliasing --- src/common.jl | 53 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/src/common.jl b/src/common.jl index 4146f4ef..e486046c 100644 --- a/src/common.jl +++ b/src/common.jl @@ -139,8 +139,6 @@ __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltyp function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; - alias_A = default_alias_A(alg, prob.A, prob.b), - alias_b = default_alias_b(alg, prob.A, prob.b), abstol = default_tol(real(eltype(prob.b))), reltol = default_tol(real(eltype(prob.b))), maxiters::Int = length(prob.b), @@ -149,23 +147,50 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, Pr = nothing, assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), - alias = LinearAliases(), + alias = LinearAliasSpecifier(), kwargs...) (;A, b, u0, p) = prob - if isnothing(alias.alias_A) - to_alias_A = alias_A + has_A = haskey(kwargs,:alias_A) + has_b = haskey(kwargs,:alias_b) + + if has_A || has_b + aliases = LinearAliasSpecifier() + if has_A + Base.depwarn("alias_A keyword argument is deprecated, to set `alias_A`, + please use a LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))", :alias_A) + SciMLBase.@reset aliases.alias_A = values(kwargs).alias_A + else + SciMLBase.@reset aliases.alias_A = default_alias_A(alg, prob.A, prob.b) + end + + if has_b + Base.depwarn("alias_b keyword argument is deprecated, to set `alias_b`, + please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))", :alias_b) + SciMLBase.@reset aliases.alias_b = values(kwargs).alias_b + else + SciMLBase.@reset aliases.alias_b = default_alias_b(alg, prob.A, prob.b) + end + + aliases else - to_alias_A = alias.alias_A + # If alias isa Bool, all fields of ODEAliases set to alias + if alias isa Bool + aliases = LinearAliasSpecifier(alias = alias) + elseif alias isa LinearAliasSpecifier || isnothing(alias) + aliases = alias + end + + if isnothing(aliases.alias_A) + SciMLBase.@reset aliases.alias_A = default_alias_A(alg,prob.A,prob.b) + end + if isnothing(aliases.alias_b) + SciMLBase.@reset aliases.alias_b = default_alias_b(alg,prob.A,prob.b) + end + aliases end - if isnothing(alias.alias_b) - to_alias_b = alias_b - else - to_alias_b = alias.alias_b - end - - A = if to_alias_A || A isa SMatrix + A = if aliases.alias_A || A isa SMatrix A elseif A isa Array copy(A) @@ -177,7 +202,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal) Array(b) # the solution to a linear solve will always be dense! - elseif to_alias_b || b isa SVector + elseif aliases.alias_b || b isa SVector b elseif b isa Array copy(b)