diff --git a/src/parametric/services/ParametricManopt.jl b/src/parametric/services/ParametricManopt.jl index e1d14e74..75d43746 100644 --- a/src/parametric/services/ParametricManopt.jl +++ b/src/parametric/services/ParametricManopt.jl @@ -7,14 +7,20 @@ using SparseArrays # using Zygote ## -function getVarIntLabelMap(vartypeslist::OrderedDict{DataType, Vector{Symbol}}) +function getVarIntLabelMap( + vartypeslist::OrderedDict{DataType, Vector{Symbol}} +) varlist_tuple = (values(vartypeslist)...,) varlabelsAP = ArrayPartition{Symbol, typeof(varlist_tuple)}(varlist_tuple) varIntLabel = OrderedDict(zip(varlabelsAP, collect(1:length(varlabelsAP)))) return varIntLabel, varlabelsAP end -function CalcFactorResidual(fg, fct::DFGFactor, varIntLabel) +function CalcFactorResidual( + fg, + fct::DFGFactor, + varIntLabel +) fac_func = getFactorType(fct) varOrder = getVariableOrder(fct) @@ -45,7 +51,11 @@ end CalcFactorResidualAP Create an `ArrayPartition` of `CalcFactorResidual`s. """ -function CalcFactorResidualAP(fg::GraphsDFG, factorLabels::Vector{Symbol}, varIntLabel::OrderedDict{Symbol, Int64}) +function CalcFactorResidualAP( + fg::GraphsDFG, + factorLabels::Vector{Symbol}, + varIntLabel::OrderedDict{Symbol, Int64} +) factypes, typedict, alltypes = getFactorTypesCount(getFactor.(fg, factorLabels)) # skip non-numeric prior (MetaPrior) @@ -300,6 +310,7 @@ function solve_RLM( faclabels = lsf(fg); is_sparse = true, finiteDiffCovariance = false, + solveKey::Symbol = :parametric, kwargs... ) @@ -312,7 +323,7 @@ function solve_RLM( #Can use varIntLabel (because its an OrderedDict), but varLabelsAP makes the ArrayPartition. p0 = map(varlabelsAP) do label - getVal(fg, label, solveKey = :parametric)[1] + getVal(fg, label; solveKey)[1] end # create an ArrayPartition{CalcFactorResidual} for faclabels @@ -389,6 +400,7 @@ function solve_RLM_conditional( separators::Vector{Symbol} = setdiff(ls(fg), frontals); is_sparse=false, finiteDiffCovariance=true, + solveKey::Symbol = :parametric, kwargs... ) is_sparse && error("Sparse solve_RLM_conditional not supported yet") @@ -420,7 +432,7 @@ function solve_RLM_conditional( all_varlabelsAP = ArrayPartition((frontal_varlabelsAP.x..., separator_varlabelsAP.x...)) all_points = map(all_varlabelsAP) do label - getVal(fg, label, solveKey = :parametric)[1] + getVal(fg, label; solveKey)[1] end p0 = ArrayPartition(all_points.x[1:length(frontal_varlabelsAP.x)]) @@ -537,7 +549,7 @@ function autoinitParametric!( ) ) end - M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; kwargs...) + M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...) val = lm_r[1] vnd.val[1] = val @@ -549,8 +561,8 @@ function autoinitParametric!( vnd.initialized = true #fill in ppe as mean Xc::Vector{Float64} = collect(getCoordinates(getVariableType(xi), val)) - ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc) - getPPEDict(xi)[:parametric] = ppe + ppe = MeanMaxPPE(solveKey, Xc, Xc, Xc) + getPPEDict(xi)[solveKey] = ppe result = vartypeslist, lm_r