Skip to content

Commit

Permalink
Merge pull request #941 from JuliaSymbolics/myb/nanmath
Browse files Browse the repository at this point in the history
Add nanmath option
  • Loading branch information
ChrisRackauckas authored Mar 12, 2024
2 parents f39e633 + f601457 commit d501299
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ function _build_function(target::JuliaTarget, op, args...;
states = LazyState(),
linenumbers = true,
wrap_code = nothing,
cse = false, kwargs...)
cse = false,
nanmath = true,
kwargs...)

states.rewrites[:nanmath] = nanmath
dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
expr = if cse
fun = Func(dargs, [], Code.cse(unwrap(op)))
Expand Down Expand Up @@ -136,11 +140,14 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
checkbounds = false,
states = LazyState(),
linenumbers = true,
cse = false, kwargs...)
cse = false,
nanmath = true,
kwargs...)

dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))

states.rewrites[:nanmath] = nanmath
expr = if cse
conv(Func(dargs, [], Code.cse(unwrap(op))), states)
else
Expand Down Expand Up @@ -280,8 +287,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
iip_config = (true, true),
nanmath = true,
parallel=nothing, cse = false, kwargs...)

states.rewrites[:nanmath] = nanmath
# We cannot switch to ShardedForm because it deadlocks with
# RuntimeGeneratedFunctions
dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Expand Down
5 changes: 5 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ using ReferenceTests
using Symbolics: value
using SymbolicUtils.Code: DestructuredArgs, Func
@variables a b c1 c2 c3 d e g
oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true)
@test all(isnan, eval(oop)([-1, Inf]))
out = [0, 0.0]
eval(iip)(out, [-1, Inf])
@test all(isnan, out)

# Multiple argument matrix
h = [a + b + c1 + c2,
Expand Down

0 comments on commit d501299

Please sign in to comment.