Skip to content

Commit

Permalink
Fix mixture sampling and deconv for static arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Affie committed Oct 11, 2023
1 parent 43dcfa3 commit 6f7878c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 6 deletions.
11 changes: 10 additions & 1 deletion src/Factors/Mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,19 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
#out memory should be right size first
length(cf.factor.labels) != N ? resize!(cf.factor.labels, N) : nothing
cf.factor.labels .= rand(cf.factor.diversity, N)
M = cf.manifold

# mixture needs to be refactored so let's make it worse :-)
if cf.factor.mechanics isa AbstractPrior
samplef = samplePoint
elseif cf.factor.mechanics isa AbstractRelative
samplef = sampleTangent
end

for i = 1:N
mixComponent = cf.factor.components[cf.factor.labels[i]]
# measurements relate to the factor's manifold (either tangent vector or manifold point)
setPointsMani!(smpls[i], rand(mixComponent, 1))
setPointsMani!(smpls, samplef(M, mixComponent), i)
end

# TODO only does first element of meas::Tuple at this stage, see #1099
Expand Down
8 changes: 8 additions & 0 deletions src/entities/AliasScalarSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ struct AliasingScalarSampler
end
end

function sampleTangent(
M::AbstractDecoratorManifold,
z::AliasingScalarSampler,
p = getPointIdentity(M),
)
return hat(M, p, SVector{manifold_dimension(M)}(rand(z)))
end

function rand!(ass::AliasingScalarSampler, smpls::Array{Float64})
StatsBase.alias_sample!(ass.domain, ass.weights, smpls)
return nothing
Expand Down
2 changes: 1 addition & 1 deletion src/services/DeconvUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function approxDeconv(
measurement[idx] = ts
else
ts = _solveLambdaNumeric(fcttype, hypoObj, res_, measurement[idx], islen1)
copyto!(target_smpl, ts)
measurement[idx] = ts
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/services/NumericalCalculations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ function _solveLambdaNumericMeas(
M = getManifold(variableType)#fcttype.M
# the variable is a manifold point, we are working on the tangent plane in optim for now.
ϵ = getPointIdentity(variableType)
X0c = vee(M, ϵ, u0)
X0c = zeros(manifold_dimension(M))

function cost(Xc)
X = hat(M, ϵ, Xc)
Expand Down
3 changes: 2 additions & 1 deletion test/testSpecialEuclidean2Mani.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ pred, meas = approxDeconv(fg, :x0x1f1)

p_t = map(x->x.x[1], pred)
m_t = map(x->x.x[1], meas)
p_θ = map(x->x.x[2][2], pred)
#TODO why is angle wrapping around? (after SA update?)
p_θ = map(x->Manifolds.sym_rem(x.x[2][2]), pred)
m_θ = map(x->x.x[2][2], meas)

@test isapprox(mean(p_θ), 0.1, atol=0.02)
Expand Down
4 changes: 2 additions & 2 deletions test/testUseMsgLikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ fT = getFactorType(fct)

M = getManifold(fT)
X = sampleTangent(M, fT.Z)
@test X isa Vector{<:Real}
@test X isa AbstractVector{<:Real}

z = sampleFactor(fct)[1]
@test z isa Vector{<:Real}
@test z isa AbstractVector{<:Real}

##

Expand Down

0 comments on commit 6f7878c

Please sign in to comment.