Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Change matrix to_vec semantics #187

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.18"
version = "0.13.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
109 changes: 27 additions & 82 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ function to_vec(z::Complex)
return [real(z), imag(z)], Complex_from_vec
end

# Integers cannot be perturbed!
function to_vec(x::Integer)
Integer_from_vec(v) = x
return Bool[], Integer_from_vec
end

# Base case -- if x is already a Vector{<:Real} there's no conversion necessary.
to_vec(x::Vector{<:Real}) = (x, identity)

Expand All @@ -37,9 +43,9 @@ end
# chunk of the time.
function to_vec(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types
is_singleton(x) && return (Bool[], _ -> x) # Singleton types

val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T))
val_vecs_and_backs = get_val_vecs_and_backs(x)
vals = first.(val_vecs_and_backs)
backs = last.(val_vecs_and_backs)

Expand All @@ -56,6 +62,16 @@ function to_vec(x::T) where {T}
return v, structtype_from_vec
end

# Type-stable way to determine whether a type has any fields.
@generated function is_singleton(x)
return isempty(fieldnames(x)) ? :true : :false
end

# Type-stable way to call `to_vec` on each field.
@generated function get_val_vecs_and_backs(x)
return Expr(:tuple, map(name -> :(to_vec(x.$name)), fieldnames(x))...)
end

function to_vec(x::DenseVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
Expand All @@ -79,83 +95,11 @@ function to_vec(x::DenseArray)
return x_vec, Array_from_vec
end

# Some specific subtypes of AbstractArray.
function to_vec(x::Base.ReshapedArray{<:Any, 1})
x_vec, from_vec = to_vec(parent(x))
function ReshapedArray_from_vec(x_vec)
p = from_vec(x_vec)
return Base.ReshapedArray(p, x.dims, x.mi)
end

return x_vec, ReshapedArray_from_vec
end

# To return a SubArray we would endup needing to copy the `parent` of `x` in `from_vec`
# which doesn't seem particularly useful. So we just convert the view into a copy.
# we might be able to do something more performant but this seems good for now.
to_vec(x::Base.SubArray) = to_vec(copy(x))

function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
x_vec, back = to_vec(Matrix(x))
function AbstractTriangular_from_vec(x_vec)
return T(reshape(back(x_vec), size(x)))
end
return x_vec, AbstractTriangular_from_vec
end

function to_vec(x::T) where {T<:LinearAlgebra.HermOrSym}
x_vec, back = to_vec(Matrix(x))
function HermOrSym_from_vec(x_vec)
return T(back(x_vec), x.uplo)
end
return x_vec, HermOrSym_from_vec
end

function to_vec(X::Diagonal)
x_vec, back = to_vec(Matrix(X))
function Diagonal_from_vec(x_vec)
return Diagonal(back(x_vec))
end
return x_vec, Diagonal_from_vec
end

function to_vec(X::Transpose)
x_vec, back = to_vec(Matrix(X))
function Transpose_from_vec(x_vec)
return Transpose(permutedims(back(x_vec)))
end
return x_vec, Transpose_from_vec
end

function to_vec(x::Transpose{<:Any, <:AbstractVector})
x_vec, back = to_vec(Matrix(x))
Transpose_from_vec(x_vec) = Transpose(vec(back(x_vec)))
return x_vec, Transpose_from_vec
end

function to_vec(X::Adjoint)
x_vec, back = to_vec(Matrix(X))
function Adjoint_from_vec(x_vec)
return Adjoint(conj!(permutedims(back(x_vec))))
end
return x_vec, Adjoint_from_vec
end

function to_vec(x::Adjoint{<:Any, <:AbstractVector})
x_vec, back = to_vec(Matrix(x))
Adjoint_from_vec(x_vec) = Adjoint(conj!(vec(back(x_vec))))
return x_vec, Adjoint_from_vec
end

function to_vec(X::T) where {T<:PermutedDimsArray}
x_vec, back = to_vec(parent(X))
function PermutedDimsArray_from_vec(x_vec)
X_parent = back(x_vec)
return T(X_parent)
end
return x_vec, PermutedDimsArray_from_vec
end

# Factorizations

function to_vec(x::F) where {F <: SVD}
Expand All @@ -170,14 +114,6 @@ function to_vec(x::F) where {F <: SVD}
return x_vec, SVD_from_vec
end

function to_vec(x::Cholesky)
x_vec, back = to_vec(x.factors)
function Cholesky_from_vec(v)
return Cholesky(back(v), x.uplo, x.info)
end
return x_vec, Cholesky_from_vec
end

function to_vec(x::S) where {U, S <: Union{LinearAlgebra.QRCompactWYQ{U}, LinearAlgebra.QRCompactWY{U}}}
# x.T is composed of upper triangular blocks. The subdiagonals elements
# of the blocks are abitrary. We make sure to set all of them to zero
Expand All @@ -203,6 +139,11 @@ end

# Non-array data structures

function to_vec(x::Tuple{})
Tuple_from_vec(v) = ()
return Bool[], Tuple_from_vec
end

function to_vec(x::Tuple)
x_vecs_and_backs = map(to_vec, x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
Expand Down Expand Up @@ -260,3 +201,7 @@ function FiniteDifferences.to_vec(t::Thunk)
Thunk_from_vec = v -> @thunk(back(v))
return v, Thunk_from_vec
end

# Things that aren't struct types and aren't differentiable.
to_vec(x::Char) = Bool[], _ -> x
to_vec(x::Symbol) = Bool[], _ -> x
9 changes: 9 additions & 0 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ function test_to_vec(x::T; check_inferred=true) where {T}
end

@testset "to_vec" begin

@testset "Integer" begin
# Under ChainRules semantics, Integers cannot be perturbed. `to_vec` is primarily a
# tool designed to work with ChainRules, so we employ the same semantics here.
test_to_vec(5)
@test length(to_vec(5)[1]) == 0
end

@testset "$T" for T in (Float32, ComplexF32, Float64, ComplexF64)
if T == Float64
test_to_vec(1.0)
Expand Down Expand Up @@ -171,6 +179,7 @@ end
end

@testset "Tuples" begin
test_to_vec(())
test_to_vec((5, 4))
test_to_vec((5, randn(T, 5)); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false)
Expand Down