diff --git a/Project.toml b/Project.toml index 3d5ca18..18a2f46 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,16 @@ version = "0.1.0" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" [compat] ForwardDiff = "0.10.36" LBFGSB = "0.4.1" +LossFunctions = "0.11.1" julia = "1.6" + +[extensions] +LossFunctionsExt = "LossFunctions" + +[weakdeps] +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" diff --git a/ext/LossFunctionsExt.jl b/ext/LossFunctionsExt.jl new file mode 100644 index 0000000..46bdbd8 --- /dev/null +++ b/ext/LossFunctionsExt.jl @@ -0,0 +1,20 @@ +module LossFunctionsExt + +using GCPDecompositions, LossFunctions + +""" + gcp(X::Array, r, loss::LossFunctions.SupervisedLoss, lower]) -> CPD + +Compute an approximate rank-`r` CP decomposition of the tensor `X` +with respect to the loss function `loss` and return a `CPD` object. + +# Inputs ++ `X` : multi-dimensional tensor/array to approximate/decompose ++ `r` : number of components for the CPD ++ `loss` : loss function from LossFunctions.jl ++ `lower` : lower bound for factor matrix entries, `default = -Inf` +""" +GCPDecompositions.gcp(X::Array, r, loss::LossFunctions.SupervisedLoss, lower=-Inf) = + GCPDecompositions._gcp(X, r, (x, m) -> loss(m, x), (x, m) -> LossFunctions.deriv(loss, m, x), lower, (;)) + +end diff --git a/src/GCPDecompositions.jl b/src/GCPDecompositions.jl index 065e67d..b4aad5a 100644 --- a/src/GCPDecompositions.jl +++ b/src/GCPDecompositions.jl @@ -17,4 +17,8 @@ export gcp include("type-cpd.jl") include("gcp-opt.jl") +if !isdefined(Base, :get_extension) + include("../ext/LossFunctionsExt.jl") +end + end diff --git a/test/Manifest.toml b/test/Manifest.toml index 923e9cf..bad76a5 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.2" manifest_format = "2.0" -project_hash = "3b0347c951b114bc9938540028de5bc8d1fc29c5" +project_hash = "5b46ab02fc1fd08dc4db50c5af3027ccd5dc095b" [[deps.Adapt]] deps = ["LinearAlgebra"] @@ -39,6 +39,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.LossFunctions]] +deps = ["Markdown", "Requires", "Statistics"] +git-tree-sha1 = "df9da07efb9b05ca7ef701acec891ee8f73c99e2" +uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +version = "0.11.1" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -62,6 +68,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" @@ -69,6 +81,14 @@ version = "0.7.0" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" @@ -89,6 +109,10 @@ git-tree-sha1 = "8621ba2637b49748e2dc43ba3d84340be2938022" uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" version = "0.1.1" +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/test/Project.toml b/test/Project.toml index 2857e4c..b74a5dc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,11 @@ [deps] +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" [compat] +LossFunctions = "0.11.1" OffsetArrays = "1.12.7" TestItemRunner = "0.2.1" diff --git a/test/items/gcp-opt.jl b/test/items/gcp-opt.jl index 11a43cc..c41ede4 100644 --- a/test/items/gcp-opt.jl +++ b/test/items/gcp-opt.jl @@ -2,6 +2,7 @@ @testitem "least squares" begin using Random + using LossFunctions @testset "size(X)=$sz, rank(X)=$r" for sz in [(15, 20, 25), (30, 40, 50)], r in 1:2 Random.seed!(0) @@ -17,5 +18,8 @@ Mh = gcp(X, r) # test default (least-squares) loss @test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5 + + Mh = gcp(X, r, L2DistLoss()) # test with loss function from LossFunctions.jl (least squares) + @test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5 end end