From f776b85aa38acb92ce530e70706446fb09cfc276 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Sat, 2 Sep 2023 08:11:25 +0000 Subject: [PATCH] clean up matrix APIs --- benchmark/Manifest.toml | 256 ++++++++++++++++++++++++++++++++++++---- src/derivative.jl | 51 ++++---- 2 files changed, 264 insertions(+), 43 deletions(-) diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 88ce710..73dc86b 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.4" +julia_version = "1.9.3" manifest_format = "2.0" project_hash = "b3e5f4cf27d760c93b2634d045748b8e8637f186" @@ -10,10 +10,15 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" version = "0.2.1" [[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" @@ -25,6 +30,10 @@ deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.6.2" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -41,6 +50,22 @@ git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" version = "7.4.11" + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -62,6 +87,20 @@ git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" version = "0.3.39" + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -140,10 +179,14 @@ uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" version = "0.1.4" [[deps.ChangesOfVariables]] -deps = ["InverseFunctions", "LinearAlgebra", "Test"] +deps = ["LinearAlgebra", "Test"] git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.8" +weakdeps = ["InverseFunctions"] + + [deps.ChangesOfVariables.extensions] + ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -163,20 +206,28 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] +deps = ["UUIDs"] git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "4.9.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.1+0" +version = "1.0.5+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" [[deps.ConcreteStructs]] git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" @@ -195,6 +246,14 @@ git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.3" + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.ContextVariablesX]] deps = ["Compat", "Logging", "UUIDs"] git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" @@ -228,7 +287,9 @@ version = "0.1.2" [[deps.DelimitedFiles]] deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" [[deps.DiffResults]] deps = ["StaticArraysCore"] @@ -290,10 +351,15 @@ version = "0.1.1" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +deps = ["LinearAlgebra", "Random"] git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "1.6.1" +weakdeps = ["SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" [[deps.Flux]] deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] @@ -301,11 +367,23 @@ git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" version = "0.13.17" + [deps.Flux.extensions] + AMDGPUExt = "AMDGPU" + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + [[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" [[deps.FunctionWrappers]] git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" @@ -411,6 +489,12 @@ git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" version = "0.9.8" + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" @@ -462,14 +546,20 @@ version = "1.10.2+0" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" version = "0.3.26" +weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -486,6 +576,27 @@ git-tree-sha1 = "78fecc38a73321df15161a481864fce75b66ae84" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" version = "0.5.4" + [deps.Lux.extensions] + LuxComponentArraysExt = "ComponentArrays" + LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] + LuxComponentArraysTrackerExt = ["ComponentArrays", "Tracker"] + LuxComponentArraysZygoteExt = ["ComponentArrays", "Zygote"] + LuxFluxTransformExt = "Flux" + LuxLuxAMDGPUExt = "LuxAMDGPU" + LuxLuxCUDAExt = "LuxCUDA" + LuxTrackerExt = "Tracker" + LuxZygoteExt = "Zygote" + + [deps.Lux.weakdeps] + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.LuxCore]] deps = ["DocStringExtensions", "Functors", "Random", "Setfield"] git-tree-sha1 = "f2dafe0ddcecf06247b40dbf336acd14e0adce6d" @@ -498,12 +609,41 @@ git-tree-sha1 = "e67d2206f6f05f534dccbed1df2b60e452ce4d0d" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" version = "0.1.7" + [deps.LuxDeviceUtils.extensions] + LuxDeviceUtilsComponentArraysExt = "ComponentArrays" + LuxDeviceUtilsFillArraysExt = "FillArrays" + LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" + LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" + LuxDeviceUtilsMetalExt = "Metal" + LuxDeviceUtilsZygoteExt = "Zygote" + + [deps.LuxDeviceUtils.weakdeps] + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.LuxLib]] deps = ["ChainRulesCore", "KernelAbstractions", "Markdown", "NNlib", "PackageExtensionCompat", "Random", "Reexport", "Statistics"] git-tree-sha1 = "06e1f04441a8835413b48c84c016313c16e1687b" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" version = "0.3.2" + [deps.LuxLib.extensions] + LuxLibForwardDiffExt = "ForwardDiff" + LuxLibLuxCUDAExt = "LuxCUDA" + LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] + LuxLibReverseDiffExt = "ReverseDiff" + LuxLibTrackerExt = "Tracker" + + [deps.LuxLib.weakdeps] + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -534,7 +674,7 @@ version = "1.1.7" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" +version = "2.28.2+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -553,7 +693,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" +version = "2022.10.11" [[deps.MultivariatePolynomials]] deps = ["ChainRulesCore", "DataStructures", "LinearAlgebra", "MutableArithmetics"] @@ -573,6 +713,12 @@ git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.8.21" + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + [[deps.NNlibCUDA]] deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" @@ -604,7 +750,7 @@ version = "0.2.4" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" +version = "0.3.21+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -641,10 +787,10 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.2" [[deps.PackageExtensionCompat]] -deps = ["Requires", "TOML"] git-tree-sha1 = "f9b1e033c2b1205cf30fd119f4e50881316c1923" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" version = "1.0.1" +weakdeps = ["Requires", "TOML"] [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] @@ -658,9 +804,9 @@ uuid = "570af359-4316-4cb7-8c74-252c00c2016b" version = "1.1.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" +version = "1.9.2" [[deps.PkgBenchmark]] deps = ["BenchmarkTools", "Dates", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Pkg", "Printf", "TerminalLoggers", "UUIDs"] @@ -673,6 +819,10 @@ deps = ["Adapt", "ArrayInterface", "ForwardDiff", "Requires"] git-tree-sha1 = "f739b1b3cc7b9949af3b35089931f2b58c289163" uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" version = "0.4.12" +weakdeps = ["ReverseDiff"] + + [deps.PreallocationTools.extensions] + PreallocationToolsReverseDiffExt = "ReverseDiff" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -743,6 +893,16 @@ git-tree-sha1 = "7ed35fb5f831aaf09c2d7c8736d44667a1afdcb0" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" version = "2.38.7" + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" @@ -811,14 +971,18 @@ uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" version = "1.1.1" [[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" [[deps.SplittablesBase]] deps = ["Setfield", "Test"] @@ -827,10 +991,14 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.6.2" +weakdeps = ["Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" @@ -840,6 +1008,7 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -863,6 +1032,11 @@ version = "0.6.15" deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + [[deps.SymbolicIndexingInterface]] deps = ["DocStringExtensions"] git-tree-sha1 = "f8ab052bfcbdb9b48fad2c80c873aa0d0344dfe5" @@ -878,7 +1052,7 @@ version = "1.2.0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" +version = "1.0.3" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -895,7 +1069,7 @@ version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.1" +version = "1.10.0" [[deps.TaylorDiff]] deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SymbolicUtils", "Zygote"] @@ -909,6 +1083,12 @@ git-tree-sha1 = "50718b4fc1ce20cecf28d85215028c78b4d875c2" uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea" version = "0.15.2" + [deps.TaylorSeries.extensions] + TaylorSeriesIAExt = "IntervalArithmetic" + + [deps.TaylorSeries.weakdeps] + IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" + [[deps.TerminalLoggers]] deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] git-tree-sha1 = "f133fab380933d042f6796eda4e130272ba520ca" @@ -931,6 +1111,12 @@ git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" version = "0.2.26" + [deps.Tracker.extensions] + TrackerPDMatsExt = "PDMats" + + [deps.Tracker.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + [[deps.TranscodingStreams]] deps = ["Random", "Test"] git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" @@ -943,6 +1129,20 @@ git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" version = "0.4.78" + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + [[deps.TruncatedStacktraces]] deps = ["InteractiveUtils", "MacroTools", "Preferences"] git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" @@ -987,7 +1187,7 @@ version = "0.1.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+3" +version = "1.2.13+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -995,6 +1195,16 @@ git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" version = "0.6.63" + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" @@ -1008,9 +1218,9 @@ uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" version = "1.1.0" [[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.1+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/src/derivative.jl b/src/derivative.jl index 9819157..c5f2b0b 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -4,51 +4,62 @@ export derivative """ derivative(f, x::T, order::Int64) - derivative(f, x::AbstractMatrix{T}, order::Int64) derivative(f, x::T, ::Val{N}) - derivative(f, x::AbstractMatrix{T}, ::Val{N}) -Computes `order`-th derivative of `f` w.r.t. `x`. +Computes `order`-th derivative of `f` w.r.t. scalar `x`. derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64) - derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N}) + +Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. + + derivative(f, x::AbstractMatrix{T}, order::Int64) + derivative(f, x::AbstractMatrix{T}, ::Val{N}) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N}) -Computes `order`-th directional derivative of `f` w.r.t. `x` in direction `l`. +Shorthand notations for multiple calculations. +For a M-by-N matrix, calculate the directional derivative for each column. +For a 1-by-N matrix (row vector), calculate the derivative for each scalar. """ function derivative end -@inline function derivative(f, x::Union{T, AbstractMatrix{T}}, - order::Int64) where {T <: Number} +# Convenience wrappers for converting orders to value types +# and forward work to core APIs + +@inline function derivative(f, x, order::Int64) derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::Union{AbstractVector{T}, AbstractMatrix{T}}, - l::AbstractVector{S}, order::Int64) where {T <: Number, S <: Number} +@inline function derivative(f, x, l, order::Int64) derivative(f, x, l, Val{order + 1}()) end +# Core APIs + +# Added to help Zygote infer types +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) + @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end -@inline function derivative(f, x::AbstractMatrix{<:Number}, N::Val) - size(x)[1] != 1 && @warn "x is not a row vector." - mapcols(u -> derivative(f, u[1], N), x) -end - -# Need to rewrite like this to help Zygote infer types -make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) - @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, vN::Val{N}) where {T <: Number, S <: Number, N} - t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) + t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) + # equivalent to map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end -@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, - vN::Val{N}) where {T <: Number, N} +# shorthand notations for matrices + +@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N} + size(x)[1] != 1 && @warn "x is not a row vector." + mapcols(u -> derivative(f, u[1], vN), x) +end + +@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, + vN::Val{N}) where {T <: Number, S <: Number, N} mapcols(u -> derivative(f, u, l, vN), x) end