Skip to content

Commit

Permalink
Merge pull request #25 from dahong67/dahong67/issue24
Browse files Browse the repository at this point in the history
Setup system for constraints
  • Loading branch information
dahong67 authored Oct 13, 2023
2 parents ba84a57 + ef7df98 commit d8e9a75
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/GCPDecompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ export AbstractLoss,
PoissonLoss,
PoissonLogLoss,
UserDefinedLoss
export GCPConstraints

include("type-cpd.jl")
include("type-losses.jl")
include("type-constraints.jl")
include("gcp-opt.jl")

if !isdefined(Base, :get_extension)
Expand Down
48 changes: 38 additions & 10 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

# Main fitting function
"""
gcp(X::Array, r, loss = LeastSquaresLoss()) -> CPD
gcp(X::Array, r, loss = LeastSquaresLoss();
constraints = default_constraints(loss)) -> CPD
Compute an approximate rank-`r` CP decomposition of the tensor `X`
with respect to the loss function `loss` and return a `CPD` object.
Conventional CP corresponds to the default `LeastSquaresLoss()`.
The weights `λ` are constrained to all be one and `constraints` is a
`Tuple` of constraints on the factor matrices `U = (U[1],...,U[N])`.
Conventional CP corresponds to the default `LeastSquaresLoss()` loss
with no constraints (i.e., `constraints = ()`).
If the LossFunctions.jl package is also loaded,
`loss` can also be a loss function from that package.
Expand All @@ -15,15 +19,16 @@ to see what losses are supported.
See also: `CPD`, `AbstractLoss`.
"""
gcp(X::Array, r, loss = LeastSquaresLoss()) = _gcp(X, r, loss, (;))
gcp(X::Array, r, loss = LeastSquaresLoss(); constraints = default_constraints(loss)) =
_gcp(X, r, loss, constraints, (;))

# Choose lower bound on factor matrix entries based on the domain of the loss function
function _factor_matrix_lower_bound(loss)
# Choose constraints based on the domain of the loss function
function default_constraints(loss)
dom = domain(loss)
if dom == Interval(-Inf, +Inf)
return -Inf
return ()
elseif dom == Interval(0.0, +Inf)
return 0.0
return (GCPConstraints.LowerBound(0.0),)
else
error(
"only loss functions with a domain of `-Inf .. Inf` or `0 .. Inf` are (currently) supported",
Expand All @@ -37,14 +42,37 @@ _gcp(X::Array{TX,N}, r, func, grad, lower, lbfgsopts) where {TX,N} = _gcp(
X,
r,
UserDefinedLoss(func; deriv = grad, domain = Interval(lower, +Inf)),
(GCPConstraints.LowerBound(lower),),
lbfgsopts,
)
function _gcp(X::Array{TX,N}, r, loss, lbfgsopts) where {TX,N}
function _gcp(
X::Array{TX,N},
r,
loss,
constraints::Tuple{Vararg{GCPConstraints.LowerBound}},
lbfgsopts,
) where {TX,N}
# T = promote_type(nonmissingtype(TX), Float64)
T = Float64 # LBFGSB.jl seems to only support Float64

# Choose lower bound on factor matrix entries based on the domain of the loss
lower = _factor_matrix_lower_bound(loss)
# Compute lower bound from constraints
lower = maximum(constraint.value for constraint in constraints; init = T(-Inf))

# Error for unsupported loss/constraint combinations
dom = domain(loss)
if dom == Interval(-Inf, +Inf)
lower in (-Inf, 0.0) || error(
"only lower bound constraints of `-Inf` or `0` are (currently) supported for loss functions with a domain of `-Inf .. Inf`",
)
elseif dom == Interval(0.0, +Inf)
lower == 0.0 || error(
"only lower bound constraints of `0` are (currently) supported for loss functions with a domain of `0 .. Inf`",
)
else
error(
"only loss functions with a domain of `-Inf .. Inf` or `0 .. Inf` are (currently) supported",
)
end

# Random initialization
M0 = CPD(ones(T, r), rand.(T, size(X), r))
Expand Down
26 changes: 26 additions & 0 deletions src/type-constraints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## Constraint types

module GCPConstraints

# Abstract type

"""
AbstractConstraint
Abstract type for GCP constraints on the factor matrices `U = (U[1],...,U[N])`.
"""
abstract type AbstractConstraint end

# Concrete types

"""
LowerBound(value::Real)
Lower-bound constraint on the entries of the factor matrices
`U = (U[1],...,U[N])`, i.e., `U[i][j,k] >= value`.
"""
struct LowerBound{T} <: AbstractConstraint
value::T
end

end
32 changes: 32 additions & 0 deletions test/items/gcp-opt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
## GCP decomposition - full optimization

@testitem "unsupported constraints" begin
using Random, IntervalSets

sz = (15, 20, 25)
r = 2
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]

# Exercise `default_constraints`
@test_throws ErrorException gcp(
X,
r,
UserDefinedLoss((x, m) -> (x - m)^2; domain = Interval(1, Inf)),
)

# Exercise `_gcp`
@test_throws ErrorException gcp(
X,
r,
LeastSquaresLoss();
constraints = (GCPConstraints.LowerBound(1),),
)
@test_throws ErrorException gcp(X, r, PoissonLoss(); constraints = ())
@test_throws ErrorException gcp(
X,
r,
UserDefinedLoss((x, m) -> (x - m)^2; domain = Interval(1, Inf));
constraints = (GCPConstraints.LowerBound(1),),
)
end

@testitem "LeastSquaresLoss" begin
using Random

Expand Down

0 comments on commit d8e9a75

Please sign in to comment.