Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tao ad #130

Merged
merged 4 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ DataFrames = "1.6"
Lehmann = "0.2"
Parameters = "0.12"
PyCall = "1"
StaticArrays = "1"
RuntimeGeneratedFunctions = "0.5"
SnoopPrecompile = "1, 2"
StaticArrays = "1"
julia = "1.6"

[extras]
Expand Down
8 changes: 5 additions & 3 deletions src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ export external_labels
# export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction

include("tree_properties.jl")
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, eldest, totaloperation

# include("operation.jl")
include("operation.jl")
export derivative
include("graphvector.jl")
include("io.jl")
# plot_tree
Expand All @@ -42,5 +43,6 @@ export prune_trivial_unary, merge_prefactors

include("optimize.jl")
export optimize!

include("eval.jl")
export evalGraph!
end
108 changes: 19 additions & 89 deletions src/computational_graph/eval.jl
Original file line number Diff line number Diff line change
@@ -1,95 +1,25 @@
@inline apply(o::Sum, diags::Vector{Diagram{W}}) where {W<:Number} = sum(d.weight for d in diags)
@inline apply(o::Prod, diags::Vector{Diagram{W}}) where {W<:Number} = prod(d.weight for d in diags)
@inline apply(o::Sum, diag::Diagram{W}) where {W<:Number} = diag.weight
@inline apply(o::Prod, diag::Diagram{W}) where {W<:Number} = diag.weight

@inline eval(d::DiagramId) = error("eval for $d has not yet implemented!")

######################### evaluator for KT representation #########################
function evalDiagNodeKT!(diag::Diagram, varK, varT, additional=nothing; eval=DiagTree.eval)
if length(diag.subdiagram) == 0
# if hasproperty(diag.id, :extK)
if (isnothing(varK) == false) && (isnothing(varT) == false)
K = varK * diag.id.extK
if isnothing(additional)
diag.weight = eval(diag.id, K, diag.id.extT, varT) * diag.factor
else
diag.weight = eval(diag.id, K, diag.id.extT, varT, additional) * diag.factor
end
elseif isnothing(varK)
if isnothing(additional)
diag.weight = eval(diag.id, diag.id.extT, varT) * diag.factor
else
diag.weight = eval(diag.id, diag.id.extT, varT, additional) * diag.factor
end
elseif isnothing(varT)
K = varK * diag.id.extK
if isnothing(additional)
diag.weight = eval(diag.id, K) * diag.factor
else
diag.weight = eval(diag.id, K, additional) * diag.factor
end
#@inline add( diags::Vector{Graph{F,W}}) where {F<:Number,W<:Number} = sum(d.weight for d in diags)
@inline apply(::Type{Sum}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = sum(d.weight*d.factor*f for (d,f) in zip(diags,factors) )
@inline apply(::Type{Prod}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = prod(d.weight*d.factor*f for (d,f) in zip(diags,factors) )
@inline apply(o::Sum, diag::Graph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Prod, diag::Graph{F,W}) where {F<:Number,W<:Number} = diag.weight

# #@inline eval(d::DiagramId) = error("eval for $d has not yet implemented!")
# #
function evalGraph!(g::Graph)
result = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename evalGraph! -> eval!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing an additional vector "leaf" which contains the leaf weight.

eval! should have the same behavior as the compiled function

for node in PostOrderDFS(g)
if isleaf(node)
node.weight = 1.0 ##Currently set to 1 just for test. In the future, probably the whole function is not needed since we have the compiler.
else
if isnothing(additional)
diag.weight = eval(diag.id) * diag.factor
else
diag.weight = eval(diag.id, additional) * diag.factor
end
node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors)
#node.weight = add(node.subgraphs) * node.factor
end
else
diag.weight = apply(diag.operator, diag.subdiagram) * diag.factor
end
return diag.weight
end

function evalKT!(diag::Diagram, varK, varT; eval=DiagTree.eval)
for d in PostOrderDFS(diag)
evalDiagNodeKT!(d, varK, varT; eval=eval)
end
return diag.weight
end

function evalKT!(diags::Vector{Diagram{W}}, varK, varT; eval=DiagTree.eval) where {W}
for d in diags
evalKT!(d, varK, varT; eval=eval)
end
# return W[d.weight for d in diags]
end

function evalKT!(df::DataFrame, varK, varT; eval=DiagTree.eval) where {W}
for d in df[!, :diagram]
evalKT!(d, varK, varT; eval=eval)
end
# return W[d.weight for d in df[!, :Diagram]]
end

######################### generic evaluator #########################
function evalDiagNode!(diag::Diagram, vargs...; eval=DiagTree.eval)
if length(diag.subdiagram) == 0
diag.weight = eval(diag.id, vargs...) * diag.factor
else
diag.weight = apply(diag.operator, diag.subdiagram) * diag.factor
end
return diag.weight
end

function eval!(diag::Diagram, vargs...; eval=DiagTree.eval)
for d in PostOrderDFS(diag)
evalDiagNode!(d, vargs...; eval=eval)
result = node.weight * node.factor
end
return diag.weight
return result
end

function eval!(diags::Vector{Diagram{W}}, vargs...; eval=DiagTree.eval) where {W}
for d in diags
eval!(d, vargs...; eval=eval)
end
# return W[d.weight for d in diags]
function evalGraph!(g::Number)
return g
end

function eval!(df::DataFrame, vargs...; eval=DiagTree.eval) where {W}
for d in df[!, :diagram]
eval!(d, vargs...; eval=eval)
end
# return W[d.weight for d in df[!, :Diagram]]
end
60 changes: 54 additions & 6 deletions src/computational_graph/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ mutable struct Graph{F,W} # Graph
end
end

const Unity = Graph([];ftype = Float64, wtype = Float64, weight = 1.0)


function Base.isequal(a::Graph, b::Graph)
typeof(a) != typeof(b) && return false
for field in fieldnames(typeof(a))
Expand Down Expand Up @@ -213,21 +216,47 @@ function Base.:*(c1::C, g2::Graph{F,W}) where {F,W,C}
subgraph_factors=[F(c1),], type=g2.type(), operator=Prod(), ftype=F, wtype=W)
# Merge multiplicative link
if g2.operator == Prod && onechild(g2)
##when prune single child nodes, why the subgraph_factors are merged, but factor is not merged ?
g.subgraph_factors[1] *= g2.subgraph_factors[1]
g.subgraphs = g2.subgraphs
end
return g
end

function Base.:*(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W}
# Currently Prod of two green's function ignore topology
if g1.operator == Prod && onechild(g1)
g1_sub = g1.subgraphs[1]
subfactor1 = g1.subgraph_factors[1] * g1.factor
else
g1_sub = g1
subfactor1 = F(1.0)
end

if g2.operator == Prod && onechild(g2)
g2_sub = g2.subgraphs[1]
subfactor2 = g2.subgraph_factors[1]*g2.factor
else
g2_sub = g2
subfactor2 = F(1.0)
end

g = Graph([g1_sub,g2_sub];
subgraph_factors=[F(subfactor1),F(subfactor2)] , type=g2_sub.type(), operator=Prod(), ftype=F, wtype=W)
# Merge multiplicative link

return g
end

"""
function linear_combination(g1::Graph{F,W}, g2::Graph{F,W}, c1::C, c2::C) where {F,W,C}

Returns a graph representing the linear combination `c1*g1 + c2*g2`.
"""
function linear_combination(g1::Graph{F,W}, g2::Graph{F,W}, c1::C, c2::C) where {F,W,C}
@assert g1.type == g2.type "g1 and g2 are not of the same type."
@assert g1.orders == g2.orders "g1 and g2 have different orders."
@assert Set(external(g1)) == Set(external(g2)) "g1 and g2 have different external vertices."
#@assert g1.type == g2.type "g1 and g2 are not of the same type."
#@assert g1.orders == g2.orders "g1 and g2 have different orders."
#@assert Set(external(g1)) == Set(external(g2)) "g1 and g2 have different external vertices."
total_vertices = union(g1.vertices, g2.vertices)
return Graph([g1, g2]; vertices=total_vertices, external=g1.external, hasLeg=g1.hasLeg,
subgraph_factors=[F(c1), F(c2)], type=g1.type(), operator=Sum(), ftype=F, wtype=W)
Expand All @@ -241,23 +270,42 @@ end
graph representing the linear combination (𝐜 ⋅ 𝐠).
"""
function linear_combination(graphs::Vector{Graph{F,W}}, constants::Vector{C}) where {F,W,C}
@assert alleq(getproperty.(graphs, :type)) "Graphs are not all of the same type."
@assert alleq(getproperty.(graphs, :orders)) "Graphs do not all have the same order."
@assert alleq(Set.(external.(graphs))) "Graphs do not share the same set of external vertices."
#@assert alleq(getproperty.(graphs, :type)) "Graphs are not all of the same type."
#@assert alleq(getproperty.(graphs, :orders)) "Graphs do not all have the same order."
#@assert alleq(Set.(external.(graphs))) "Graphs do not share the same set of external vertices."
total_vertices = union(Iterators.flatten(vertices.(graphs)))
g1 = graphs[1]
return Graph(graphs; vertices=total_vertices, external=g1.external, hasLeg=g1.hasLeg,
subgraph_factors=constants, type=g1.type(), operator=Sum(), ftype=F, wtype=W)
end




function Base.:+(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W}
return linear_combination(g1, g2, F(1), F(1))
end


function Base.:-(g1::Graph{F,W}, g2::Graph{F,W}) where {F,W}
return linear_combination(g1, g2, F(1), F(-1))
end

# function Base.:+(c::C, g1::Graph{F,W}) where {F,W,C}
# return linear_combination(g1, Unity, F(1), F(c))
# end
# function Base.:+(g1::Graph{F,W},c::C) where {F,W,C}
# return linear_combination(g1, Unity, F(1), F(c))
# end

# function Base.:-(c::C, g1::Graph{F,W}) where {F,W,C}
# return linear_combination(Unity, g1, F(c), F(-1))
# end
# function Base.:-(g1::Graph{F,W},c::C) where {F,W,C}
# return linear_combination(g1, Unity, F(1), F(-c))
# end


"""
function feynman_diagram(subgraphs::Vector{Graph{F,W}}, topology::Vector{Vector{Int}}, perm_noleg::Union{Vector{Int},Nothing}=nothing;
factor=one(_dtype.factor), weight=zero(_dtype.weight), name="", diagtype::GraphType=GenericDiag()) where {F,W}
Expand Down
Loading
Loading