Skip to content

Commit

Permalink
remove factor
Browse files Browse the repository at this point in the history
  • Loading branch information
houpc committed Jan 9, 2024
1 parent 4dc4408 commit ee79c8f
Show file tree
Hide file tree
Showing 15 changed files with 220 additions and 220 deletions.
2 changes: 1 addition & 1 deletion src/FeynmanDiagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti
# export reducibility, connectivity
# export 𝐺ᶠ, 𝐺ᵇ, 𝐺ᵠ, 𝑊, Green2, Interaction
# export Coupling_yukawa, Coupling_phi3, Coupling_phi4, Coupling_phi6
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest
export haschildren, onechild, isleaf, isbranch, ischain, has_zero_subfactors, eldest
export relabel!, standardize_labels!, replace_subgraph!, merge_linear_combination!, merge_multi_product!, merge_chains!
export relabel, standardize_labels, replace_subgraph, merge_linear_combination, merge_multi_product, merge_chains
export open_parenthesis, open_parenthesis!, flatten_prod!, flatten_prod, flatten_sum!, flatten_sum
Expand Down
2 changes: 1 addition & 1 deletion src/computational_graph/ComputationalGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export multi_product, linear_combination, feynman_diagram, propagator, interacti


include("tree_properties.jl")
export haschildren, onechild, isleaf, isbranch, ischain, isfactorless, has_zero_subfactors, eldest, count_operation
export haschildren, onechild, isleaf, isbranch, ischain, has_zero_subfactors, eldest, count_operation

include("operation.jl")
include("io.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/computational_graph/abstractgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ function isequiv(a::AbstractGraph, b::AbstractGraph, args...)
# elseif field == :subgraph_factors && getproperty(a, :subgraph_factors) == subgraph_factors(a) && getproperty(b, :subgraph_factors) == subgraph_factors(b)
# continue # skip subgraph_factors if already accounted for
else
getproperty(a, field) != getproperty(b, field) && return false
# getproperty(a, field) != getproperty(b, field) && return false
if getproperty(a, field) != getproperty(b, field)
# println(field)
return false
end
end
end
return true
Expand Down
2 changes: 1 addition & 1 deletion src/computational_graph/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- `g` computational graph
"""
function Base.convert(::Type{G}, g::FeynmanGraph{F,W}) where {F,W,G<:Graph}
return Graph(g.subgraphs; subgraph_factors=g.subgraph_factors, name=g.name, operator=g.operator(), orders=g.orders, ftype=F, wtype=W, factor=g.factor, weight=g.weight)
return Graph(g.subgraphs; subgraph_factors=g.subgraph_factors, name=g.name, operator=g.operator(), orders=g.orders, ftype=F, wtype=W, weight=g.weight)
end

function Base.convert(::Type{FeynmanGraph}, g::Graph{F,W}) where {F,W}
Expand Down
36 changes: 22 additions & 14 deletions src/computational_graph/eval.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
@inline apply(::Type{Sum}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = sum(d.weight * d.factor * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Prod}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = prod(d.weight * d.factor * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Power{N}}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {N,F<:Number,W<:Number} = (diags[1].weight * diags[1].factor)^N * factors[1]
@inline apply(::Type{Sum}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = sum(d.weight * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Prod}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = prod(d.weight * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Power{N}}, diags::Vector{Graph{F,W}}, factors::Vector{F}) where {N,F<:Number,W<:Number} = (diags[1].weight)^N * factors[1]
@inline apply(o::Sum, diag::Graph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Prod, diag::Graph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Power{N}, diag::Graph{F,W}) where {N,F<:Number,W<:Number} = diag.weight

@inline apply(::Type{Sum}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = sum(d.weight * d.factor * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Prod}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = prod(d.weight * d.factor * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Power{N}}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {N,F<:Number,W<:Number} = (diags[1].weight * diags[1].factor)^N * factors[1]
@inline apply(::Type{Sum}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = sum(d.weight * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Prod}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {F<:Number,W<:Number} = prod(d.weight * f for (d, f) in zip(diags, factors))
@inline apply(::Type{Power{N}}, diags::Vector{FeynmanGraph{F,W}}, factors::Vector{F}) where {N,F<:Number,W<:Number} = (diags[1].weight)^N * factors[1]
@inline apply(o::Sum, diag::FeynmanGraph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Prod, diag::FeynmanGraph{F,W}) where {F<:Number,W<:Number} = diag.weight
@inline apply(o::Power{N}, diag::FeynmanGraph{F,W}) where {N,F<:Number,W<:Number} = diag.weight
Expand All @@ -33,26 +33,34 @@ function eval!(g::Graph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vect
else
node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors)
end
result = node.weight * node.factor
result = node.weight
end
return result
end


function eval!(g::FeynmanGraph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}()) where {F,W}
function eval!(g::FeynmanGraph{F,W}, leafmap::Dict{Int,Int}=Dict{Int,Int}(), leaf::Vector{W}=Vector{W}(); inherit=false, randseed::Int=-1) where {F,W}
result = nothing

if randseed > 0
Random.seed!(randseed)
end
for node in PostOrderDFS(g)
if isleaf(node)
if isempty(leafmap)
node.weight = 1.0
else
node.weight = leaf[leafmap[node.id]]
if !inherit
if isempty(leafmap)
if randseed < 0
node.weight = 1.0
else
node.weight = rand()
end
else
node.weight = leaf[leafmap[node.id]]
end
end
else
node.weight = apply(node.operator, node.subgraphs, node.subgraph_factors)
end
result = node.weight * node.factor
result = node.weight
end
return result
end
Expand Down
22 changes: 18 additions & 4 deletions src/computational_graph/feynmangraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ Base.:(==)(a::FeynmanProperties, b::FeynmanProperties) = Base.isequal(a, b)
Returns a copy of the given FeynmanProperties `p` modified to have no topology.
"""
drop_topology(p::FeynmanProperties) = FeynmanProperties(p.diagtype, p.vertices, [], p.external_indices, p.external_legs)

"""
mutable struct FeynmanGraph{F<:Number,W}
Expand Down Expand Up @@ -120,7 +119,14 @@ mutable struct FeynmanGraph{F<:Number,W} <: AbstractGraph # FeynmanGraph
vertices = [external_operators(g) for g in subgraphs if diagram_type(g) != Propagator]
end
properties = FeynmanProperties(typeof(diagtype), vertices, topology, external_indices, external_legs)
return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
# return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
g = new{ftype,wtype}(uid(), String(name), orders, properties, subgraphs, subgraph_factors, typeof(operator), one(ftype), weight)

if factor one(ftype)
return g
else
return new{ftype,wtype}(uid(), String(name), orders, properties, [g,], [factor,], Prod, one(ftype), weight * factor)
end
end

"""
Expand Down Expand Up @@ -150,7 +156,14 @@ mutable struct FeynmanGraph{F<:Number,W} <: AbstractGraph # FeynmanGraph
@assert length(subgraphs) == 1 "FeynmanGraph with Power operator must have one and only one subgraph."
end
# @assert allunique(subgraphs) "all subgraphs must be distinct."
return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
# return new{ftype,wtype}(uid(), name, orders, properties, subgraphs, subgraph_factors, typeof(operator), factor, weight)
g = new{ftype,wtype}(uid(), String(name), orders, properties, subgraphs, subgraph_factors, typeof(operator), one(ftype), weight)

if factor one(ftype)
return g
else
return new{ftype,wtype}(uid(), String(name), orders, properties, [g,], [factor,], Prod, one(ftype), weight * factor)
end
end

"""
Expand All @@ -165,7 +178,8 @@ mutable struct FeynmanGraph{F<:Number,W} <: AbstractGraph # FeynmanGraph
function FeynmanGraph(g::Graph{F,W}, properties::FeynmanProperties) where {F,W}
@assert length(properties.external_indices) == length(properties.external_legs)
# @assert allunique(subgraphs) "all subgraphs must be distinct."
return new{F,W}(uid(), g.name, g.orders, properties, g.subgraphs, g.subgraph_factors, g.operator, g.factor, g.weight)
# return new{F,W}(uid(), g.name, g.orders, properties, g.subgraphs, g.subgraph_factors, g.operator, g.factor, g.weight)
return new{F,W}(uid(), g.name, g.orders, properties, [FeynmanGraph(subg, subg.properties) for subg in g.subgraphs], g.subgraph_factors, g.operator, g.factor, g.weight)
end
end

Expand Down
9 changes: 7 additions & 2 deletions src/computational_graph/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mutable struct Graph{F<:Number,W} <: AbstractGraph # Graph
# @assert allunique(subgraphs) "all subgraphs must be distinct."
g = new{ftype,wtype}(uid(), String(name), orders, subgraphs, subgraph_factors, typeof(operator), one(ftype), weight, properties)

if (factor one(ftype))
if factor one(ftype)
return g
else
return new{ftype,wtype}(uid(), String(name), orders, [g,], [factor,], Prod, one(ftype), weight * factor, properties)
Expand Down Expand Up @@ -117,7 +117,12 @@ set_subgraph_factors!(g::Graph{F,W}, subgraph_factors::AbstractVector, indices::
- `f`: constant factor
"""
function constant_graph(factor=one(_dtype.factor))
return Graph([]; operator=Constant(), factor=factor, ftype=_dtype.factor, wtype=_dtype.weight, weight=one(_dtype.weight))
g = Graph([]; operator=Constant(), ftype=_dtype.factor, wtype=_dtype.weight, weight=one(_dtype.weight))
if factor one(_dtype.factor)
return g
else
return g * factor
end
end

"""
Expand Down
9 changes: 6 additions & 3 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1,
dual[key_node].subgraph_factors = node.subgraph_factors
dual[key_node].name = node.name
else
dual[key_node] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors, factor=node.factor)
# dual[key_node] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors, factor=node.factor)
dual[key_node] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors)
end
elseif node.operator == Prod
nodes_deriv = G[]
Expand All @@ -416,7 +417,8 @@ function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1,
dual[key_node].subgraph_factors = one.(eachindex(nodes_deriv))
dual[key_node].name = node.name
else
dual[key_node] = Graph(nodes_deriv; factor=node.factor)
# dual[key_node] = Graph(nodes_deriv; factor=node.factor)
dual[key_node] = Graph(nodes_deriv)
end
elseif node.operator <: Power # node with Power operator has only one subgraph!
nodes_deriv = G[]
Expand All @@ -437,7 +439,8 @@ function forwardAD_root!(graphs::AbstractVector{G}, idx::Int=1,
dual[key_node].name = node.name
dual.operator = Prod
else
dual[key_node] = Graph(nodes_deriv; subgraph_factors=[1, node.subgraph_factors[1]], operator=Prod(), factor=node.factor)
# dual[key_node] = Graph(nodes_deriv; subgraph_factors=[1, node.subgraph_factors[1]], operator=Prod(), factor=node.factor)
dual[key_node] = Graph(nodes_deriv; subgraph_factors=[1, node.subgraph_factors[1]], operator=Prod())
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/computational_graph/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ function burn_from_targetleaves!(graphs::AbstractVector{G}, targetleaves_id::Abs
verbose > 0 && println("remove all nodes connected to the target leaves via Prod operators.")

graphs_sum = linear_combination(graphs, one.(eachindex(graphs)))
ftype = typeof(factor(graphs[1]))
ftype = eltype(subgraph_factors(graphs[1]))

for leaf in Leaves(graphs_sum)
if !isdisjoint(id(leaf), targetleaves_id)
Expand Down
3 changes: 2 additions & 1 deletion src/computational_graph/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,11 @@ function merge_multi_product!(g::Graph{F,W}) where {F,W}
end
end

if length(unique_factors) == 1
if length(unique_factors) == 1 && repeated_counts[1] > 1
g.subgraphs = unique_graphs
g.subgraph_factors = unique_factors
g.operator = typeof(Power(repeated_counts[1]))
# g.operator = repeated_counts[1] == 1 ? Prod : typeof(Power(repeated_counts[1]))
else
_subgraphs = Vector{Graph{F,W}}()
for (idx, g) in enumerate(unique_graphs)
Expand Down
34 changes: 17 additions & 17 deletions src/computational_graph/tree_properties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,23 @@ function ischain(g::AbstractGraph)
return false
end

"""
function isfactorless(g)
Returns whether the graph g is factorless, i.e., has unity factor and, if applicable,
subgraph factor(s). Note that this function does not recurse through subgraphs of g, so
that one may have, e.g., `isfactorless(g) == true` but `isfactorless(eldest(g)) == false`.
# Arguments:
- `g::AbstractGraph`: graph to be analyzed
"""
function isfactorless(g::AbstractGraph)
if isleaf(g)
return isapprox_one(factor(g))
else
return all(isapprox_one.([factor(g); subgraph_factors(g)]))
end
end
# """
# function isfactorless(g)

# Returns whether the graph g is factorless, i.e., has unity factor and, if applicable,
# subgraph factor(s). Note that this function does not recurse through subgraphs of g, so
# that one may have, e.g., `isfactorless(g) == true` but `isfactorless(eldest(g)) == false`.

# # Arguments:
# - `g::AbstractGraph`: graph to be analyzed
# """
# function isfactorless(g::AbstractGraph)
# if isleaf(g)
# return isapprox_one(factor(g))
# else
# return all(isapprox_one.([factor(g); subgraph_factors(g)]))
# end
# end

"""
function has_zero_subfactors(g)
Expand Down
20 changes: 11 additions & 9 deletions src/frontend/diagtree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,18 @@ function Graph!(d::DiagTree.Diagram{W}) where {W}
push!(subgraphs, res)
end

if isempty(subgraphs)
root = ComputationalGraphs.Graph(subgraphs; subgraph_factors=ones(W, length(subgraphs)), factor=d.factor, name=String(d.name),
operator=op(d.operator), orders=d.id.order, ftype=W, wtype=W, weight=d.weight, properties=d.id)
else
tree = ComputationalGraphs.Graph(subgraphs; subgraph_factors=ones(W, length(subgraphs)),
operator=op(d.operator), orders=d.id.order, ftype=W, wtype=W, weight=d.weight)
root = ComputationalGraphs.Graph([tree,]; subgraph_factors=[d.factor,], orders=tree.orders,
ftype=W, wtype=W, weight=d.weight * d.factor)
end
# if isempty(subgraphs)
# root = ComputationalGraphs.Graph(subgraphs; subgraph_factors=ones(W, length(subgraphs)), factor=d.factor, name=String(d.name),
# operator=op(d.operator), orders=d.id.order, ftype=W, wtype=W, weight=d.weight, properties=d.id)
# else
# tree = ComputationalGraphs.Graph(subgraphs; subgraph_factors=ones(W, length(subgraphs)),
# operator=op(d.operator), orders=d.id.order, ftype=W, wtype=W, weight=d.weight)
# root = ComputationalGraphs.Graph([tree,]; subgraph_factors=[d.factor,], orders=tree.orders,
# ftype=W, wtype=W, weight=d.weight * d.factor)
# end

root = ComputationalGraphs.Graph(subgraphs; subgraph_factors=ones(W, length(subgraphs)), factor=d.factor, name=String(d.name),
operator=op(d.operator), orders=d.id.order, ftype=W, wtype=W, weight=d.weight, properties=d.id)
return root
# @assert haskey(map, root.id) == false "DiagramId already exists in map: $(root.id)"
# @assert haskey(map, tree.id) == false "DiagramId already exists in map: $(tree.id)"
Expand Down
Loading

0 comments on commit ee79c8f

Please sign in to comment.