From 80946d76c86b17b6ed6f48d5155ad03bca059694 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 11:07:07 +0200 Subject: [PATCH 1/4] done? --- CHANGELOG.md | 6 ++++++ Project.toml | 3 ++- src/JointEnergyModels.jl | 4 ++-- src/samplers.jl | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e56cdc3..24980b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.1.4]. +## Version [0.1.6] - 2024-09-09 + +### Changed + +- Now depending on new `EnergySamplers` package for energy-based sampling. [#27] + ## Version [0.1.5] - 2024-06-07 ### Changed diff --git a/Project.toml b/Project.toml index ab07394..dfc01f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "JointEnergyModels" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" authors = ["Patrick Altmeyer"] -version = "0.1.5" +version = "0.1.6" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +EnergySamplers = "f446124b-5d5e-4171-a6dd-a1d99768d3ce" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/src/JointEnergyModels.jl b/src/JointEnergyModels.jl index 127df19..fa80d52 100644 --- a/src/JointEnergyModels.jl +++ b/src/JointEnergyModels.jl @@ -1,11 +1,11 @@ module JointEnergyModels +using EnergySamplers using Flux using TaijaBase -using TaijaBase.Samplers using Reexport -@reexport import TaijaBase.Samplers: ConditionalSampler, UnconditionalSampler, JointSampler +@reexport import EnergySamplers: ConditionalSampler, UnconditionalSampler, JointSampler include("utils.jl") export _energy diff --git a/src/samplers.jl b/src/samplers.jl index e48a06b..cba3f2f 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -10,7 +10,7 @@ using Distributions Outer constructor for `ConditionalSampler`. """ -function TaijaBase.Samplers.ConditionalSampler( +function EnergySamplers.ConditionalSampler( X::Union{Tables.MatrixTable,AbstractMatrix}, y::Union{CategoricalArray,AbstractMatrix}; batch_size::Int = 1, From 96d6f1e8867fff87fabbec2e76a977af8ce151d5 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 11:11:20 +0200 Subject: [PATCH 2/4] fixing error --- .github/workflows/CI.yml | 1 - Project.toml | 6 +++--- src/model.jl | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index cf24cff..c075ffa 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1.10' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index dfc01f5..2d86a83 100644 --- a/Project.toml +++ b/Project.toml @@ -32,14 +32,14 @@ MLJFlux = "0.2, 0.3, 0.4.0" MLJModelInterface = "1.8" MLUtils = "0.4" ProgressMeter = "1.7" -Random = "1.6, 1.10" +Random = "1.10" Reexport = "1.2.2" StatsBase = "0.33, 0.34" Tables = "1.10" TaijaBase = "1.1.0" -Test = "1.6, 1.10" +Test = "1.10" Zygote = "0.6" -julia = "1.6, 1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/src/model.jl b/src/model.jl index 28648ec..06f69b2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,7 +1,7 @@ using ChainRulesCore using Flux using Flux.Losses: logitcrossentropy -using TaijaBase.Samplers: ImproperSGLD +using EnergySamplers: ImproperSGLD struct JointEnergyModel chain::Chain From 7fc332e1763136f5d476d153374fe451036bc68e Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 11:23:08 +0200 Subject: [PATCH 3/4] now then? --- .gitignore | 2 +- Project.toml | 3 +- test/Manifest.toml | 233 +++++++++++++++++++++++++-------------------- test/Project.toml | 7 +- 4 files changed, 140 insertions(+), 105 deletions(-) diff --git a/.gitignore b/.gitignore index 313a8c9..badef0d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ /.quarto/ -Manifest.toml +**/Manifest.toml diff --git a/Project.toml b/Project.toml index 2d86a83..1a26559 100644 --- a/Project.toml +++ b/Project.toml @@ -27,8 +27,9 @@ CategoricalArrays = "0.10" ChainRulesCore = "1.16" ComputationalResources = "0.3" Distributions = "0.25" +EnergySamplers = "1.0.0" Flux = "0.13, 0.14" -MLJFlux = "0.2, 0.3, 0.4.0" +MLJFlux = "0.2, 0.3, 0.4, 0.5" MLJModelInterface = "1.8" MLUtils = "0.4" ProgressMeter = "1.7" diff --git a/test/Manifest.toml b/test/Manifest.toml index be83031..0ef2c84 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.5" manifest_format = "2.0" project_hash = "e34d8ebf030e8e97a7b26a2e86c6cd2fc20b9053" @@ -15,6 +15,27 @@ weakdeps = ["ChainRulesCore", "Test"] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" AbstractFFTsTestExt = "Test" +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.37" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" @@ -61,16 +82,17 @@ uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" version = "0.3.9" [[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.40" +version = "0.4.3" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" BangBangDataFramesExt = "DataFrames" BangBangStaticArraysExt = "StaticArrays" BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" BangBangTypedTablesExt = "TypedTables" [deps.BangBang.weakdeps] @@ -78,6 +100,7 @@ version = "0.3.40" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [[deps.Base64]] @@ -93,12 +116,6 @@ git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.5.0" -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" @@ -131,9 +148,9 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "5ec157747036038ec70b250f578362268f0472f1" +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.68.0" +version = "1.69.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] @@ -152,16 +169,16 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" version = "0.11.5" [[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" +version = "0.3.1" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" +version = "4.16.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -182,30 +199,29 @@ version = "1.1.1+0" git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" +weakdeps = ["InverseFunctions"] [deps.CompositionsBase.extensions] CompositionsBaseInverseFunctionsExt = "InverseFunctions" - [deps.CompositionsBase.weakdeps] - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.ComputationalResources]] git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.5" +version = "1.5.8" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] @@ -268,9 +284,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +git-tree-sha1 = "e6c693a0e4394f8fda0e51a5bdf5aef26f8235e9" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.109" +version = "0.25.111" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -293,17 +309,11 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.1" +version = "0.2.2" [[deps.FLoopsBase]] deps = ["ContextVariablesX"] @@ -322,9 +332,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -340,19 +350,21 @@ version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" +git-tree-sha1 = "fbf100b4bed74c9b6fac0ebd1031e04977d35b3b" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.15" +version = "0.14.19" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxEnzymeExt = "Enzyme" FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -368,9 +380,9 @@ weakdeps = ["StaticArrays"] [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" +git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.10" +version = "0.4.12" [[deps.Future]] deps = ["Random"] @@ -378,9 +390,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "38cb19b8a3e600e509dc36a6396ac74266d108c1" +git-tree-sha1 = "62ee71528cca49be797076a76bdc654a170a523e" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.1.1" +version = "10.3.1" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -389,10 +401,10 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.6" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" +version = "0.3.24" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools"] @@ -409,6 +421,16 @@ version = "0.3.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InverseFunctions]] +git-tree-sha1 = "2787db24f4e03daf859c6509ff87764e4182f7d1" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.16" +weakdeps = ["Dates", "Test"] + + [deps.InverseFunctions.extensions] + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" + [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" @@ -425,16 +447,16 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "bdbe8222d2f5703ad6a7019277d149ec6d78c301" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.48" +version = "0.4.53" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -443,22 +465,26 @@ uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" [[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272" +deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.19" +version = "0.9.25" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" + LinearAlgebraExt = "LinearAlgebra" + SparseArraysExt = "SparseArrays" [deps.KernelAbstractions.weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24" +git-tree-sha1 = "b351d72436ddecd27381a07c242ba27282a6c8a7" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "7.1.0" +version = "9.0.0" [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" @@ -468,9 +494,9 @@ version = "7.1.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +git-tree-sha1 = "f42bec1e12f42ec251541f6d0482d520a4638b17" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" +version = "0.0.33+0" [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" @@ -539,9 +565,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "24e5d28b2ea86b3feb6af5a5735f012d62e27b65" +git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "1.4.0" +version = "1.7.0" [deps.MLJBase.extensions] DefaultMeasuresExt = "StatisticalMeasures" @@ -550,16 +576,16 @@ version = "1.4.0" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" [[deps.MLJFlux]] -deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "72935b7de07a7f6b72fd49ecc7898dac79248d46" +deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"] +git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.4.0" +version = "0.5.1" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "88ef480f46e0506143681b3fb14d86742f3cecb1" +git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.10.0" +version = "1.11.0" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -600,10 +626,10 @@ version = "0.9.3" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.4" +version = "0.2.0" [[deps.Missings]] deps = ["DataAPI"] @@ -619,21 +645,25 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "4a83c2e01027a0bfcea28589222f2df60b2e20cb" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.17" +version = "0.9.23" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NaNMath]] @@ -743,20 +773,26 @@ version = "0.1.4" [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] -git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" +git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.0" +version = "1.10.2" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.4" +version = "2.11.0" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -797,9 +833,9 @@ version = "0.7.1" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +git-tree-sha1 = "e60724fd3beea548353984dc61c943ecddb0e29a" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.2+0" +version = "0.4.3+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -874,9 +910,9 @@ version = "0.1.15" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "9ae599cd7529cfce7fea36cf00a62cfc56f0f37c" +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.4" +version = "1.9.7" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -884,9 +920,9 @@ weakdeps = ["ChainRulesCore", "Statistics"] StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" +version = "1.4.3" [[deps.StatisticalMeasuresBase]] deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"] @@ -896,9 +932,9 @@ version = "0.1.1" [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] -git-tree-sha1 = "983c41a0ddd6c19f5607ca87271d7c7620ab5d50" +git-tree-sha1 = "542d979f6e756f13f862aa00b224f04f9e445f11" uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.3.0" +version = "3.4.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -922,15 +958,12 @@ deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Re git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "1.3.1" +weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] StatsFunsChainRulesCoreExt = "ChainRulesCore" StatsFunsInverseFunctionsExt = "InverseFunctions" - [deps.StatsFuns.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.StringManipulation]] deps = ["PrecompileTools"] git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" @@ -971,10 +1004,10 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.1" +version = "1.12.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -986,19 +1019,15 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TranscodingStreams]] -git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0" +git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.9" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] +version = "0.11.2" [[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.80" +version = "0.4.82" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -1033,9 +1062,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" +git-tree-sha1 = "2d17fabcd17e67d7625ce9c531fb9f40b7c42ce4" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.4" +version = "0.2.1" [[deps.Zlib_jll]] deps = ["Libdl"] @@ -1067,7 +1096,7 @@ version = "0.2.5" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/test/Project.toml b/test/Project.toml index 7e3f166..90ab8fc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,4 +9,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -julia = "1.6, 1.10" \ No newline at end of file +CompatHelperLocal = "0.1" +Distributions = "0.25" +Flux = "0.14" +MLJBase = "1.7" +MLJFlux = "0.5" +julia = "1.10" \ No newline at end of file From 89512f224c69601df6b1dfb889c7823a701181b0 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 11:32:37 +0200 Subject: [PATCH 4/4] tryin again --- Project.toml | 2 ++ src/mlj_flux.jl | 3 ++- test/Manifest.toml | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1a26559..3c9d579 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -32,6 +33,7 @@ Flux = "0.13, 0.14" MLJFlux = "0.2, 0.3, 0.4, 0.5" MLJModelInterface = "1.8" MLUtils = "0.4" +Optimisers = "0.3" ProgressMeter = "1.7" Random = "1.10" Reexport = "1.2.2" diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl index cbb3d4e..a041b8c 100644 --- a/src/mlj_flux.jl +++ b/src/mlj_flux.jl @@ -2,6 +2,7 @@ using ComputationalResources using Flux using MLJFlux import MLJModelInterface as MMI +using Optimisers using ProgressMeter using Random using Tables @@ -31,7 +32,7 @@ function JointEnergyClassifier( sampler::AbstractSampler; builder::B = default_builder_jem, finaliser::F = Flux.softmax, - optimiser::O = Flux.Optimise.Adam(), + optimiser::O = Optimisers.Adam(), loss::L = Flux.crossentropy, epochs::Int = 100, batch_size::Int = 100, diff --git a/test/Manifest.toml b/test/Manifest.toml index 0ef2c84..8100a00 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "e34d8ebf030e8e97a7b26a2e86c6cd2fc20b9053" +project_hash = "e5df3295d4fa5526130a8cdbd54efe647549ccfe" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"]