diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 0854c0e04..4afc74b32 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -43,6 +43,11 @@ jobs: - {user: SciML, repo: StochasticDelayDiffEq.jl, group: All} - {user: SciML, repo: SimpleNonlinearSolve.jl, group: All} - {user: SciML, repo: SimpleDiffEq.jl, group: All} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core1} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core2} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core3} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index 2aba2b87b..0d085b4e1 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index e4d1e371f..a0a9bc900 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -23,6 +23,7 @@ import TruncatedStacktraces import ADTypes: AbstractADType import ChainRulesCore import ZygoteRules: @adjoint +import FillArrays using Reexport using SciMLOperators diff --git a/src/solutions/zygote.jl b/src/solutions/zygote.jl index 08090bdbf..d41d07e0f 100644 --- a/src/solutions/zygote.jl +++ b/src/solutions/zygote.jl @@ -1,6 +1,6 @@ @adjoint function getindex(VA::ODESolution, i::Int) function ODESolution_getindex_pullback(Δ) - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] (Δ′, nothing) end