Skip to content

Commit

Permalink
Some ParametricManopt.jl cleanup while reading
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann authored Mar 8, 2024
1 parent a962e9c commit 952e41a
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/parametric/services/ParametricManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ using SparseArrays
# using Zygote

##
function getVarIntLabelMap(vartypeslist::OrderedDict{DataType, Vector{Symbol}})
function getVarIntLabelMap(
vartypeslist::OrderedDict{DataType, Vector{Symbol}}
)
varlist_tuple = (values(vartypeslist)...,)
varlabelsAP = ArrayPartition{Symbol, typeof(varlist_tuple)}(varlist_tuple)
varIntLabel = OrderedDict(zip(varlabelsAP, collect(1:length(varlabelsAP))))
return varIntLabel, varlabelsAP
end

function CalcFactorResidual(fg, fct::DFGFactor, varIntLabel)
function CalcFactorResidual(
fg,
fct::DFGFactor,
varIntLabel
)
fac_func = getFactorType(fct)
varOrder = getVariableOrder(fct)

Expand Down Expand Up @@ -45,7 +51,11 @@ end
CalcFactorResidualAP
Create an `ArrayPartition` of `CalcFactorResidual`s.
"""
function CalcFactorResidualAP(fg::GraphsDFG, factorLabels::Vector{Symbol}, varIntLabel::OrderedDict{Symbol, Int64})
function CalcFactorResidualAP(
fg::GraphsDFG,
factorLabels::Vector{Symbol},
varIntLabel::OrderedDict{Symbol, Int64}
)
factypes, typedict, alltypes = getFactorTypesCount(getFactor.(fg, factorLabels))

# skip non-numeric prior (MetaPrior)
Expand Down Expand Up @@ -300,6 +310,7 @@ function solve_RLM(
faclabels = lsf(fg);
is_sparse = true,
finiteDiffCovariance = false,
solveKey::Symbol = :parametric,
kwargs...
)

Expand All @@ -312,7 +323,7 @@ function solve_RLM(

#Can use varIntLabel (because its an OrderedDict), but varLabelsAP makes the ArrayPartition.
p0 = map(varlabelsAP) do label
getVal(fg, label, solveKey = :parametric)[1]
getVal(fg, label; solveKey)[1]
end

# create an ArrayPartition{CalcFactorResidual} for faclabels
Expand Down Expand Up @@ -389,6 +400,7 @@ function solve_RLM_conditional(
separators::Vector{Symbol} = setdiff(ls(fg), frontals);
is_sparse=false,
finiteDiffCovariance=true,
solveKey::Symbol = :parametric,
kwargs...
)
is_sparse && error("Sparse solve_RLM_conditional not supported yet")
Expand Down Expand Up @@ -420,7 +432,7 @@ function solve_RLM_conditional(
all_varlabelsAP = ArrayPartition((frontal_varlabelsAP.x..., separator_varlabelsAP.x...))

all_points = map(all_varlabelsAP) do label
getVal(fg, label, solveKey = :parametric)[1]
getVal(fg, label; solveKey)[1]
end

p0 = ArrayPartition(all_points.x[1:length(frontal_varlabelsAP.x)])
Expand Down Expand Up @@ -537,7 +549,7 @@ function autoinitParametric!(
)
)
end
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; kwargs...)
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...)

val = lm_r[1]
vnd.val[1] = val
Expand All @@ -549,8 +561,8 @@ function autoinitParametric!(
vnd.initialized = true
#fill in ppe as mean
Xc::Vector{Float64} = collect(getCoordinates(getVariableType(xi), val))
ppe = MeanMaxPPE(:parametric, Xc, Xc, Xc)
getPPEDict(xi)[:parametric] = ppe
ppe = MeanMaxPPE(solveKey, Xc, Xc, Xc)
getPPEDict(xi)[solveKey] = ppe

result = vartypeslist, lm_r

Expand Down

0 comments on commit 952e41a

Please sign in to comment.