Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
lpawela committed Feb 8, 2023
1 parent b9c9a1e commit 03da490
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 467 deletions.
141 changes: 74 additions & 67 deletions src/SpinGlassNetworks.jl
Original file line number Diff line number Diff line change
@@ -1,92 +1,99 @@
module SpinGlassNetworks
using LabelledGraphs
using LightGraphs
using MetaGraphs # TODO: remove that
using CSV
using DocStringExtensions
using LinearAlgebra
using Base.Cartesian
using LabelledGraphs
using LightGraphs
using MetaGraphs # TODO: remove that
using CSV
using DocStringExtensions
using LinearAlgebra
using Base.Cartesian

import Base.Prehashed
import Base.Prehashed

export unique_neighbors
export unique_neighbors


unique_neighbors(ig::LabelledGraph, i::Int) = filter(j -> j > i, neighbors(ig, i))
unique_neighbors(ig::LabelledGraph, i::Int) = filter(j -> j > i, neighbors(ig, i))

@generated function unique_dims(A::AbstractArray{T,N}, dim::Integer) where {T,N}
quote
1 <= dim <= $N || return copy(A)
hashes = zeros(UInt, axes(A, dim))
@generated function unique_dims(A::AbstractArray{T,N}, dim::Integer) where {T,N}
quote
1 <= dim <= $N || return copy(A)
hashes = zeros(UInt, axes(A, dim))

# Compute hash for each row
k = 0
@nloops $N i A d->(if d == dim; k = i_d; end) begin
@inbounds hashes[k] = hash(hashes[k], hash((@nref $N A i)))
# Compute hash for each row
k = 0
@nloops $N i A d -> (
if d == dim
k = i_d
end
) begin
@inbounds hashes[k] = hash(hashes[k], hash((@nref $N A i)))
end

# Collect index of first row for each hash
uniquerow = similar(Array{Int}, axes(A, dim))
firstrow = Dict{Prehashed,Int}()
for k = axes(A, dim)
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
end
uniquerows = collect(values(firstrow))
# Collect index of first row for each hash
uniquerow = similar(Array{Int}, axes(A, dim))
firstrow = Dict{Prehashed,Int}()
for k in axes(A, dim)
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
end
uniquerows = collect(values(firstrow))

# Check for collisions
collided = falses(axes(A, dim))
@inbounds begin
@nloops $N i A d->(if d == dim
# Check for collisions
collided = falses(axes(A, dim))
@inbounds begin
@nloops $N i A d -> (
if d == dim
k = i_d
j_d = uniquerow[k]
else
j_d = i_d
end) begin
if (@nref $N A j) != (@nref $N A i)
collided[k] = true
end
end
) begin
if (@nref $N A j) != (@nref $N A i)
collided[k] = true
end
end
end

if any(collided)
nowcollided = similar(BitArray, axes(A, dim))
while any(collided)
# Collect index of first row for each collided hash
empty!(firstrow)
for j = axes(A, dim)
collided[j] || continue
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
end
for v values(firstrow)
push!(uniquerows, v)
end
if any(collided)
nowcollided = similar(BitArray, axes(A, dim))
while any(collided)
# Collect index of first row for each collided hash
empty!(firstrow)
for j in axes(A, dim)
collided[j] || continue
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
end
for v values(firstrow)
push!(uniquerows, v)
end

# Check for collisions
fill!(nowcollided, false)
@nloops $N i A d->begin
if d == dim
k = i_d
j_d = uniquerow[k]
(!collided[k] || j_d == k) && continue
else
j_d = i_d
end
end begin
if (@nref $N A j) != (@nref $N A i)
nowcollided[k] = true
end
# Check for collisions
fill!(nowcollided, false)
@nloops $N i A d -> begin
if d == dim
k = i_d
j_d = uniquerow[k]
(!collided[k] || j_d == k) && continue
else
j_d = i_d
end
end begin
if (@nref $N A j) != (@nref $N A i)
nowcollided[k] = true
end
(collided, nowcollided) = (nowcollided, collided)
end
(collided, nowcollided) = (nowcollided, collided)
end

(@nref $N A d->d == dim ? sort!(uniquerows) : (axes(A, d))), indexin(uniquerow, uniquerows)
end

(@nref $N A d -> d == dim ? sort!(uniquerows) : (axes(A, d))),
indexin(uniquerow, uniquerows)
end
end

include("states.jl")
include("ising.jl")
include("spectrum.jl")
include("lattice.jl")
include("factor.jl")
include("states.jl")
include("ising.jl")
include("spectrum.jl")
include("lattice.jl")
include("factor.jl")
end # module
52 changes: 29 additions & 23 deletions src/factor.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
export factor_graph, rank_reveal, projectors, split_into_clusters, decode_factor_graph_state


function split_into_clusters(ig::LabelledGraph{S, T}, assignment_rule) where {S, T}
cluster_id_to_verts = Dict(
i => T[] for i in values(assignment_rule)
)
function split_into_clusters(ig::LabelledGraph{S,T}, assignment_rule) where {S,T}
cluster_id_to_verts = Dict(i => T[] for i in values(assignment_rule))

for v in vertices(ig)
push!(cluster_id_to_verts[assignment_rule[v]], v)
end

Dict(
i => first(cluster(ig, verts)) for (i, verts) cluster_id_to_verts
)
Dict(i => first(cluster(ig, verts)) for (i, verts) cluster_id_to_verts)
end


function factor_graph(
ig::IsingGraph,
num_states_cl::Int;
spectrum::Function=full_spectrum,
cluster_assignment_rule::Dict{Int, T} # e.g. square lattice
spectrum::Function = full_spectrum,
cluster_assignment_rule::Dict{Int,T}, # e.g. square lattice
) where {T}
ns = Dict(i => num_states_cl for i Set(values(cluster_assignment_rule)))
factor_graph(
ig,
ns,
spectrum=spectrum,
cluster_assignment_rule=cluster_assignment_rule
spectrum = spectrum,
cluster_assignment_rule = cluster_assignment_rule,
)
end

function factor_graph(
ig::IsingGraph,
num_states_cl::Dict{T, Int};
spectrum::Function=full_spectrum,
cluster_assignment_rule::Dict{Int, T} # e.g. square lattice
num_states_cl::Dict{T,Int};
spectrum::Function = full_spectrum,
cluster_assignment_rule::Dict{Int,T}, # e.g. square lattice
) where {T}
L = maximum(values(cluster_assignment_rule))
fg = LabelledGraph{MetaDiGraph}(sort(unique(values(cluster_assignment_rule))))

for (v, cl) split_into_clusters(ig, cluster_assignment_rule)
sp = spectrum(cl, num_states=get(num_states_cl, v, basis_size(cl)))
sp = spectrum(cl, num_states = get(num_states_cl, v, basis_size(cl)))
set_props!(fg, v, Dict(:cluster => cl, :spectrum => sp))
end

Expand All @@ -52,13 +48,18 @@ function factor_graph(

if !isempty(outer_edges)
en = inter_cluster_energy(
get_prop(fg, v, :spectrum).states, J, get_prop(fg, w, :spectrum).states
get_prop(fg, v, :spectrum).states,
J,
get_prop(fg, w, :spectrum).states,
)
pl, en = rank_reveal(en, :PE)
en, pr = rank_reveal(en, :EP)
add_edge!(fg, v, w)
set_props!(
fg, v, w, Dict(:outer_edges => outer_edges, :pl => pl, :en => en, :pr => pr)
fg,
v,
w,
Dict(:outer_edges => outer_edges, :pl => pl, :en => en, :pr => pr),
)
end
end
Expand All @@ -67,13 +68,18 @@ end

function factor_graph(
ig::IsingGraph;
spectrum::Function=full_spectrum,
cluster_assignment_rule::Dict{Int, T}
spectrum::Function = full_spectrum,
cluster_assignment_rule::Dict{Int,T},
) where {T}
factor_graph(ig, Dict{T, Int}(), spectrum=spectrum, cluster_assignment_rule=cluster_assignment_rule)
factor_graph(
ig,
Dict{T,Int}(),
spectrum = spectrum,
cluster_assignment_rule = cluster_assignment_rule,
)
end

function rank_reveal(energy, order=:PE)
function rank_reveal(energy, order = :PE)
@assert order (:PE, :EP)
dim = order == :PE ? 1 : 2

Expand All @@ -85,7 +91,7 @@ function rank_reveal(energy, order=:PE)
P = zeros(size(E, 2), size(energy, 2))
end

for (i, elements) enumerate(eachslice(P, dims=dim))
for (i, elements) enumerate(eachslice(P, dims = dim))
elements[idx[i]] = 1
end

Expand All @@ -94,7 +100,7 @@ end


function decode_factor_graph_state(fg, state::Vector{Int})
ret = Dict{Int, Int}()
ret = Dict{Int,Int}()
for (i, vert) zip(state, vertices(fg))
spins = get_prop(fg, vert, :cluster).labels
states = get_prop(fg, vert, :spectrum).states
Expand Down
46 changes: 21 additions & 25 deletions src/ising.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using LabelledGraphs

export ising_graph, rank_vec, cluster, rank, nodes, basis_size, biases, couplings, IsingGraph, prune
export ising_graph,
rank_vec, cluster, rank, nodes, basis_size, biases, couplings, IsingGraph, prune

const Instance = Union{String, Dict}
const Instance = Union{String,Dict}


unique_nodes(ising_tuples) = sort(collect(Set(Iterators.flatten((i, j) for (i, j, _) ising_tuples))))
unique_nodes(ising_tuples) =
sort(collect(Set(Iterators.flatten((i, j) for (i, j, _) ising_tuples))))

const IsingGraph = LabelledGraph{MetaGraph{Int64, Float64}, Int64}
const IsingGraph = LabelledGraph{MetaGraph{Int64,Float64},Int64}

"""
$(TYPEDSIGNATURES)
Expand All @@ -20,14 +22,14 @@ Store extra information
"""
function ising_graph(
instance::Instance;
sgn::Number=1.0,
rank_override::Dict{Int, Int}=Dict{Int, Int}()
sgn::Number = 1.0,
rank_override::Dict{Int,Int} = Dict{Int,Int}(),
)
# load the Ising instance
if instance isa String
ising = CSV.File(instance, types = [Int, Int, Float64], header=0, comment = "#")
ising = CSV.File(instance, types = [Int, Int, Float64], header = 0, comment = "#")
else
ising = [ (i, j, J) for ((i, j), J) instance ]
ising = [(i, j, J) for ((i, j), J) instance]
end

nodes = unique_nodes(ising)
Expand All @@ -51,13 +53,7 @@ function ising_graph(
end


set_prop!(
ig,
:rank,
Dict{Int, Int}(
v => get(rank_override, v, 2) for v in vertices(ig)
)
)
set_prop!(ig, :rank, Dict{Int,Int}(v => get(rank_override, v, 2) for v in vertices(ig)))

ig
end
Expand All @@ -78,13 +74,8 @@ end
cluster(ig::IsingGraph, verts) = induced_subgraph(ig, collect(verts))

function inter_cluster_edges(ig::IsingGraph, cl1::IsingGraph, cl2::IsingGraph)
verts1, verts2 = vertices(cl1), vertices(cl2)

outer_edges = [
LabelledEdge(i, j)
for i vertices(cl1), j vertices(cl2)
if has_edge(ig, i, j)
]
outer_edges =
[LabelledEdge(i, j) for i vertices(cl1), j vertices(cl2) if has_edge(ig, i, j)]

J = zeros(nv(cl1), nv(cl2))
for e outer_edges
Expand All @@ -94,14 +85,19 @@ function inter_cluster_edges(ig::IsingGraph, cl1::IsingGraph, cl2::IsingGraph)
outer_edges, J
end

function prune(ig::IsingGraph)
function prune(ig::IsingGraph)
to_keep = vcat(
findall(!iszero, degree(ig)),
findall(x->iszero(degree(ig, x)) && !isapprox(get_prop(ig, x, :h), 0, atol=1e-14), vertices(ig))
findall(
x ->
iszero(degree(ig, x)) &&
!isapprox(get_prop(ig, x, :h), 0, atol = 1e-14),
vertices(ig),
),
)

gg = ig[ig.labels[to_keep]]
labels = collect(vertices(gg.inner_graph))
reverse_label_map = Dict(i => i for i=1:nv(gg.inner_graph))
reverse_label_map = Dict(i => i for i = 1:nv(gg.inner_graph))
LabelledGraph(labels, gg.inner_graph, reverse_label_map)
end
Loading

0 comments on commit 03da490

Please sign in to comment.