Skip to content

Commit

Permalink
Add nanmath option
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Aug 1, 2023
1 parent d0b1724 commit fcee7cf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ function _build_function(target::JuliaTarget, op, args...;
states = LazyState(),
linenumbers = true,
wrap_code = nothing,
cse = false, kwargs...)
cse = false,
nanmath = false,
kwargs...)
if nanmath
states.rewrites[:nanmath] = true
end

dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
expr = if cse
fun = Func(dargs, [], Code.cse(unwrap(op)))
Expand Down Expand Up @@ -136,11 +142,16 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
checkbounds = false,
states = LazyState(),
linenumbers = true,
cse = false, kwargs...)
cse = false,
nanmath = false,
kwargs...)

dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))

if nanmath
states.rewrites[:nanmath] = true
end
expr = if cse
toexpr(Func(dargs, [], Code.cse(unwrap(op))), states)
else
Expand Down Expand Up @@ -262,8 +273,12 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
iip_config = (true, true),
nanmath = false,
parallel=nothing, cse = false, kwargs...)

if nanmath
states.rewrites[:nanmath] = true
end
# We cannot switch to ShardedForm because it deadlocks with
# RuntimeGeneratedFunctions
dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Expand Down
5 changes: 5 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ using ReferenceTests
using Symbolics: value
using SymbolicUtils.Code: DestructuredArgs, Func
@variables a b c1 c2 c3 d e g
oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true)
@test all(isnan, eval(oop)([-1, Inf]))
out = [0, 0.0]
eval(iip)(out, [-1, Inf])
@test all(isnan, out)

# Multiple argument matrix
h = [a + b + c1 + c2,
Expand Down

0 comments on commit fcee7cf

Please sign in to comment.