From 305277afecbb15b82ccdd5a6878c0af2b4edd7a5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 09:41:33 +0200 Subject: [PATCH] Add function setladj Co-authored-by: David Widmann --- Project.toml | 11 ++- README.md | 3 +- docs/Project.toml | 1 + docs/make.jl | 2 +- docs/src/api.md | 7 ++ ext/ChangesOfVariablesInverseFunctionsExt.jl | 32 +++++++++ src/ChangesOfVariables.jl | 5 ++ src/setladj.jl | 71 ++++++++++++++++++++ test/runtests.jl | 4 +- test/test_setladj.jl | 58 ++++++++++++++++ 10 files changed, 189 insertions(+), 5 deletions(-) create mode 100644 ext/ChangesOfVariablesInverseFunctionsExt.jl create mode 100644 src/setladj.jl create mode 100644 test/test_setladj.jl diff --git a/Project.toml b/Project.toml index 52369e1..ef337a6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,15 +3,24 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.7" [deps] +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[weakdeps] +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[extensions] +ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" + [compat] +InverseFunctions = "0.1" julia = "1" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [targets] -test = ["Documenter", "ForwardDiff"] +test = ["Documenter", "InverseFunctions", "ForwardDiff"] diff --git a/README.md b/README.md index 11c636c..2c5bca0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ changes for functions that perform a change of variables (like coordinate transformations). `ChangesOfVariables` is a very lightweight package and has no dependencies -beyond `Base`, `LinearAlgebra`, `Test`. +beyond `Base`, `LinearAlgebra` and `Test` (plus a weak depdendency on +`InverseFunctions`). ## Documentation diff --git a/docs/Project.toml b/docs/Project.toml index f656746..ef0d8b4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" diff --git a/docs/make.jl b/docs/make.jl index 9504cd5..bbb84b1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,7 @@ using ChangesOfVariables DocMeta.setdocmeta!( ChangesOfVariables, :DocTestSetup, - :(using ChangesOfVariables); + :(using ChangesOfVariables, InverseFunctions); recursive=true, ) diff --git a/docs/src/api.md b/docs/src/api.md index 4ee897f..cee4eb3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -5,6 +5,7 @@ ```@docs with_logabsdet_jacobian NoLogAbsDetJacobian +setladj ``` ## Test utility @@ -12,3 +13,9 @@ NoLogAbsDetJacobian ```@docs ChangesOfVariables.test_with_logabsdet_jacobian ``` + +## Additional functionality + +```@docs +ChangesOfVariables.FunctionWithLADJ +``` diff --git a/ext/ChangesOfVariablesInverseFunctionsExt.jl b/ext/ChangesOfVariablesInverseFunctionsExt.jl new file mode 100644 index 0000000..76c2eb4 --- /dev/null +++ b/ext/ChangesOfVariablesInverseFunctionsExt.jl @@ -0,0 +1,32 @@ +module ChangesOfVariablesInverseFunctionsExt + +using ChangesOfVariables +using InverseFunctions + + +struct InverseFunctionWithLADJ{InvF,LADJF} <: Function + inv_f::InvF + ladjf::LADJF +end +InverseFunctionWithLADJ(::Type{InvF}, ladjf::LADJF) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},LADJF}(InvF,ladjf) +InverseFunctionWithLADJ(inv_f::InvF, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{InvF,Type{LADJF}}(inv_f,LADJF) +InverseFunctionWithLADJ(::Type{InvF}, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},Type{LADJF}}(InvF,LADJF) + +(f::InverseFunctionWithLADJ)(y) = f.inv_f(y) + +function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctionWithLADJ, y) + x = f.inv_f(y) + return x, -f.ladjf(x) +end + +InverseFunctions.inverse(f::ChangesOfVariables.FunctionWithLADJ) = InverseFunctionWithLADJ(inverse(f.f), f.ladjf) +InverseFunctions.inverse(f::InverseFunctionWithLADJ) = ChangesOfVariables.FunctionWithLADJ(inverse(f.inv_f), f.ladjf) + + +@static if isdefined(InverseFunctions, :FunctionWithInverse) + function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctions.FunctionWithInverse, x) + ChangesOfVariables.with_logabsdet_jacobian(f.f, x) + end +end + +end # module ChangesOfVariablesInverseFunctionsExt diff --git a/src/ChangesOfVariables.jl b/src/ChangesOfVariables.jl index b2b034d..daa3e0a 100644 --- a/src/ChangesOfVariables.jl +++ b/src/ChangesOfVariables.jl @@ -13,6 +13,11 @@ using LinearAlgebra using Test include("with_ladj.jl") +include("setladj.jl") include("test.jl") +@static if !isdefined(Base, :get_extension) + include("../ext/ChangesOfVariablesInverseFunctionsExt.jl") +end + end # module diff --git a/src/setladj.jl b/src/setladj.jl new file mode 100644 index 0000000..377c8e7 --- /dev/null +++ b/src/setladj.jl @@ -0,0 +1,71 @@ +# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). + + +""" + struct FunctionWithLADJ{F,LADJF} <: Function + +A function with an separate function to compute it's `logabddet(J)`. + +Do not construct directly, use [`setladj(f, ladjf)`](@ref) instead. +""" +struct FunctionWithLADJ{F,LADJF} <: Function + f::F + ladjf::LADJF +end +FunctionWithLADJ(::Type{F}, ladjf::LADJF) where {F,LADJF} = FunctionWithLADJ{Type{F},LADJF}(F,ladjf) +FunctionWithLADJ(f::F, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{F,Type{LADJF}}(f,LADJF) +FunctionWithLADJ(::Type{F}, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{Type{F},Type{LADJF}}(F,LADJF) + +(f::FunctionWithLADJ)(x) = f.f(x) + +with_logabsdet_jacobian(f::FunctionWithLADJ, x) = f.f(x), f.ladjf(x) + + +""" + setladj(f, ladjf)::Function + +Return a function that behaves like `f` in general and which has +`with_logabsdet_jacobian(f, x) = f(x), ladjf(x)`. + +Useful in cases where [`with_logabsdet_jacobian`](@ref) is not defined +for `f`, or if `f` needs to be assigned a LADJ-calculation that is +only valid within a given context, e.g. only for a +limited argument type/range that is guaranteed by the use case but +not in general, or that is optimized to a custom use case. + +For example, `CUDA.CuArray` has no `with_logabsdet_jacobian` defined, +but may be used to switch computing device for a part of a +heterogenous computing function chain. Likewise, one may want to +switch numerical precision for a part of a calculation. + +The function (wrapper) returned by `setladj` supports +[`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) +if `f` does so. + +Example: + +```jldoctest setladj +VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6 + # Increases precition before calculation exp: + foo = exp ∘ setladj(setinverse(Float64, Float32), _ -> 0) + + # A log-value from some low-precision (e.g. GPU) computation: + log_x = Float32(100) + + # f(log_x) would return Inf32 without going to Float64: + y, ladj = with_logabsdet_jacobian(foo, log_x) + + r_log_x, ladj_inv = with_logabsdet_jacobian(inverse(foo), y) + + ladj ≈ 100 ≈ -ladj_inv && r_log_x ≈ log_x +end +# output + +true +``` +""" +setladj(f, ladjf) = FunctionWithLADJ(_unwrap_f(f), ladjf) +export setladj + +_unwrap_f(f) = f +_unwrap_f(f::FunctionWithLADJ) = f.f diff --git a/test/runtests.jl b/test/runtests.jl index 035b048..d2afb51 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,14 +7,14 @@ import Documenter Test.@testset "Package ChangesOfVariables" begin include("test_test.jl") include("test_with_ladj.jl") + include("test_setladj.jl") # doctests Documenter.DocMeta.setdocmeta!( ChangesOfVariables, :DocTestSetup, - :(using ChangesOfVariables); + :(using ChangesOfVariables, InverseFunctions); recursive=true, ) Documenter.doctest(ChangesOfVariables) end # testset - diff --git a/test/test_setladj.jl b/test/test_setladj.jl new file mode 100644 index 0000000..02a6bd8 --- /dev/null +++ b/test/test_setladj.jl @@ -0,0 +1,58 @@ +# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). + +using Test +using ChangesOfVariables +using InverseFunctions + +const ChangesOfVariablesInverseFunctionsExt = if isdefined(Base, :get_extension) + Base.get_extension(ChangesOfVariables, :ChangesOfVariablesInverseFunctionsExt) +else + ChangesOfVariables.ChangesOfVariablesInverseFunctionsExt +end +const InverseFunctionWithLADJ = ChangesOfVariablesInverseFunctionsExt.InverseFunctionWithLADJ + +include("getjacobian.jl") + + +# Dummy testing type that looks like something that represents abstract zeros: +struct _Zero{T} end +_Zero(::T) where {T} = _Zero{T}() + + +@testset "setladj" begin + @test @inferred(setladj(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}} + @test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}} + @test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),Type{_Zero}} + @test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},typeof(zero)} + @test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),typeof(zero)} + + @test @inferred(InverseFunctionWithLADJ(Real, _Zero)) isa InverseFunctionWithLADJ{Type{Real},Type{_Zero}} + @test @inferred(InverseFunctionWithLADJ(widen, _Zero)) isa InverseFunctionWithLADJ{typeof(widen),Type{_Zero}} + @test @inferred(InverseFunctionWithLADJ(Real, zero)) isa InverseFunctionWithLADJ{Type{Real},typeof(zero)} + @test @inferred(InverseFunctionWithLADJ(widen, zero)) isa InverseFunctionWithLADJ{typeof(widen),typeof(zero)} + + @test @inferred(setladj(setladj(exp, x -> 0), x -> x)) isa ChangesOfVariables.FunctionWithLADJ{typeof(exp)} + ChangesOfVariables.test_with_logabsdet_jacobian(setladj(setladj(exp, x -> 0), x -> x), 1.7, getjacobian) + + x = 4.2 + y = x^2 + + f_fwd = setladj(x -> x^2, x -> log(2*x)) + f_inv = setladj(y -> sqrt(y), y -> log(inv(2*sqrt(y)))) + ChangesOfVariables.test_with_logabsdet_jacobian(f_fwd, x, getjacobian) + ChangesOfVariables.test_with_logabsdet_jacobian(f_inv, y, getjacobian) + + f = @inferred setladj(setinverse(x -> x^2, x -> sqrt(x)), x -> log(2*x)) + @test @inferred(f(x)) == y + ChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian) + ChangesOfVariables.test_with_logabsdet_jacobian(inverse(f), y, getjacobian) + ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(f)), x, getjacobian) + @inferred(inverse(inverse(f))) isa ChangesOfVariables.FunctionWithLADJ + + @static if isdefined(InverseFunctions, :setinverse) + g = setinverse(f_fwd, f_inv) + ChangesOfVariables.test_with_logabsdet_jacobian(g, x, getjacobian) + ChangesOfVariables.test_with_logabsdet_jacobian(inverse(g), y, getjacobian) + ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(g)), x, getjacobian) + end +end