Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: handle zero_tangent in presence of cyclic structures (v1) #654

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7f15d11
rename files
oxinabox Aug 4, 2023
c9b4938
move functionality up to StructuralTangent
oxinabox Aug 4, 2023
a51e51e
Formatting
oxinabox Aug 4, 2023
93af90b
WIP mutable Tangent (squash me)
oxinabox Aug 21, 2023
418b5ce
wip
oxinabox Sep 15, 2023
724ba1b
First pass at something that maybe works
oxinabox Sep 15, 2023
c9a65df
accept int index
oxinabox Sep 18, 2023
06b51f5
add == and hash for MutableTangent
oxinabox Sep 26, 2023
ed3aa1d
add and test zero_tangent
oxinabox Sep 26, 2023
f45fbc7
export StructuralTangent
oxinabox Sep 28, 2023
b2bdb26
Style
oxinabox Oct 2, 2023
0438217
handle unassigned a bit more
oxinabox Oct 4, 2023
4852c91
add some more test cases to zero_tangent
oxinabox Oct 4, 2023
0f82019
style
oxinabox Oct 4, 2023
5574691
Handle Structs with undef fields
oxinabox Oct 6, 2023
e9cc221
overhaul zero_tangent and MutableTangent for type stability
oxinabox Dec 22, 2023
baea9d3
set MutableTangent setproperty! on index
oxinabox Dec 27, 2023
a27f1b6
formatting
oxinabox Dec 27, 2023
4cfce0b
handle abstract fields right in mutable tangents outside of zero tangent
oxinabox Dec 28, 2023
ad9a5af
formatting
oxinabox Dec 28, 2023
8b3d525
Add docs for forward mutation support
oxinabox Dec 28, 2023
c09ff91
use ismutabletype from Compat
oxinabox Dec 29, 2023
59fc470
wrap structural tangent tests in a common testset
oxinabox Dec 29, 2023
ade0c3d
Support types that have no tangent space in zero_tangent
oxinabox Dec 29, 2023
e068cb6
define zero_tangent for Tangent
oxinabox Jan 16, 2024
45de6a7
Add structural zero tangent code for higher order
oxinabox Jan 17, 2024
780ed05
Formatting
oxinabox Jan 17, 2024
2795872
overload show for mutable tangent
oxinabox Jan 17, 2024
7e9e778
formatting
oxinabox Jan 19, 2024
e912e46
move show code to `Common` area
oxinabox Jan 23, 2024
e478e7f
docs more consistent
oxinabox Jan 23, 2024
7d95866
Update src/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
fe63c33
Update test/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
5fbbe5b
Handle circular references with-in mutable structs
oxinabox Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle unassigned a bit more
  • Loading branch information
oxinabox committed Jan 23, 2024
commit 0438217811ed7b0006ce270d5930336a564f7b40
28 changes: 23 additions & 5 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,9 @@ For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is d
Exactly how it should be used (e.g. is it forward-mode only?)
"""
function zero_tangent end
zero_tangent(::AbstractString) = ZeroTangent()
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
zero_tangent(primal::Array) = map(zero_tangent, primal)

zero_tangent(x::Number) = zero(x)

@generated function zero_tangent(primal)
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
zfield_exprs = map(fieldnames(primal)) do fname
Expand All @@ -119,4 +118,23 @@ zero_tangent(primal::Array) = map(zero_tangent, primal)
end
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
return :($MutableTangent{$primal}($backing_expr))
end
end

function zero_tangent(x::Array{P, N}) where {P, N}
(isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x)

# Now we need to handle nonfully assigned arrays
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...)
@inbounds for n in eachindex(y)
if isassigned(x, n)
y[n] = zero_tangent(x[n])
end
end
return y
end

guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N}
guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634
# TODO: we might be able to do better than this. even without.
25 changes: 25 additions & 0 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,29 @@ end

@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]

@testset "undef elements" begin
x = Vector{Vector{Float64}}(undef, 3)
x[2] = [1.0,2.0]
dx = zero_tangent(x)
@test dx isa Vector{Vector{Float64}}
@test length(dx) == 3
@test !isassigned(dx, 1)
@test dx[2] == [0.0, 0.0]
@test !isassigned(dx, 3)


a = Vector{MutDemo}(undef, 3)
a[2] = MutDemo(1.5)
da = zero_tangent(a)
@test !isassigned(da, 1)
@test iszero(da[2])
@test !isassigned(da, 3)


db = zero_tangent(Vector{MutDemo}(undef, 3))
@test all(ii->!isassigned(db,ii), eachindex(db))
@test length(db)==3
@test db isa Vector
end
end