From fcee7cf66e999df392e5edf121613a75d8418fe5 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 31 Jul 2023 20:18:54 -0400 Subject: [PATCH] Add nanmath option --- src/build_function.jl | 19 +++++++++++++++++-- test/build_function.jl | 5 +++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 457078c95..ca645757c 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -108,7 +108,13 @@ function _build_function(target::JuliaTarget, op, args...; states = LazyState(), linenumbers = true, wrap_code = nothing, - cse = false, kwargs...) + cse = false, + nanmath = false, + kwargs...) + if nanmath + states.rewrites[:nanmath] = true + end + dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) expr = if cse fun = Func(dargs, [], Code.cse(unwrap(op))) @@ -136,11 +142,16 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; checkbounds = false, states = LazyState(), linenumbers = true, - cse = false, kwargs...) + cse = false, + nanmath = false, + kwargs...) dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) + if nanmath + states.rewrites[:nanmath] = true + end expr = if cse toexpr(Func(dargs, [], Code.cse(unwrap(op))), states) else @@ -262,8 +273,12 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; fillzeros = skipzeros && !(rhss isa SparseMatrixCSC), states = LazyState(), iip_config = (true, true), + nanmath = false, parallel=nothing, cse = false, kwargs...) + if nanmath + states.rewrites[:nanmath] = true + end # We cannot switch to ShardedForm because it deadlocks with # RuntimeGeneratedFunctions dargs = map((x) -> destructure_arg(x[2], !checkbounds, diff --git a/test/build_function.jl b/test/build_function.jl index aa79e4209..18e0a574b 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -3,6 +3,11 @@ using ReferenceTests using Symbolics: value using SymbolicUtils.Code: DestructuredArgs, Func @variables a b c1 c2 c3 d e g +oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true) +@test all(isnan, eval(oop)([-1, Inf])) +out = [0, 0.0] +eval(iip)(out, [-1, Inf]) +@test all(isnan, out) # Multiple argument matrix h = [a + b + c1 + c2,