Skip to content

Commit

Permalink
Incorrect trace
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent e773e63 commit 3700773
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ version = "6.144.0"

[[deps.DiffEqCallbacks]]
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"]
git-tree-sha1 = "d0b94b3694d55e7eedeee918e7daee9e3b873399"
git-tree-sha1 = "e48b985459d1cbe8c809de192529f1e25c3382a6"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
version = "2.35.0"
version = "2.36.0"

[deps.DiffEqCallbacks.weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
7 changes: 4 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
::Val{shapes}) where {idxs, shapes}
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
varnames = map(_ -> gensym("x_view"), dims)
calls = [:($(varnames[i]) = view(x, $(dims[i]), :)) for i in 1:length(dims)]
calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in 1:length(dims)]
return quote
$(calls...)
return tuple($(varnames...))
Expand All @@ -14,7 +14,8 @@ __split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x
function __split_and_reshape(y::AbstractMatrix, x)
szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x]
counters = vcat(0, cumsum(szs)[1:(end - 1)])
return map((sz, c, xᵢ) -> reshape(view(y, (c + 1):(c + sz), :), size(xᵢ)),
# Make the data contiguous
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))),
szs, counters, x)
end

Expand Down Expand Up @@ -95,7 +96,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)

# Jacobian Stabilization
function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng)
__f = u -> first(model((u, x), ps))
__f = u -> model((u, x), ps)
res = zero(eltype(x))
ϵ = cbrt(eps(typeof(res)))
ϵ⁻¹ = inv(ϵ)
Expand Down
9 changes: 5 additions & 4 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end

model_type = (:deq, :skipdeq, :skipregdeq)
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
jacobian_regularizations = Any[nothing, AutoZygote()]
!ongpu && push!(jacobian_regularizations, AutoFiniteDiff())

@testset "Solver: $(__nameof(solver))" for solver in solvers,
mtype in model_type, jacobian_regularization in jacobian_regularizations
Expand Down Expand Up @@ -133,10 +134,10 @@ end
jacobian_regularization)
end

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
@test st.solution == DeepEquilibriumSolution()

x = randn(rng, Float32, x_size...)
x = randn(rng, Float32, x_size...) |> dev
z, st = model(x, ps, st)
z_ = DEQs.__flatten_vcat(z)

Expand All @@ -157,7 +158,7 @@ end
@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
@test st.solution == DeepEquilibriumSolution()

Expand Down
2 changes: 2 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote
import LuxTestUtils: @jet
using LuxCUDA

CUDA.allowscalar(false)

__nameof(::X) where {X} = nameof(X)

__get_prng(seed::Int) = StableRNG(seed)
Expand Down

0 comments on commit 3700773

Please sign in to comment.