Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors for gradient and hessian of logabsdet of sparse matrix #719

Closed
ElOceanografo opened this issue May 28, 2023 · 3 comments · Fixed by FluxML/Zygote.jl#1432 or #730
Closed

Errors for gradient and hessian of logabsdet of sparse matrix #719

ElOceanografo opened this issue May 28, 2023 · 3 comments · Fixed by FluxML/Zygote.jl#1432 or #730

Comments

@ElOceanografo
Copy link
Contributor

ElOceanografo commented May 28, 2023

The following error:

using LinearAlgebra, SparseArrays, Zygote
Zygote.gradient(x -> logabsdet(spdiagm(x)), rand(10))
ERROR: The inverse of a sparse matrix can often be dense... ```julia ERROR: The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please either convert your matrix to a dense matrix, e.g. by calling `Matrix` or if `A` can be factorized, use `\` on the dense identity matrix, e.g. `A \ Matrix{eltype(A)}(I, size(A)...)` restrictions of `\` on sparse lhs applies. Altenatively, `A\b` is generally preferable to `inv(A)*b` Stacktrace: [1] error(s::String) @ Base .\error.jl:35 [2] inv(A::SparseMatrixCSC{Float64, Int64}) @ SparseArrays C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\linalg.jl:1448 [3] (::Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}})(Δ::Tuple{Float64, Float64}) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\array.jl:379 [4] (::Zygote.var"#3123#back#837"{Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}}})(Δ::Tuple{Float64, Float64}) @ Zygote C:\Users\sam.urmy\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:71 [5] Pullback @ C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\LinearAlgebra\src\generic.jl:1685 [inlined] [6] (::Zygote.Pullback{Tuple{typeof(logdet), SparseMatrixCSC{Float64, Int64}}, Tuple{Zygote.var"#back#241"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#3123#back#837"{Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}}}, Zygote.ZBack{ChainRules.var"#log_pullback#1324"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Float64, Float64}}}, Zygote.var"#back#242"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Float64}}}}})(Δ::Float64) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0 [7] Pullback @ .\Untitled-1:54 [inlined] [8] (::Zygote.Pullback{Tuple{typeof(f), Vector{Float64}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}}, Tuple{Zygote.Pullback{Tuple{typeof(spdiagm), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(SparseArrays._spdiagm), Nothing, Pair{Int64, Vector{Float64}}}, Any}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2198#back#309"{Zygote.Jnew{Pair{Int64, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(logdet), SparseMatrixCSC{Float64, Int64}}, Tuple{Zygote.var"#back#241"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#3123#back#837"{Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}}}, Zygote.ZBack{ChainRules.var"#log_pullback#1324"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Float64, Float64}}}, Zygote.var"#back#242"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Float64}}}}}, Zygote.ZBack{Zygote.var"#literal_pow_pullback#327"{2, Float64}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:n, Zygote.Context{false}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}, Int64}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{2}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.ZBack{ChainRules.var"#fill_pullback#1455"{Tuple{NoTangent}, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#exp_pullback#1320"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{1}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.var"#1910#back#157"{Zygote.var"#153#156"}}})(Δ::Float64) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0 [9] Pullback @ .\Untitled-1:60 [inlined] [10] (::Zygote.Pullback{Tuple{var"#86#87", Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(f), Vector{Float64}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}}, Tuple{Zygote.Pullback{Tuple{typeof(spdiagm), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(SparseArrays._spdiagm), Nothing, Pair{Int64, Vector{Float64}}}, Any}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2198#back#309"{Zygote.Jnew{Pair{Int64, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(logdet), SparseMatrixCSC{Float64, Int64}}, Tuple{Zygote.var"#back#241"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#3123#back#837"{Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}}}, Zygote.ZBack{ChainRules.var"#log_pullback#1324"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Float64, Float64}}}, Zygote.var"#back#242"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Float64}}}}}, Zygote.ZBack{Zygote.var"#literal_pow_pullback#327"{2, Float64}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:n, Zygote.Context{false}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}, Int64}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{2}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.ZBack{ChainRules.var"#fill_pullback#1455"{Tuple{NoTangent}, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#exp_pullback#1320"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{1}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.var"#1910#back#157"{Zygote.var"#153#156"}}}, Zygote.var"#1974#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}}}}})(Δ::Float64) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0 [11] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#86#87", Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(f), Vector{Float64}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}}, Tuple{Zygote.Pullback{Tuple{typeof(spdiagm), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(SparseArrays._spdiagm), Nothing, Pair{Int64, Vector{Float64}}}, Any}, Zygote.Pullback{Tuple{Type{Pair}, Int64, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2198#back#309"{Zygote.Jnew{Pair{Int64, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Int64}, Int64}, Tuple{}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(logdet), SparseMatrixCSC{Float64, Int64}}, Tuple{Zygote.var"#back#241"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, Float64}}, Zygote.var"#3123#back#837"{Zygote.var"#833#836"{SparseMatrixCSC{Float64, Int64}}}, Zygote.ZBack{ChainRules.var"#log_pullback#1324"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{Zygote.var"#plus_pullback#341"{Tuple{Float64, Float64}}}, Zygote.var"#back#242"{Zygote.var"#2017#back#209"{Zygote.var"#back#207"{2, 2, Zygote.Context{false}, Float64}}}}}, Zygote.ZBack{Zygote.var"#literal_pow_pullback#327"{2, Float64}}, Zygote.var"#2168#back#299"{Zygote.var"#back#298"{:n, Zygote.Context{false}, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}, Int64}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{2}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.ZBack{ChainRules.var"#fill_pullback#1455"{Tuple{NoTangent}, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#exp_pullback#1320"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Float64}, Val{1}}, Tuple{Zygote.var"#2594#back#529"{Zygote.var"#539#541"{1, Float64, Vector{Float64}, Tuple{Int64}}}}}, Zygote.var"#1910#back#157"{Zygote.var"#153#156"}}}, Zygote.var"#1974#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}}}}}}})(Δ::Float64) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:45 [12] gradient(f::Function, args::Vector{Float64}) @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:97 [13] top-level scope @ Untitled-1:60 ```

And for the Hessian:

Zygote.hessian(x -> logabsdet(spdiagm(x)), rand(10))
ERROR: StackOverflowError:
ERROR: StackOverflowError:
Stacktrace:
  [1] Array
    @ .\boot.jl:477 [inlined]
  [2] Array
    @ .\boot.jl:486 [inlined]
  [3] Array
    @ .\boot.jl:494 [inlined]
  [4] similar
    @ .\abstractarray.jl:882 [inlined]
  [5] similar
    @ .\abstractarray.jl:881 [inlined]
  [6] similar
    @ .\broadcast.jl:212 [inlined]
  [7] similar
    @ .\broadcast.jl:211 [inlined]
  [8] copy
    @ .\broadcast.jl:898 [inlined]
  [9] materialize
    @ .\broadcast.jl:873 [inlined]
 [10] float(S::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
    @ SparseArrays C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\sparsematrix.jl:935
 [11] lu(A::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64}; check::Bool) (repeats 16262 times)
    @ SparseArrays.UMFPACK C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\solvers\umfpack.jl:398
 [12] logabsdet(A::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
    @ LinearAlgebra C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\LinearAlgebra\src\generic.jl:1660
 [13] adjoint
    @ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\array.jl:379 [inlined]
 [14] _pullback
    @ C:\Users\sam.urmy\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66 [inlined]
 [15] _pullback
    @ C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\LinearAlgebra\src\generic.jl:1685 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::typeof(logdet), args::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
 [17] _pullback
    @ .\Untitled-1:54 [inlined]
 [18] _pullback(::Zygote.Context{false}, ::typeof(f), ::Vector{ForwardDiff.Dual{Nothing, Float64, 2}}, ::NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
 [19] _pullback
    @ .\Untitled-1:60 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::var"#88#89", args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
 [21] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:44
 [22] pullback
    @ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:42 [inlined]
 [23] gradient(f::Function, args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:96
 [24] (::Zygote.var"#121#122"{var"#88#89"})(x::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:64
 [25] forward_jacobian(f::Zygote.var"#121#122"{var"#88#89"}, x::Vector{Float64}, #unused#::Val{2})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:29
 [26] forward_jacobian(f::Function, x::Vector{Float64}; chunk_threshold::Int64)
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:44
 [27] forward_jacobian
    @ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:42 [inlined]
 [28] hessian_dual
    @ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:64 [inlined]
 [29] hessian(f::Function, x::Vector{Float64})
    @ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:62
 [30] eval
    @ .\boot.jl:370 [inlined]
 [31] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base .\loading.jl:1864
 [32] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Base .\essentials.jl:816
 [33] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base .\essentials.jl:813
 [34] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:233
 [35] (::VSCodeServer.var"#66#70"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:157
 [36] withpath(f::VSCodeServer.var"#66#70"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\repl.jl:249
 [37] (::VSCodeServer.var"#65#69"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:155
 [38] hideprompt(f::VSCodeServer.var"#65#69"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\repl.jl:38
 [39] (::VSCodeServer.var"#64#68"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:126
 [40] with_logstate(f::Function, logstate::Any)
    @ Base.CoreLogging .\logging.jl:514
 [41] with_logger
    @ .\logging.jl:626 [inlined]
 [42] (::VSCodeServer.var"#63#67"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:225
 [43] #invokelatest#2
    @ .\essentials.jl:816 [inlined]
 [44] invokelatest(::Any)
    @ Base .\essentials.jl:813
 [45] macro expansion
    @ c:\Users\sam.urmy\.vscode\extensions\julialang.language-julia-1.47.2\scripts\packages\VSCodeServer\src\eval.jl:34 [inlined]
 [46] (::VSCodeServer.var"#61#62")()
    @ VSCodeServer .\task.jl:514

changing spdiagm to diagm makes both work as expected. I assume this is due to missing rrules for logabsdet(F::SparseArrays.UMFPACK.UmfpackLU)?

Relevant to TuringLang/DistributionsAD.jl#89

@mcabbott
Copy link
Member

mcabbott commented May 29, 2023

changing spdiagm to diagm makes both work as expected

Note that it returns a tuple, you need to keep just one element:

julia> gradient(logabsdet, diagm(1:3))
ERROR: Output should be scalar; gradients are not defined for output (1.791759469228055, 1.0)

julia> gradient(x -> logabsdet(x)[1], diagm(1:3))
([1.0 0.0 0.0; 0.0 0.5 0.0; 0.0 0.0 0.3333333333333333],)

julia> gradient(x -> logabsdet(x)[1], spdiagm(1:3))
ERROR: The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please either convert your matrix to a dense matrix, e.g. by calling `Matrix` or if `A` can be factorized, use `\` on the dense identity matrix, e.g. `A \ Matrix{eltype(A)}(I, size(A)...)` restrictions of `\` on sparse lhs applies. Altenatively, `A\b` is generally preferable to `inv(A)*b`
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] inv(A::SparseMatrixCSC{Int64, Int64})
   @ SparseArrays ~/.julia/dev/julia/usr/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:1431
 [3] (::Zygote.var"#833#836"{SparseMatrixCSC{Int64, Int64}})(Δ::Tuple{Float64, Nothing})
   @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:379
...
 [8] gradient(f::Function, args::SparseMatrixCSC{Int64, Int64})
   @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface.jl:97

Here src/lib/array.jl:379 is Zygote's own rule, which takes precedence -- CR isn't involved. That could probably be deleted.

The rule in ChainRules looks like it will have the same problem, but perhaps it can be fixed?

@ElOceanografo
Copy link
Contributor Author

In theory, I think this line:


could be rewritten as:

∂x = (x \ g')'

@ElOceanografo
Copy link
Contributor Author

I was wrong about that line of code above, g is a scalar so I don't think it can be rewritten to use \. As I understand it, two new rules are needed to avoid having to invert a sparse matrix:

  • lu(S::SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}; check, q, control)
  • logabsdet(F::SparseArrays.UMFPACK.UmfpackLU)

In fact, the generic definition for logabsdet(A::AbstractMatrix) first takes the lu decomposition of A, then calculates the logabsdet of the factorized matrix:
https://github.com/JuliaLang/julia/blob/a7348b7aa9d99af5ce5a8314f58a690132f21fb9/stdlib/LinearAlgebra/src/generic.jl#L1660
which is basically just a sum along the factor's diagonal. Even for dense matrices, lu is generally faster than inv, so could removing the rule for logabsdet(A::AbstractMatrix) and letting it fall through to rules for logabsdet of an LU object possibly speed up both the sparse and dense cases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants