Skip to content

Commit

Permalink
Merge pull request #14 from dahong67/restrict-lossfunctions
Browse files Browse the repository at this point in the history
Restrict LossFunctions.jl support to DistanceLoss and MarginLoss
  • Loading branch information
dahong67 authored Oct 10, 2023
2 parents c4fa875 + 2db3cd3 commit b0b1a3b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ext/LossFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module LossFunctionsExt

using GCPDecompositions, LossFunctions

const SupportedLosses = Union{LossFunctions.DistanceLoss,LossFunctions.MarginLoss}

"""
gcp(X::Array, r, loss::LossFunctions.SupervisedLoss[, lower]) -> CPD
Expand All @@ -15,7 +17,7 @@ with respect to the loss function `loss` and return a `CPD` object.
- `loss` : loss function from LossFunctions.jl
- `lower` : lower bound for factor matrix entries, `default = -Inf`
"""
GCPDecompositions.gcp(X::Array, r, loss::LossFunctions.SupervisedLoss, lower = -Inf) =
GCPDecompositions.gcp(X::Array, r, loss::SupportedLosses, lower = -Inf) =
GCPDecompositions._gcp(
X,
r,
Expand Down

0 comments on commit b0b1a3b

Please sign in to comment.