Skip to content

Commit

Permalink
Merge pull request #1038 from SciML/ChrisRackauckas-patch-6
Browse files Browse the repository at this point in the history
Make sure to extend
  • Loading branch information
ChrisRackauckas authored May 20, 2024
2 parents 55657af + 110dcce commit 48d12e4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ext/DiffEqBaseCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module DiffEqBaseCUDAExt

using DiffEqBase, CUDA

function ODE_DEFAULT_NORM(u::CuArray{T},t) where {T <: Union{AbstractFloat, Complex}}
sqrt(sum(DiffEqBase.sse, u; init = DiffEqBase.sse(zero(T))) / totallength(u))
function DiffEqBase.ODE_DEFAULT_NORM(u::CuArray{T},t) where {T <: Union{AbstractFloat, Complex}}
sqrt(sum(DiffEqBase.sse, u; init = DiffEqBase.sse(zero(T))) / DiffEqBase.totallength(u))
end

end
2 changes: 1 addition & 1 deletion src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x))

@inline fastpow(x::ForwardDiff.Dual, y::ForwardDiff.Dual) = x^y

sse(x::Number) = x^2
sse(x::Number) = abs2(x)
sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
totallength(x::Number) = 1
function totallength(x::ForwardDiff.Dual)
Expand Down

0 comments on commit 48d12e4

Please sign in to comment.