From efcc64bbbe7534f71998d2371c4a89b0e95bbcdb Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Tue, 10 Oct 2023 19:39:23 -0700 Subject: [PATCH] Handle `Expr(:boundscheck)` --- src/compiler/reverse.jl | 4 ++++ test/compiler.jl | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 333323e83..c72e1a5a6 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -132,6 +132,10 @@ function instrument(ir::IR) elseif isexpr(ex, :(=)) @assert ex.args[1] isa GlobalRef pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2]) + elseif isexpr(ex, :boundscheck) + # Expr(:boundscheck) now appears in common Julia code paths, so we need to handle it. + # For correctness sake, fix to true like https://github.com/dfdx/Umlaut.jl/issues/34. + pr[v] = true else ex = instrument_new!(pr, v, ex) ex = instrument_literals!(pr, v, ex) diff --git a/test/compiler.jl b/test/compiler.jl index c9b091f78..381c4bc14 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -225,3 +225,12 @@ end # issue 897 @test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] ≈ fill(0.5773502691896258, 3, 400) + +# Tests adapted from https://github.com/dfdx/Umlaut.jl/pull/35 +@eval _boundscheck_foo(x) = ifelse($(Expr(:boundscheck)), 2x, x) + +@testset "Meta Expr handling" begin + y, (dx,) = withgradient(_boundscheck_foo, 1) + @test y == 2 + @test dx == 2 +end