From 6c86635be27d3d2c95a5d0a448ae005cc4a0225a Mon Sep 17 00:00:00 2001
From: Lionel Zoubritzky <lionel.zoubritzky@gmail.com>
Date: Mon, 29 Jan 2024 19:14:04 +0100
Subject: [PATCH] Extend parse_cgd support to general .cgd files

---
 src/input.jl        | 266 ++++++++++++++++++++++++++++++++++++++++----
 src/minimization.jl |  13 ++-
 2 files changed, 254 insertions(+), 25 deletions(-)

diff --git a/src/input.jl b/src/input.jl
index aec728f..d444a73 100644
--- a/src/input.jl
+++ b/src/input.jl
@@ -353,6 +353,207 @@ function parse_arcs(path, lesser_priority=["epinet"])
     return flag, dict
 end
 
+
+function expand_symmetry_cgd!(nodes::Vector{SVector{N,Float64}}, equivalents::Vector{EquivalentPosition{Float64}}, connectivity::Vector{Int}) where N
+    n = length(nodes)
+    for i in 1:n
+        node = nodes[i]
+        node3 = N==2 ? SA[node[1], node[2], 0.0] : node
+        for eq in equivalents
+            newpos = round.(eq(node3); digits=10) # transform -1e-15 into 0.0
+            newpos = newpos .- floor.(newpos)
+            cmppos = N==2 ? SA[newpos[1], newpos[2]] : newpos
+            minimum(Base.Fix1(dist2, cmppos), nodes) < 1e-8 && continue
+            push!(nodes, cmppos)
+            push!(connectivity, connectivity[i])
+        end
+    end
+    nothing
+end
+
+struct ClosestSymmetricImage{N,T} <: Function
+    nodes::Vector{SVector{N,Float64}}
+    x::PeriodicVertex{N}
+    eq::EquivalentPosition{T}
+end
+function (csi::ClosestSymmetricImage{N,T})() where {N,T}
+    node = csi.nodes[csi.x.v] .+ csi.x.ofs
+    pos = N==2 ? SA[node[1], node[2], 0.0] : node
+    newpos = round.(csi.eq(pos); digits=10) # transform -1e-15 into 0.0
+    ofs = floor.(Int, newpos)
+    newpos = newpos .- ofs
+    cmppos = N==2 ? SA[newpos[1], newpos[2]] : newpos
+    rsrc, newsrc = findmin(Base.Fix1(dist2, cmppos), csi.nodes)
+    # rsrc > 0.01 && @show csi.nodes csi.x csi.eq node pos newpos ofs cmppos
+    @assert rsrc < 1e-9
+    PeriodicVertex{N}(newsrc, N==2 ? SA[ofs[1], ofs[2]] : ofs)
+end
+
+function expand_symmetry_cgd!(edgelist::Vector{PeriodicEdge{N}}, nodes::Vector{SVector{N,Float64}}, equivalents::Vector{EquivalentPosition{Float64}}) where N
+    n = length(edgelist)
+    memories = [Dict{PeriodicVertex{N},PeriodicVertex{N}}() for _ in equivalents]
+    for i in 1:n
+        src, dst = edgelist[i]
+        for (j, eq) in enumerate(equivalents)
+            memory = memories[j]
+            src_v, src_ofs = get!(ClosestSymmetricImage(nodes, PeriodicVertex{N}(src), eq), memory, PeriodicVertex{N}(src))
+            dst_v, dst_ofs = get!(ClosestSymmetricImage(nodes, dst, eq), memory, dst)
+            push!(edgelist, PeriodicEdge{N}(src_v, PeriodicVertex{N}(dst_v, dst_ofs.-src_ofs)))
+        end
+    end
+end
+
+function retrieve_node(nodes::Vector{SVector{N,Float64}}, pos::SVector{N,Float64}) where N
+    pos = round.(pos; digits=10) # transform -1e-15 into 0.0
+    ofs = floor.(Int, pos)
+    find_nearest(nodes, pos .- ofs), ofs
+end
+
+function _parse_cgd_crystal!(iterator, ret, current_edgelist2D, current_edgelist3D, path)
+    current_name = ""
+    local nodes#::Union{Vector{SVector{2,Float64}}, Vector{SVector{3,Float64}}}
+    local current_edgelist
+    local PE
+    names = Dict{String,Int}()
+    connectivity = Int[]
+    hall = 1
+    hallflag = false
+    cell = Cell{Float64}()
+    lastkeyw = ""
+    dimension = 0
+    for _l in iterator
+        l = strip(lowercase(_l))
+        comment = findfirst('#', l)
+        if !isnothing(comment)
+            l = strip(@view l[1:prevind(l, comment)])
+        end
+        isempty(l) && continue
+        splits = split(l; limit=2)
+        keyw = first(splits)
+        if keyw == "name" || keyw == "id" || keyw == "key"
+            current_name = last(splits)
+        elseif keyw == "group"
+            hall = PeriodicGraphEmbeddings.find_hall_number("", last(splits), 0, false)
+            if hall == 1 && last(splits) != "p1"
+                @info lazy"Unreadable symmetry $(last(splits)): skipping $current_name"
+                @goto skip
+            end
+            if cell != Cell{Float64}()
+                cell = Cell{Float64}(hall, cell_parameters(cell.mat)...)
+            end
+            hallflag = true
+        elseif keyw == "cell"
+            _a, _b, unk... = split(last(splits))
+            a = parse(Float64, _a); b = parse(Float64, _b)
+            if length(unk) == 1
+                if hall != 1
+                    @info lazy"2-dimensional nets with symmetries are not currently accepted: skipping $current_name"
+                    @goto skip
+                end
+                γ = parse(Float64, only(unk))
+                c = 10.0
+                α = β = 90.0
+                dimension = 2
+                current_edgelist = current_edgelist2D
+                PE = PeriodicEdge2D
+                nodes = SVector{2,Float64}[]
+            else
+                c, α, β, γ = parse.(Float64, unk)
+                dimension = 3
+                current_edgelist = current_edgelist3D
+                PE = PeriodicEdge3D
+                nodes = SVector{3,Float64}[]
+            end
+            cell = Cell{Float64}(hall, (a, b, c), (α, β, γ))
+        elseif keyw == "node" || keyw == "atom"
+            @label handle_node
+            if !@isdefined(nodes)
+                @error "Error while parsing $path at \"$current_name\": missing cell"
+                dimension = length(split(l)) - 3
+                if dimension == 3
+                    current_edgelist = current_edgelist3D
+                    PE = PeriodicEdge3D
+                    nodes = SVector{3,Float64}[]
+                    cell = Cell{Float64}(hall, (NaN, NaN, NaN), (NaN, NaN, NaN))
+                end
+            end
+            @assert @isdefined(nodes) && hallflag && (!isempty(nodes) || cell != Cell{Float64}())
+            if length(splits) == 2
+                name, conn, _pos... = split(last(splits))
+                haskey(names, name) && error(lazy"Error while parsing $path at \"$current_name\": Multiple vertices have the same name $name")
+                push!(connectivity, parse(Int, conn))
+                newnode = round.(parse.(Float64, dimension == 2 ? SVector{2}(_pos) : SVector{3}(_pos)); digits=10)
+                push!(nodes, newnode .- floor.(newnode))
+            end # else, the line is just "node" hence the following are the actual nodes
+            lastkeyw = "node"
+        elseif keyw == "edge"
+            @label handle_edge
+            @assert @isdefined(nodes) && !isempty(nodes)
+            if length(splits) == 2
+                _content = split(last(splits))
+                isempty(current_edgelist) && expand_symmetry_cgd!(nodes, cell.equivalents, connectivity)
+                push!(current_edgelist, if length(_content) == 2
+                    PE(parse.(Int, _content)..., dimension==2 ? (0,0) : (0,0,0))
+                elseif length(_content) == 1+dimension
+                    srcA = parse(Int, _content[1])
+                    dstA_v = parse.(Float64, (dimension==2 ? SVector{2,Float64} : SVector{3,Float64})(@view _content[2:end]))
+                    dstA, ofsA = retrieve_node(nodes, dstA_v)
+                    PE(srcA, dstA, ofsA)
+                else
+                    @assert length(_content) == 2*dimension
+                    srcB_v = parse.(Float64, (dimension==2 ? SVector{2,SubString{String}} : SVector{3,SubString{String}})(@view _content[1:dimension]))
+                    dstB_v = parse.(Float64, (dimension==2 ? SVector{2,SubString{String}} : SVector{3,SubString{String}})(@view _content[(dimension+1):end]))
+                    srcB, ofsBstart = retrieve_node(nodes, srcB_v)
+                    dstB, ofsBstop = retrieve_node(nodes, dstB_v)
+                    # if srcB == dstB && ofsBstop == ofsBstart
+                    #     @show current_name
+                    #     display(current_edgelist)
+                    #     display(nodes)
+                    #     @show srcB, ofsBstart
+                    #     println(parse.(Float64, _content[1:dimension]))
+                    #     println(parse.(Float64, _content[(dimension+1):end]))
+                    # end
+                    PE(srcB, dstB, ofsBstop .- ofsBstart)
+                end)
+            end
+            lastkeyw = "edge"
+        elseif keyw == "end"
+            expand_symmetry_cgd!(current_edgelist, nodes, cell.equivalents, )
+            g = dimension == 2 ? PeriodicGraph2D(current_edgelist2D) : PeriodicGraph3D(current_edgelist3D)
+            empty!(current_edgelist)
+            if nv(g) != length(nodes)
+                @error lazy"Error while parsing $path at \"$current_name\": Found a graph with only $(nv(g)) nodes instead of expected $(length(nodes))"
+                @goto skip
+            end
+            for (i, conn) in enumerate(connectivity)
+                if degree(g, i) != conn
+                    @error lazy"Error while parsing $path at \"$current_name\": Invalid connectivity of $(degree(g, i)) found instead of $conn for node $i"
+                    @goto skip
+                end
+            end
+            push!(ret, (current_name, g))
+            return dimension
+        else
+            if lastkeyw == "node"
+                splits = [l]
+                @goto handle_node
+            elseif lastkeyw == "edge"
+                splits = [l]
+                @goto handle_edge
+            end
+            @error lazy"Error while parsing $path at \"$current_name\": Unknown key: \"$keyw\""
+            @goto skip
+        end
+    end
+    error(lazy"Error while parsing $path at \"$current_name\": Starting \"CRYSTAL\" missing its corresponding \"END\"")
+
+    @label skip
+    for l2 in iterator
+        lowercase(strip(l2)) == "end" && break
+    end
+    return 0
+end
+
 function parse_cgd_lines!(edgelist::Vector{PeriodicEdge{D}}, iterator) where D
     for _l in iterator
         l = strip(_l)
@@ -364,30 +565,19 @@ function parse_cgd_lines!(edgelist::Vector{PeriodicEdge{D}}, iterator) where D
     nothing
 end
 
-"""
-    parse_cgd(path::AbstractString)
-
-Parse a .cgd Systre configuration data file such as the one used by the RCSR.
-Return a list of `id => g` where `id` is the name of the structure and `g` is its
-`PeriodicGraph`.
-
-Only support PERIODIC_GRAPH inputs.
-"""
-function parse_cgd(path::AbstractString)
-    ret = Tuple{String,Union{PeriodicGraph2D,PeriodicGraph3D}}[]
-    current_name = ""
-    current_edgelist2D = PeriodicEdge2D[]
-    current_edgelist3D = PeriodicEdge3D[]
+function _parse_cgd_periodicgraph!(iterator, ret, current_edgelist2D, current_edgelist3D, path)
     edgeready = false
-    iterator = eachline(path)
+    current_name = ""
     for _l in iterator
         l = strip(_l)
+        comment = findfirst('#', l)
+        if !isnothing(comment)
+            l = strip(@view l[1:prevind(l, comment)])
+        end
         isempty(l) && continue
         splits = split(l; limit=2)
         keyw = lowercase(first(splits))
-        if keyw == "periodic_graph"
-            @assert isempty(current_name) && isempty(current_edgelist2D) && isempty(current_edgelist3D) && !edgeready
-        elseif keyw == "name" || keyw == "id" || keyw == "key"
+        if keyw == "name" || keyw == "id" || keyw == "key"
             current_name = last(splits)
         elseif keyw == "edges"
             edgeready = true
@@ -399,18 +589,52 @@ function parse_cgd(path::AbstractString)
                 parse_cgd_lines!(current_edgelist2D, iterator)
                 push!(ret, (current_name, PeriodicGraph2D(current_edgelist2D)))
                 empty!(current_edgelist2D)
+                return 2
             else
                 @assert length(ofs) == 3
                 push!(current_edgelist3D, PeriodicEdge3D(src, dst, SVector{3,Int}(ofs)))
                 parse_cgd_lines!(current_edgelist3D, iterator)
                 push!(ret, (current_name, PeriodicGraph3D(current_edgelist3D)))
                 empty!(current_edgelist3D)
+                return 3
             end
-            current_name = ""
-            edgeready = false
+        else
+            error(lazy"Error while parsing $path at \"$current_name\": Unknown key: \"$keyw\"")
         end
     end
-    return ret
+    error(lazy"Error while parsing $path at \"$current_name\": Starting \"CRYSTAL\" missing its corresponding \"END\"")
+end
+
+"""
+    parse_cgd(path::AbstractString)
+
+Parse a .cgd Systre configuration data file such as the one used by the RCSR.
+Return a list of `id => g` where `id` is the name of the structure and `g` is its
+`PeriodicGraph`.
+"""
+function parse_cgd(path::AbstractString)
+    ret = Tuple{String,Union{PeriodicGraph2D,PeriodicGraph3D}}[]
+    current_edgelist2D = PeriodicEdge2D[]
+    current_edgelist3D = PeriodicEdge3D[]
+    iterator = eachline(path)
+    for _l in iterator
+        l = lowercase(strip(_l))
+        comment = findfirst('#', l)
+        if !isnothing(comment)
+            l = strip(@view l[1:prevind(l, comment)])
+        end
+        isempty(l) && continue
+        if l == "periodic_graph"
+            @assert isempty(current_edgelist2D) && isempty(current_edgelist3D)
+            _parse_cgd_periodicgraph!(iterator, ret, current_edgelist2D, current_edgelist3D, path)
+        elseif l == "crystal"
+            @assert isempty(current_edgelist2D) && isempty(current_edgelist3D)
+            _parse_cgd_crystal!(iterator, ret, current_edgelist2D, current_edgelist3D, path)
+        else
+            error(lazy"Error while parsing $path: Misplaced \"$l\"")
+        end
+    end
+    ret
 end
 
 function parse_atom_name(name::AbstractString)
diff --git a/src/minimization.jl b/src/minimization.jl
index 89523c4..943547b 100644
--- a/src/minimization.jl
+++ b/src/minimization.jl
@@ -347,14 +347,19 @@ function max_nearest(c::Crystal{Nothing}, trans)
     return i_max, max
 end
 
+function dist2(x::SVector{D,T}, y::SVector{D,T}) where {D,T}
+    r2 = zero(T)
+    for j in 1:D
+        r2 += (x[j] - y[j])^2
+    end
+    r2
+end
+
 function find_nearest(l::Vector{SVector{D,T}}, pos::SVector{D,T}) where {D,T}
     minr2 = Inf
     mini = 0
     for (i, x) in enumerate(l)
-        r2 = zero(T)
-        for j in 1:D
-            r2 += (pos[j] - x[j])^2
-        end
+        r2 = dist2(pos, x)
         if r2 < minr2
             minr2 = r2
             mini = i