Skip to content

Commit

Permalink
Fix primal evaluation (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 authored May 2, 2024
1 parent 5cacc30 commit ff266f2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.9.1"
version = "0.9.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ end
function sumheat(heat::Heat, chkpscheme::Scheme, tsteps::Int64)
# AD: Create shadow copy for derivatives
@checkpoint_struct chkpscheme heat for i in 1:tsteps
# checkpoint_struct_for(advance, heat)
heat.Tlast .= heat.Tnext
advance(heat)
end
Expand All @@ -87,7 +86,7 @@ function heat(scheme::Scheme, tsteps::Int)
heat.Tnext[end] = 0

# Compute gradient
autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), scheme, tsteps)
autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), Const(scheme), Const(tsteps))

return heat.Tnext, dheat.Tnext[2:end-1]
end
Expand Down
4 changes: 2 additions & 2 deletions examples/optcontrol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Technique for Resilience. United States: N. p., 2016. https://www.osti.gov/biblio/1364654.

using Checkpointing
using Zygote
using Enzyme


include("optcontrolfunc.jl")
Expand Down Expand Up @@ -69,7 +69,7 @@ function muoptcontrol(scheme, steps, ::EnzymeTool)
end
return model.F[2]
end
autodiff(Enzyme.ReverseWithPrimal, foo, Duplicated(model, bmodel))
autodiff(Enzyme.Reverse, foo, Duplicated(model, bmodel))

F = model.F
L = bmodel.F
Expand Down
9 changes: 3 additions & 6 deletions src/Rules/EnzymeRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ function augmented_primal(
model,
range,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), range.val)
if needs_primal(config)
primal = func.val(body.val, alg.val, model.val, range.val)
return AugmentedReturn(primal, nothing, (model.val,))
else
return AugmentedReturn(nothing, nothing, (model.val,))
Expand Down Expand Up @@ -50,12 +50,9 @@ function augmented_primal(
model,
condition,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), condition.val)
if needs_primal(config)
return AugmentedReturn(
func.val(body.val, alg.val, model.val, condition.val),
nothing,
(model.val,),
)
return AugmentedReturn(primal, nothing, (model.val,))
else
return AugmentedReturn(nothing, nothing, (model.val,))
end
Expand Down

0 comments on commit ff266f2

Please sign in to comment.