Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a functor for projection #385

Merged
merged 44 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
914bd92
Sketch project implementation
willtebbutt Feb 24, 2021
06678a4
change Composite to Tangent
Jun 22, 2021
c58f974
export project
Jun 22, 2021
00020e3
make T optional
Jun 22, 2021
37f9253
add tests and Complex
Jun 22, 2021
4e1b79d
workout the edge cases
Jun 22, 2021
7dc58ee
rename dummy struct
Jun 22, 2021
3345ba9
rename project to projector
Jun 23, 2021
31d81ed
move to projector
Jun 24, 2021
2ea4845
do not close over x (other than in the general case)
Jun 24, 2021
465e1d7
update docstring
Jun 24, 2021
0a06dce
fix getproperty
Jun 24, 2021
d822b02
add to Tangent and to Symmetric
Jun 24, 2021
25a7cee
remove debug strings
Jun 24, 2021
7801e19
separate out the projector
Jun 24, 2021
9147fad
implement preproject
Jun 25, 2021
cc2f199
remove getproperty for thunks
Jun 25, 2021
2aa3859
remove to Tangent
Jun 25, 2021
44ef266
fix docstrings
Jun 25, 2021
d8848f5
project nested structs
Jun 25, 2021
88da9c6
Change from preproject to ProjectTo functor
oxinabox Jun 29, 2021
e0318b3
Make sure Arrays of Arrays etc work
oxinabox Jun 29, 2021
ce5d646
remove the special case ProjectTo(::Type{<:Number})
Jun 30, 2021
12a0db4
Merge branch 'master' into mz/projectto
Jun 30, 2021
f1a6260
add to_ prefix, add Transpose/Adjoint/SubArray
Jun 30, 2021
06268a3
add Adjoint and Transpose test
Jun 30, 2021
a981279
test Tangents with implicit zeros
Jun 30, 2021
eefd84f
throw error when ProjectTo to Tuple or NamedTuple
Jun 30, 2021
2facaea
fix transpose bug
Jun 30, 2021
9787b1b
add test for TwoFields
Jul 1, 2021
93c7489
test complex numbers too
Jul 1, 2021
e7190b2
nested where
Jul 1, 2021
233d292
fix SubArray
Jul 1, 2021
4c25f32
add Hermitian
Jul 2, 2021
029cb69
remove debug statements
Jul 2, 2021
b73e246
add Upper and LowerTriangular
Jul 2, 2021
9d665c0
PermutedDimsArray
Jul 2, 2021
030d636
Update test/projection.jl
mzgubic Jul 2, 2021
b87368f
fix docs
Jul 5, 2021
0f09ab9
JuliaFormatter
Jul 5, 2021
3a47f6f
simplify one of the PermutedDimsArray
Jul 5, 2021
ce022d5
document when to use ProjectTo
Jul 5, 2021
4106232
Apply suggestions from code review
mzgubic Jul 6, 2021
04a4e87
Update docs/Manifest.toml
mzgubic Jul 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,12 @@ function (::ProjectTo) end
# fallback (structs)
function ProjectTo(x::T) where {T}
# Generic fallback for structs, recursively make `ProjectTo`s all their fields
#println()
#@show x
#@show T
fields_nt::NamedTuple = backing(x)
#@show fields_nt
return ProjectTo{T}(map(ProjectTo, fields_nt))
end
function (project::ProjectTo{T})(dx::Tangent) where {T}
sub_projects = backing(project)
#@show sub_projects
sub_dxs = backing(canonicalize(dx))
#@show sub_dxs
_call(f, x) = f(x)
return construct(T, map(_call, sub_projects, sub_dxs))
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
end
Expand Down Expand Up @@ -95,7 +89,7 @@ ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x)))
(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx)))
(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx))

# Symmetric and Hermitian
# :data, :uplo fields
for SymHerm = (:Symmetric, :Hermitian)
@eval begin
ProjectTo(x::T) where {T<:$SymHerm} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x)))
Expand All @@ -105,16 +99,30 @@ for SymHerm = (:Symmetric, :Hermitian)
end
end

# :data field
for UL = (:UpperTriangular, :LowerTriangular)
@eval begin
ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx))
(project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx))
(project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data))
end
end

# Transpose
ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx)))
(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(conj(parent(dx))))
(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx))

# Adjoint
ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx)))
(project::ProjectTo{<:Adjoint})(dx::ZeroTangent) = adjoint(project.parent(dx))
(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx))

# PermutedDimsArray
ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractArray) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm))
(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractZero) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(project.parent(dx))
mzgubic marked this conversation as resolved.
Show resolved Hide resolved

# SubArray
ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
23 changes: 23 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,20 @@ end
@test x == ProjectTo(x)(@thunk(ZeroTangent()))
end

@testset "to $UL" for UL in (UpperTriangular, LowerTriangular)
data = [1.0+1im 2-2im; 3 4]

x = UL(data)
@test x == ProjectTo(x)(data)
@test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data))

data = [0.0+0im 0; 0 0]
x = UL(data)
@test x == ProjectTo(x)(Diagonal(zeros(2)))
@test x == ProjectTo(x)(ZeroTangent())
@test x == ProjectTo(x)(@thunk(ZeroTangent()))
end

@testset "to Transpose" begin
x = rand(ComplexF64, 3, 4)
t = transpose(x)
Expand All @@ -202,6 +216,15 @@ end
@test zeros(4, 3) == ProjectTo(a)(Tangent{Adjoint}(; parent=ZeroTangent()))
end

@testset "to PermutedDimsArray" begin
a = zeros(3, 5, 4)
b = PermutedDimsArray(a, (2, 1, 3))
bc = collect(b)

@test b == ProjectTo(b)(bc)
@test b == ProjectTo(b)(ZeroTangent())
end

@testset "to SubArray" begin
x = rand(3, 4)
sa = view(x, :, 1:2)
Expand Down