-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support try/catch on the happy (nothrow) path #1474
Changes from 6 commits
3ddf945
8f416aa
0836d94
07ec290
c0e5ba1
9e6e63b
e00a28c
4d52ed4
d56dd27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,11 +124,6 @@ function instrument(ir::IR) | |
ex = st.expr | ||
if isexpr(ex, :foreigncall, :isdefined) | ||
continue | ||
elseif isexpr(ex, :enter, :leave) | ||
error("""try/catch is not supported. | ||
Refer to the Zygote documentation for fixes. | ||
https://fluxml.ai/Zygote.jl/latest/limitations | ||
""") | ||
elseif isexpr(ex, :(=)) | ||
@assert ex.args[1] isa GlobalRef | ||
pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2]) | ||
|
@@ -258,7 +253,7 @@ function adjointcfg(pr::Primal) | |
end | ||
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable) | ||
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb` | ||
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` | ||
push!(rb, xcall(Base, :error, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` | ||
branch!(rb, 0) | ||
end | ||
end | ||
|
@@ -279,7 +274,7 @@ xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...)) | |
|
||
function passthrough_expr(ex::Expr) | ||
# Metadata we want to preserve | ||
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true | ||
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo, :enter, :leave, :catch) && return true | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# ccalls and more that are safe to preserve/required for proper operation: | ||
# - jl_set_task_threadpoolid: added in 1.9 for @spawn | ||
isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true | ||
|
@@ -297,9 +292,14 @@ function adjoint(pr::Primal) | |
for i = 1:length(sigs[b.id]) | ||
grad(sigs[b.id][i], arguments(rb)[i]) | ||
end | ||
|
||
has_leave = false | ||
|
||
# Backprop through statements | ||
for v in reverse(keys(b)) | ||
ex = b[v].expr | ||
has_leave |= isexpr(ex, :leave) | ||
|
||
if haskey(pr.pullbacks, v) | ||
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), | ||
line = b[v].line)) | ||
|
@@ -321,6 +321,18 @@ function adjoint(pr::Primal) | |
continue | ||
end | ||
end | ||
|
||
# This is corresponds to a catch blocks which technically | ||
# has predecessors but they are not modelled in the IRTools CFG. | ||
# We put an error message at the beginning of said block. | ||
if has_leave && isempty(predecessors(b)) && b.id != 1 | ||
_, f_stmt = first(b) | ||
li = pr.ir.lines[f_stmt.line] | ||
li = LineNumberNode(Int(li.line), li.file) | ||
pushfirst!(rb, stmt(xcall(Base, :error, | ||
"Can't differentiate function execution in catch block at $(li)."))) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think throwing the error in the catch block is right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try finally is implemented in term of try
foo()
origin = 1
catch
origin = 2
end
if origin == 2
rethrow()
end I will improve the error message to include try/finally as another construct which is not supported when an error is thrown in the block. |
||
|
||
if b.id > 1 # Backprop through (predecessor) branch arguments | ||
gs = grad.(arguments(b)) | ||
for br in branches(rb) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -245,3 +245,81 @@ end | |
@test_nowarn g = back(1.) | ||
@test only(g) ∈ (1., 2.) | ||
end | ||
|
||
function throws_and_catches_if_x_negative(x,y) | ||
z = x + y | ||
try | ||
if x < 0. | ||
throw(DomainError("x is negative")) | ||
end | ||
z = 2z + x + y | ||
catch err | ||
@error "something went wrong" exception=(err,catch_backtrace()) | ||
end | ||
return 3z | ||
end | ||
|
||
function try_catch_finally(cond, x) | ||
|
||
try | ||
x = 2x | ||
cond && throw(DomainError()) | ||
catch | ||
x = 2x | ||
finally | ||
x = 3x | ||
end | ||
|
||
x | ||
end | ||
|
||
if VERSION >= v"1.8" | ||
# try/catch/else is invalid syntax prior to v1.8 | ||
eval(Meta.parse(""" | ||
function try_catch_else(cond, x) | ||
end | ||
""")) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this code seems unused and incomplete |
||
|
||
@testset "try/catch" begin | ||
@testset "happy path (nothrow)" begin | ||
res, (dx,dy) = withgradient(throws_and_catches_if_x_negative, 1., 2.) | ||
@test res == 3 * (2 * (1. + 2.) + 1. + 2.) | ||
@test dx == 3. * (2. + 1.) | ||
@test dy == 3. * (2. + 1.) | ||
end | ||
|
||
@testset "try/catch/finally" begin | ||
res, (_, dx,) = withgradient(try_catch_finally, false, 1.) | ||
@test res == 6. | ||
@test dx == 6. | ||
|
||
res, pull = pullback(try_catch_finally, true, 1.) | ||
@test res == 12. | ||
@test_throws ErrorException pull(1.) | ||
err = try pull(1.) catch ex; ex end | ||
@test occursin("Can't differentiate function execution in catch block", | ||
string(err)) | ||
end | ||
|
||
function foo_try(f) | ||
y = 1 | ||
try | ||
y = f() | ||
catch | ||
y | ||
end | ||
y | ||
end | ||
|
||
g, = gradient(x -> foo_try(() -> x), 1) # 1 | ||
@test g == 1. | ||
|
||
vy, pull = pullback(foo_try, () -> 0//0) # bypass because of expr | ||
@test vy === 1 | ||
@test_throws ErrorException pull(1.) | ||
|
||
err = try pull(1.) catch ex; ex end | ||
@test occursin("Can't differentiate function execution in catch block", | ||
string(err)) | ||
end | ||
Pangoraw marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believer we still need this on versions of julia that do not have
enter
andleave
?so maybe we keep this and add a
&& VERSION<v"1.10"
(is that the right bounds?)