Skip to content

Commit

Permalink
Incorrect trace
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent e773e63 commit 73cd560
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)

# Jacobian Stabilization
function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng)
__f = u -> first(model((u, x), ps))
__f = u -> model((u, x), ps)
res = zero(eltype(x))
ϵ = cbrt(eps(typeof(res)))
ϵ⁻¹ = inv(ϵ)
Expand Down
9 changes: 5 additions & 4 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end

model_type = (:deq, :skipdeq, :skipregdeq)
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
jacobian_regularizations = [nothing, AutoZygote()]
!ongpu && push!(jacobian_regularizations, AutoFiniteDiff())

@testset "Solver: $(__nameof(solver))" for solver in solvers,
mtype in model_type, jacobian_regularization in jacobian_regularizations
Expand Down Expand Up @@ -133,10 +134,10 @@ end
jacobian_regularization)
end

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
@test st.solution == DeepEquilibriumSolution()

x = randn(rng, Float32, x_size...)
x = randn(rng, Float32, x_size...) |> dev
z, st = model(x, ps, st)
z_ = DEQs.__flatten_vcat(z)

Expand All @@ -157,7 +158,7 @@ end
@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
@test st.solution == DeepEquilibriumSolution()

Expand Down
2 changes: 2 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote
import LuxTestUtils: @jet
using LuxCUDA

CUDA.allowscalar(false)

__nameof(::X) where {X} = nameof(X)

__get_prng(seed::Int) = StableRNG(seed)
Expand Down

0 comments on commit 73cd560

Please sign in to comment.