Skip to content

Commit

Permalink
add back_AD and test
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Sep 23, 2023
1 parent 7047e66 commit 5c843bf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
61 changes: 61 additions & 0 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,64 @@ function backAD(diag::Graph{F,W}, debug::Bool=false) where {F,W}
# return dual[rootid]
end


function back_AD(diag::Graph{F,W}) where {F,W}
dual = Dict{Int,Union{F,Graph{F,W}}}()
# println("rootID: ", diag.id)
for node in PreOrderDFS(diag)
visited = false
if haskey(dual, node.id)
node_dual = dual[node.id]
node_dual.name != "None" && continue
visited = true
end
# println("Node: ", node.id)

if node.operator == Sum
nodes_deriv = Graph[]
for sub_node in node.subgraphs
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
push!(nodes_deriv, dual[sub_node.id])
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(Graph[]; factor=sub_node.factor, weight=sub_node.weight, name="None")
push!(nodes_deriv, g_dual)
dual[sub_node.id] = g_dual
end
end
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = node.subgraph_factors
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors, factor=node.factor, weight=node.weight)
end
elseif node.operator == Prod
nodes_deriv = Graph[]
for (i, sub_node) in enumerate(node.subgraphs)
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
subgraphs = [j == i ? dual[subg.id] : g for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(Graph[]; factor=sub_node.factor, weight=sub_node.weight, name="None")
dual[sub_node.id] = g_dual
subgraphs = [j == i ? g_dual : subg for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))

end
end
println(nodes_deriv)
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = node.subgraph_factors
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; factor=node.factor, weight=node.weight)
end
end
end
return dual
end
19 changes: 18 additions & 1 deletion test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ end

@testset verbose = true "Auto Differentiation" begin
using FeynmanDiagram.ComputationalGraphs:
eval!, frontAD, backAD, node_derivative
eval!, frontAD, backAD, node_derivative, back_AD
g1 = propagator(𝑓⁻(1)𝑓⁺(2))
g2 = propagator(𝑓⁻(3)𝑓⁺(4))
g3 = propagator(𝑓⁻(5)𝑓⁺(6), factor=2.0)
Expand Down Expand Up @@ -286,6 +286,23 @@ end
# end
# end
end
@testset "back_AD" begin
F3 = g1 + g2
F2 = linear_combination([g1, g3, F3], [2, 1, 3])
# F1 = Graph([g1, F2, F3], operator=FeynmanDiagram.ComputationalGraphs.Prod())
F1 = Graph([g1, F2, F3], operator=Graphs.Prod())

dual = back_AD(F1)
leafmap = Dict{Int,Int}()
leafmap[g1.id], leafmap[g2.id], leafmap[g3.id] = 1, 2, 3
leafmap[dual[g1.id].id] = 4
leafmap[dual[g2.id].id] = 5
leafmap[dual[g3.id].id] = 6
leaf = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0] # d / d g1
@test eval!(dual[F1.id], leafmap, leaf) == 40.0
@test eval!(dual[F2.id], leafmap, leaf) == 5.0
@test eval!(dual[F3.id], leafmap, leaf) == 1.0
end
end

@testset verbose = true "Tree properties" begin
Expand Down

0 comments on commit 5c843bf

Please sign in to comment.