-
-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: fix KrylovJL_GMRES with Enzyme #382
base: main
Are you sure you want to change the base?
Conversation
ext/LinearSolveEnzymeExt.jl
Outdated
using Enzyme | ||
|
||
using EnzymeCore | ||
|
||
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably add a minimum version 0.6.0 on EnzymeCore ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's already the minimum
520ed1f
to
8fc4ae3
Compare
ERROR: AssertionError: SciMLBase.LinearSolution{Float64, 1, Vector{Float64}, Float64, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}, Krylov.GmresSolver{Float64, Float64, Vector{Float64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Nothing} has mixed internal activity types
Stacktrace:
[1] active_reg(::Type{SciMLBase.LinearSolution{…}}, world::UInt64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:468
[2]
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:4125
[3] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:4390
[4] enzyme_custom_augfwd
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:4758 [inlined]
[5] (::Enzyme.Compiler.var"#199#200")(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, normalR::Ptr{…}, shadowR::Ptr{…}, tapeR::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:6650
[6] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/rbuCz/src/api.jl:141
[7] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:7726
[8] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9278
[9] codegen
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:8886 [inlined]
[10] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
[11] cached_compilation
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
[12] (::Enzyme.Compiler.var"#474#475"{…})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
[13] JuliaContext(f::Enzyme.Compiler.var"#474#475"{…})
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
[14] #s325#473
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
[15]
@ Enzyme.Compiler ./none:0
[16] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:600
[17] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(f3)}, ::Type{Active{…}}, ::Duplicated{Matrix{…}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207
[18] autodiff(::ReverseMode{…}, ::Const{…}, ::Duplicated{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:236
[19] autodiff(::ReverseMode{…}, ::typeof(f3), ::Duplicated{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
[20] top-level scope
@ REPL[1]:1
Some type information was truncated. Use `show(err)` to see complete types. |
8fc4ae3
to
d7aab0b
Compare
@wsmoses is this saying you can't have inactive things in a struct with active things? That doesn't make sense because having |
Inactive things in a struct with active things are fine. This is an error that you cannot have active things (Aka floats) in the outermost struct with duplicated things (Aka arrays), unless the struct is inside of a pointer (like mutated in a ref etc). |
But I've labelled all algs (and thus GMRES) as inactive, so why would it give me an error that it has mixed activity? |
The type that it is complaining about is |
oh yeah oops. But wait, I can't have an inactive field in the solution type? There's cases where there's |
you can have an inactive type like nothing anywhere, you cannot have an active type like a float and a duplicated type like a vector (in rev mode, in a way that crosses julia abi) |
So I am still not quite getting it. Why is this algorithm not allowed while the other ones are? I added the dispatch so that it would be treated just like the singleton ones. The LinearSolution type should be active (well duplicated) since it acts like a vector, and just ignores derivatives w.r.t. alg (which are declared always inactive). Why would changing the algorithm to the GMRES type (which can take a control vector) change anything if it's declared inactive? |
From the type signature and the definition in the error message, looks like the resid is Float64, which is active: https://github.com/SciML/SciMLBase.jl/blob/6b0a38535d530540a5b780b096d293a36100ad97/src/solutions/basic_solutions.jl#L24 |
Yes, is there a way to mark that field as inactive? |
Not at the moment |
Okay, so currently this isn't solvable without a breaking change to make |
In the alternative in the interim you could make the solution type mutable |
You could also make a special nodiff float type, mark it inactive, then store that |
d7aab0b
to
c5fe914
Compare
c5fe914
to
869f7fb
Compare
No description provided.