Skip to content

Commit

Permalink
Merge pull request #22 from dahong67/gcp-dispatch
Browse files Browse the repository at this point in the history
Change design for dispatch
  • Loading branch information
dahong67 authored Oct 12, 2023
2 parents 6c3c18d + 4b5431a commit 34db0c6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
17 changes: 5 additions & 12 deletions ext/LossFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
module LossFunctionsExt

using GCPDecompositions, LossFunctions
import GCPDecompositions: _factor_matrix_lower_bound
using IntervalSets

const SupportedLosses = Union{LossFunctions.DistanceLoss,LossFunctions.MarginLoss}

GCPDecompositions.gcp(X::Array, r, loss::SupportedLosses) = GCPDecompositions._gcp(
X,
r,
(x, m) -> loss(m, x),
(x, m) -> LossFunctions.deriv(loss, m, x),
_factor_matrix_lower_bound(loss),
(;),
)

_factor_matrix_lower_bound(::LossFunctions.DistanceLoss) = -Inf
_factor_matrix_lower_bound(::LossFunctions.MarginLoss) = -Inf
GCPDecompositions.value(loss::SupportedLosses, x, m) = loss(m, x)
GCPDecompositions.deriv(loss::SupportedLosses, x, m) = LossFunctions.deriv(loss, m, x)
GCPDecompositions.domain(::LossFunctions.DistanceLoss) = Interval(-Inf, Inf)
GCPDecompositions.domain(::LossFunctions.MarginLoss) = Interval(-Inf, Inf)

end
34 changes: 19 additions & 15 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@ to see what losses are supported.
See also: `CPD`, `AbstractLoss`.
"""
gcp(X::Array, r, loss::AbstractLoss = LeastSquaresLoss()) = _gcp(
X,
r,
(x, m) -> value(loss, x, m),
(x, m) -> deriv(loss, x, m),
_factor_matrix_lower_bound(loss),
(;),
)
gcp(X::Array, r, loss = LeastSquaresLoss()) = _gcp(X, r, loss, (;))

# Choose lower bound on factor matrix entries based on the domain of the loss
function _factor_matrix_lower_bound(loss)
Expand Down Expand Up @@ -52,10 +45,21 @@ function _factor_matrix_lower_bound(loss)
return min
end

function _gcp(X::Array{TX,N}, r, func, grad, lower, lbfgsopts) where {TX,N}
# TODO: remove this `func, grad, lower` signature
# will require reworking how we do testing
_gcp(X::Array{TX,N}, r, func, grad, lower, lbfgsopts) where {TX,N} = _gcp(
X,
r,
UserDefinedLoss(func; deriv = grad, domain = Interval(lower, +Inf)),
lbfgsopts,
)
function _gcp(X::Array{TX,N}, r, loss, lbfgsopts) where {TX,N}
# T = promote_type(nonmissingtype(TX), Float64)
T = Float64 # LBFGSB.jl seems to only support Float64

# Choose lower bound on factor matrix entries based on the domain of the loss
lower = _factor_matrix_lower_bound(loss)

# Random initialization
M0 = CPD(ones(T, r), rand.(T, size(X), r))
M0norm = sqrt(sum(abs2, M0[I] for I in CartesianIndices(size(M0))))
Expand All @@ -70,12 +74,12 @@ function _gcp(X::Array{TX,N}, r, func, grad, lower, lbfgsopts) where {TX,N}
vec_ranges = ntuple(k -> vec_cutoffs[k]+1:vec_cutoffs[k+1], Val(N))
function f(u)
U = map(range -> reshape(view(u, range), :, r), vec_ranges)
return gcp_func(CPD(ones(T, r), U), X, func)
return gcp_func(CPD(ones(T, r), U), X, loss)
end
function g!(gu, u)
U = map(range -> reshape(view(u, range), :, r), vec_ranges)
GU = map(range -> reshape(view(gu, range), :, r), vec_ranges)
gcp_grad_U!(GU, CPD(ones(T, r), U), X, grad)
gcp_grad_U!(GU, CPD(ones(T, r), U), X, loss)
return gu
end

Expand All @@ -87,18 +91,18 @@ function _gcp(X::Array{TX,N}, r, func, grad, lower, lbfgsopts) where {TX,N}
end

# Objective function and gradient (w.r.t. `M.U`)
function gcp_func(M::CPD{T,N}, X::Array{TX,N}, func) where {T,TX,N}
return sum(func(X[I], M[I]) for I in CartesianIndices(X) if !ismissing(X[I]))
function gcp_func(M::CPD{T,N}, X::Array{TX,N}, loss) where {T,TX,N}
return sum(value(loss, X[I], M[I]) for I in CartesianIndices(X) if !ismissing(X[I]))
end

function gcp_grad_U!(
GU::NTuple{N,TGU},
M::CPD{T,N},
X::Array{TX,N},
grad,
loss,
) where {T,TX,N,TGU<:AbstractMatrix{T}}
Y = [
ismissing(X[I]) ? zero(nonmissingtype(eltype(X))) : grad(X[I], M[I]) for
ismissing(X[I]) ? zero(nonmissingtype(eltype(X))) : deriv(loss, X[I], M[I]) for
I in CartesianIndices(X)
]

Expand Down

0 comments on commit 34db0c6

Please sign in to comment.