Skip to content

Commit

Permalink
Merge pull request #11 from alexmul1114/lossfunctions
Browse files Browse the repository at this point in the history
Add compatibility with LossFunctions.jl, test with least squares
  • Loading branch information
dahong67 authored Oct 9, 2023
2 parents cfd310f + 3ca2ec5 commit 4f3ef37
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 1 deletion.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@ version = "0.1.0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"

[compat]
ForwardDiff = "0.10.36"
LBFGSB = "0.4.1"
LossFunctions = "0.11.1"
julia = "1.6"

[extensions]
LossFunctionsExt = "LossFunctions"

[weakdeps]
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
20 changes: 20 additions & 0 deletions ext/LossFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module LossFunctionsExt

using GCPDecompositions, LossFunctions

"""
gcp(X::Array, r, loss::LossFunctions.SupervisedLoss, lower]) -> CPD
Compute an approximate rank-`r` CP decomposition of the tensor `X`
with respect to the loss function `loss` and return a `CPD` object.
# Inputs
+ `X` : multi-dimensional tensor/array to approximate/decompose
+ `r` : number of components for the CPD
+ `loss` : loss function from LossFunctions.jl
+ `lower` : lower bound for factor matrix entries, `default = -Inf`
"""
GCPDecompositions.gcp(X::Array, r, loss::LossFunctions.SupervisedLoss, lower=-Inf) =
GCPDecompositions._gcp(X, r, (x, m) -> loss(m, x), (x, m) -> LossFunctions.deriv(loss, m, x), lower, (;))

end
4 changes: 4 additions & 0 deletions src/GCPDecompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ export gcp
include("type-cpd.jl")
include("gcp-opt.jl")

if !isdefined(Base, :get_extension)
include("../ext/LossFunctionsExt.jl")
end

end
26 changes: 25 additions & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.8.2"
manifest_format = "2.0"
project_hash = "3b0347c951b114bc9938540028de5bc8d1fc29c5"
project_hash = "5b46ab02fc1fd08dc4db50c5af3027ccd5dc095b"

[[deps.Adapt]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -39,6 +39,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.LossFunctions]]
deps = ["Markdown", "Requires", "Statistics"]
git-tree-sha1 = "df9da07efb9b05ca7ef701acec891ee8f73c99e2"
uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.11.1"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand All @@ -62,13 +68,27 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
deps = ["SHA", "Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand All @@ -89,6 +109,10 @@ git-tree-sha1 = "8621ba2637b49748e2dc43ba3d84340be2938022"
uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
version = "0.1.1"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[compat]
LossFunctions = "0.11.1"
OffsetArrays = "1.12.7"
TestItemRunner = "0.2.1"
4 changes: 4 additions & 0 deletions test/items/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

@testitem "least squares" begin
using Random
using LossFunctions

@testset "size(X)=$sz, rank(X)=$r" for sz in [(15, 20, 25), (30, 40, 50)], r in 1:2
Random.seed!(0)
Expand All @@ -17,5 +18,8 @@

Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5

Mh = gcp(X, r, L2DistLoss()) # test with loss function from LossFunctions.jl (least squares)
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end
end

0 comments on commit 4f3ef37

Please sign in to comment.