Skip to content

Commit

Permalink
Fix arrayfuse broadcast overload
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 15, 2022
1 parent 8c5bdd5 commit d79a290
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
11 changes: 3 additions & 8 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@ end

ArrayFuse(visible::AT, hidden::AT, p) where {AT} = ArrayFuse{AT, eltype(visible), typeof(p)}(visible, hidden, p)

@inline function Base.copyto!(af::OrdinaryDiffEq.ArrayFuse{AT, T, P}, src::Base.Broadcast.Broadcasted) where {AT, T, P}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end

@inline function Base.copyto!(af::OrdinaryDiffEq.ArrayFuse{AT, T, P}, src::Base.Broadcast.Broadcasted{F1, Axes, F, Args}) where {AT, T, P, F1<:Base.Broadcast.AbstractArrayStyle{0}, Axes, F, Args<:Tuple}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
@inline function Base.materialize!(af::ArrayFuse, src::Broadcast.Broadcasted)
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end

# not recommended but good to have
Expand Down
10 changes: 5 additions & 5 deletions test/gpu/linear_lsrk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ for alg in algs
@time sol = solve(prob,alg,save_everystep=false,save_start=false,dt=0.01)

# GPU warmup
@test_broken solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
@test_broken solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
@test_broken solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
println("GPU Times for $alg")
@time @test_broken sol2 = solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)
@time sol2 = solve(prob2,alg,save_everystep=false,save_start=false,dt=0.01)

@test_broken sol[end] Array(sol2[end])
@test sol[end] Array(sol2[end])
end

0 comments on commit d79a290

Please sign in to comment.