From aa6b2f963d78e5f7dac8d5360330ddd0cbf22896 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 21 Jul 2024 09:13:44 -0700 Subject: [PATCH 1/5] first pass at organizing code --- DEC/Project.toml | 1 - DEC/src/DEC.jl | 18 +- DEC/src/OperatorStorage.jl | 4 - DEC/src/Roe.jl | 292 ------------------------ DEC/src/eggstraction.jl | 39 ---- DEC/src/models/ThDEC/EGraph.jl | 83 +++++++ DEC/src/{ => models/ThDEC}/Luke.jl | 21 +- DEC/src/{ => models/ThDEC}/Signature.jl | 42 ++-- DEC/src/models/ThDEC/ThDEC.jl | 128 +++++++++++ DEC/src/models/module.jl | 9 + DEC/src/roe/RoeUtility.jl | 222 ++++++++++++++++++ DEC/src/{ => roe}/SSAExtract.jl | 49 +++- DEC/src/roe/module.jl | 11 + DEC/src/{ => util}/HashColor.jl | 0 DEC/src/util/module.jl | 9 + DEC/tests/DEC.jl | 126 +++------- DEC/tests/Roe.jl | 60 +++++ DEC/tests/SSAExtract.jl | 60 ++++- DEC/tests/Signature.jl | 25 ++ DEC/tests/runtests.jl | 17 ++ 20 files changed, 737 insertions(+), 479 deletions(-) delete mode 100644 DEC/src/OperatorStorage.jl delete mode 100644 DEC/src/Roe.jl delete mode 100644 DEC/src/eggstraction.jl create mode 100644 DEC/src/models/ThDEC/EGraph.jl rename DEC/src/{ => models/ThDEC}/Luke.jl (89%) rename DEC/src/{ => models/ThDEC}/Signature.jl (65%) create mode 100644 DEC/src/models/ThDEC/ThDEC.jl create mode 100644 DEC/src/models/module.jl create mode 100644 DEC/src/roe/RoeUtility.jl rename DEC/src/{ => roe}/SSAExtract.jl (73%) create mode 100644 DEC/src/roe/module.jl rename DEC/src/{ => util}/HashColor.jl (100%) create mode 100644 DEC/src/util/module.jl create mode 100644 DEC/tests/Roe.jl create mode 100644 DEC/tests/Signature.jl create mode 100644 DEC/tests/runtests.jl diff --git a/DEC/Project.toml b/DEC/Project.toml index f8e2621a..9ce6c275 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -26,6 +26,5 @@ Decapodes = "0.5.5" GeometryBasics = "0.4.11" MLStyle = "0.4.17" OrdinaryDiffEq = "6.86.0" -Random = "1.11.0" Reexport = "1.2.2" StructEquality = "2.1.0" diff --git a/DEC/src/DEC.jl b/DEC/src/DEC.jl index 432a258b..f913060c 100644 --- a/DEC/src/DEC.jl +++ b/DEC/src/DEC.jl @@ -1,4 +1,7 @@ module DEC + +using Reexport + using MLStyle using Reexport using StructEquality @@ -6,15 +9,14 @@ import Metatheory using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH import Metatheory: extract! -import Base: +, - -import Base: * +import Base: +, -, * -include("HashColor.jl") -include("Signature.jl") -include("Roe.jl") -include("SSAExtract.jl") -include("Luke.jl") +include("util/module.jl") # Pretty-printing +include("roe/module.jl") # Checking signature for DEC operations +include("models/module.jl") # manipulating SSAs +@reexport using .Util @reexport using .SSAExtract +@reexport using .Models -end # module DEC +end diff --git a/DEC/src/OperatorStorage.jl b/DEC/src/OperatorStorage.jl deleted file mode 100644 index eeb9000b..00000000 --- a/DEC/src/OperatorStorage.jl +++ /dev/null @@ -1,4 +0,0 @@ -struct OperatorStorage - hodge::Tuple{} -end - diff --git a/DEC/src/Roe.jl b/DEC/src/Roe.jl deleted file mode 100644 index 8869b817..00000000 --- a/DEC/src/Roe.jl +++ /dev/null @@ -1,292 +0,0 @@ -@struct_hash_equal struct RootVar - name::Symbol - idx::Int - sort::Sort -end - -struct Roe - variables::Vector{RootVar} - graph::EGraph{Expr, Sort} - function Roe() - new(RootVar[], EGraph{Expr, Sort}()) - end -end - -struct Var{S} - roe::Roe - id::Id -end - -function EGraphs.make(g::EGraph{Expr, Sort}, n::Metatheory.VecExpr) - op = EGraphs.get_constant(g,Metatheory.v_head(n)) - if op isa RootVar - op.sort - elseif op isa Number - Scalar() - else - op((g[arg].data for arg in Metatheory.v_children(n))...) - end -end - -function EGraphs.join(s1::Sort, s2::Sort) - if s1 == s2 - s1 - else - error("Cannot equate two nodes with different sorts") - end -end - -function extract!(v::Var, f=EGraphs.astsize) - extract!(v.roe.graph, f, v.id) -end - -function rootvarcrayon(v::RootVar) - lightnessrange = (50., 100.) - HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) -end - -function Base.show(io::IO, v::RootVar) - if get(io, :color, true) - crayon = rootvarcrayon(v) - print(io, crayon, "$(v.name)") - print(io, inv(crayon)) - else - print(io, "$(v.name)#$(v.idx)") - end -end - -function fix_functions(e) - @match e begin - s::Symbol => s - Expr(:call, f::Function, args...) => - Expr(:call, nameof(f), fix_functions.(args)...) - Expr(head, args...) => - Expr(head, fix_functions.(args)...) - _ => e - end -end - -function getexpr(v::Var) - e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) - fix_functions(e) -end - -function Base.show(io::IO, v::Var) - print(io, getexpr(v)) -end - -function fresh!(roe::Roe, sort::Sort, name::Symbol) - v = RootVar(name, length(roe.variables), sort) - push!(roe.variables, v) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) - Var{sort}(roe, EGraphs.add!(roe.graph, n, false)) -end - -@nospecialize -function inject_number!(roe::Roe, x::Number) - x = Float64(x) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) - Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) -end - -@nospecialize -function addcall!(g::EGraph, head, args) - ar = length(args) - n = Metatheory.v_new(ar) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) - Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) - for i in Metatheory.v_children_range(n) - @inbounds n[i] = args[i - VECEXPR_META_LENGTH] - end - EGraphs.add!(g, n, false) -end - -function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) - v1.roe === v2.roe || error("Cannot equate variables from different graphs") - union!(v1.roe.graph, v1.id, v2.id) -end - -≐(v1::Var, v2::Var) = equate!(v1, v2) - -@nospecialize -function derivative_cost(allowed_roots) - function cost(n::Metatheory.VecExpr, op, costs) - if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) - Inf - else - Metatheory.astsize(n, op, costs) - end - end -end - - -@nospecialize -function +(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot add variables from different graphs") - s = s1 + s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, +, (v1.id, v2.id))) -end - -@nospecialize -+(v::Var, x::Number) = +(v, inject_number!(v.roe, x)) - -@nospecialize -+(x::Number, v::Var) = +(inject_number!(v.roe, x), v) - -@nospecialize -function -(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe == v2.roe || error("Cannot subtract variables from different graphs") - s = s1 - s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, -, (v1.id, v2.id))) -end - -@nospecialize --(v::Var{s}) where {s} = Var{s}(v.roe, addcall!(v.roe.graph, -, (v.id,))) - -@nospecialize --(v::Var, x::Number) = -(v, inject_number!(v.roe, x)) - -@nospecialize --(x::Number, v::Var) = -(inject_number!(v.roe, x), v) - -@nospecialize -function *(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot multiply variables from different graphs") - s = s1 * s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, *, (v1.id, v2.id))) -end - -@nospecialize -*(v::Var, x::Number) = *(v, inject_number!(v.roe, x)) - -@nospecialize -*(x::Number, v::Var) = *(inject_number!(v.roe, x), v) - -@nospecialize -function ∧(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot wedge variables from different graphs") - s = s1 ∧ s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, ∧, (v1.id, v2.id))) -end - -@nospecialize -function ∂ₜ(v::Var{s}) where {s} - Var{s}(v.roe, addcall!(v.roe.graph, ∂ₜ, (v.id,))) -end - -@nospecialize -function d(v::Var{s}) where {s} - s′ = d(s) - Var{s′}(v.roe, addcall!(v.roe.graph, d, (v.id,))) -end - - -@nospecialize -function ★(v::Var{s}) where {s} - s′ = ★(s) - Var{s′}(v.roe, addcall!(v.roe.graph, ★, (v.id,))) -end - -Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) - - -""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction - -Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. - -Example: given a diffusivity constant a, the heat equation can be written as: -``` - ∂ₜ u = a * Laplacian(u) -``` -would return (u, a). - -A limitation of this function can be demonstrated here: given the model - ``` - ∂ₜ a = a + b - ∂ₜ b = a + b - ``` - we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, - the extractor would bypass it completely. -""" -function vfield(model, operator_lookup::Dict{TypedApplication, Any}) - roe = Roe() - (state_vars, param_vars) = model(roe) - length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") - state_rootvars = map(state_vars) do x - rv = extract!(x) - rv isa RootVar || error("all state variables must be RootVars") - rv - end - param_rootvars = map(param_vars) do p - rv = extract!(p) - rv isa RootVar || error("all param variables must be RootVars") - rv - end - - u = :u - p = :p - du = :du - - rootvar_lookup = - Dict{RootVar, Union{Expr, Symbol}}( - [ - [rv => :($(u)) for (i, rv) in enumerate(state_rootvars)]; - [rv => :($(p)) for (i, rv) in enumerate(param_rootvars)] - ] - ) - - cost = derivative_cost(Set([state_rootvars; param_rootvars])) - - extractor = EGraphs.Extractor(roe.graph, cost, Float64) - - function term_select(id) - EGraphs.find_best_node(extractor, id) - end - - - ssa = SSAExtract.SSA() - - derivative_vars = map(state_vars) do v - SSAExtract.extract_ssa!(roe.graph, ssa, (∂ₜ(v)).id, term_select) - end - - toexpr(v::SSAExtract.SSAVar) = Symbol("tmp%$(v.idx)") - - function toexpr(expr::SSAExtract.SSAExpr) - if expr.fn isa RootVar - rootvar_lookup[expr.fn] - elseif expr.fn isa Number - expr.fn - else - op = operator_lookup[TypedApplication(expr.fn, first.(expr.args))] - if op isa Tuple - op = op[1] - end - Expr(:call, *, op, toexpr.(last.(expr.args))...) - end - end - - ssalines = map(enumerate(ssa.statements)) do (i, expr) - :($(toexpr(SSAExtract.SSAVar(i))) = $(toexpr(expr))) - end - - set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) - :($(du) .= $(toexpr(v))) - end - - eval( - quote - ($du, $u, $p, _) -> begin - $(ssalines...) - $(set_derivative_stmts...) - end - end - ) -end \ No newline at end of file diff --git a/DEC/src/eggstraction.jl b/DEC/src/eggstraction.jl deleted file mode 100644 index 9bd330da..00000000 --- a/DEC/src/eggstraction.jl +++ /dev/null @@ -1,39 +0,0 @@ - -""" -An EGraph has a field `symcache::Dict{Any, Vector{EClassId}}` -whose keys appear to be just symbols (what else could be there?) - -To extract a value `a`, we could just index by symcache[a]. -However, variables could share names. - -Consider this -``` -G = EGraph() -ec1 = addexpr!(G, :(f(a, b))) -ec2 = addexpr!(G, :(f(a, c))) -``` -which -""" -keys(G.symcache) -""" -returns -``` -:a => [1] -:b => [2] -:f => [3, 5] -:c => [4] -``` -Threading the val for :f into `G.classes` -""" -G.classes[:f] -""" -returns the EClass -``` -EClass 3 ([ENode(call, f, Expr, [1,2])], ) -``` -How do we convert an EClass back to an expression? -""" - -""" -``memo::Dict{AbstractENode, EClassId}`` -""" \ No newline at end of file diff --git a/DEC/src/models/ThDEC/EGraph.jl b/DEC/src/models/ThDEC/EGraph.jl new file mode 100644 index 00000000..df8757f9 --- /dev/null +++ b/DEC/src/models/ThDEC/EGraph.jl @@ -0,0 +1,83 @@ +using ...DEC: Var, addcall! + +import Base: +, -, * + +# These operations create calls on a common egraph. We validate the signature by dispatching the operation on the types using methods we defined in Signature. + +@nospecialize +function +(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot add variables from different graphs") + s = s1 + s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, +, (v1.id, v2.id))) +end +export + + +@nospecialize ++(v::Var, x::Number) = +(v, inject_number!(v.roe, x)) + +@nospecialize ++(x::Number, v::Var) = +(inject_number!(v.roe, x), v) + +@nospecialize +function -(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe == v2.roe || error("Cannot subtract variables from different graphs") + s = s1 - s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, -, (v1.id, v2.id))) +end +export - + +@nospecialize +-(v::Var{s}) where {s} = Var{s}(v.roe, addcall!(v.roe.graph, -, (v.id,))) + +@nospecialize +-(v::Var, x::Number) = -(v, inject_number!(v.roe, x)) + +@nospecialize +-(x::Number, v::Var) = -(inject_number!(v.roe, x), v) + +@nospecialize +function *(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot multiply variables from different graphs") + s = s1 * s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, *, (v1.id, v2.id))) +end +export * + +@nospecialize +*(v::Var, x::Number) = *(v, inject_number!(v.roe, x)) + +@nospecialize +*(x::Number, v::Var) = *(inject_number!(v.roe, x), v) + +@nospecialize +function ∧(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot wedge variables from different graphs") + s = s1 ∧ s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, ∧, (v1.id, v2.id))) +end +export ∧ + +@nospecialize +function ∂ₜ(v::Var{s}) where {s} + Var{s}(v.roe, addcall!(v.roe.graph, ∂ₜ, (v.id,))) +end +export ∂ₜ + +@nospecialize +function d(v::Var{s}) where {s} + s′ = d(s) + Var{s′}(v.roe, addcall!(v.roe.graph, d, (v.id,))) +end +export d + +@nospecialize +function ★(v::Var{s}) where {s} + s′ = ★(s) + Var{s′}(v.roe, addcall!(v.roe.graph, ★, (v.id,))) +end +export ★ + +Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) +export Δ + +# end diff --git a/DEC/src/Luke.jl b/DEC/src/models/ThDEC/Luke.jl similarity index 89% rename from DEC/src/Luke.jl rename to DEC/src/models/ThDEC/Luke.jl index 31a14e98..22affa0b 100644 --- a/DEC/src/Luke.jl +++ b/DEC/src/models/ThDEC/Luke.jl @@ -1,17 +1,12 @@ import Decapodes using StructEquality -@struct_hash_equal struct TypedApplication - fn::Function - sorts::Vector{Sort} -end -const TA = TypedApplication +""" precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} -function Base.show(io::IOBuffer, ta::TA) - print(io, Expr(:call, ta.fn, ta.sorts...)) -end +Given a matrix and a hodge star (DiagonalHodge() or GeometricHodge()), this returns a lookup dictionary between operators (as TypedApplications) and their corresponding matrices. +""" function precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} Dict{TypedApplication, Any}( # Regular Hodge Stars @@ -20,8 +15,10 @@ function precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} TA(★, Sort[PrimalForm(2)]) => Decapodes.dec_mat_hodge(2, sd, hodge), # Inverse Hodge Stars - TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(1, sd, hodge), # why is this 1??? - TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), # Special since Geo is a solver + TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(1, sd, hodge), + # why is this 1??? + TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), + # Special since Geo is a solver TA(★, Sort[DualForm(2)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), # Differentials @@ -70,6 +67,6 @@ function precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} # # Averaging Operator # :avg₀₁ => Decapodes.dec_avg₀₁(sd) - # :neg => x -> -1 .* x ) -end \ No newline at end of file +end + diff --git a/DEC/src/Signature.jl b/DEC/src/models/ThDEC/Signature.jl similarity index 65% rename from DEC/src/Signature.jl rename to DEC/src/models/ThDEC/Signature.jl index 89328c92..f517554b 100644 --- a/DEC/src/Signature.jl +++ b/DEC/src/models/ThDEC/Signature.jl @@ -1,9 +1,18 @@ -@data Sort begin +using ...DEC: AbstractSort, SortError + +using MLStyle + +import Base: +, -, * + +# Define the sorts in your theory. +# For the DEC, we work with Scalars and Forms, graded objects which can also be primal or dual. +@data Sort <: AbstractSort begin Scalar() Form(dim::Int, isdual::Bool) end export Scalar, Form +dim(f::Form) = f.dim duality(f::Form) = f.isdual ? "dual" : "primal" PrimalForm(i::Int) = Form(i, false) @@ -12,33 +21,37 @@ export PrimalForm DualForm(i::Int) = Form(i, true) export DualForm -struct SortError <: Exception - message::String -end +Base.show(io::IO, ω::Form) = print(io, ω.isdual ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") + +## OPERATIONS @nospecialize function +(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(i1, isdual1), Form(i2, isdual2)) => - if (i1 == i2) && (isdual1 == isdual2) - Form(i1, isdual1) - else - throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) - end + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end end end +# Type-checking inverse of addition follows addition -(s1::Sort, s2::Sort) = +(s1, s2) +# Negation is valid -(s::Sort) = s @nospecialize function *(s1::Sort, s2::Sort) @match (s1, s2) begin (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) end end @@ -82,3 +95,4 @@ function ★(s::Sort) Form(i, isdual) => Form(2 - i, !isdual) end end + diff --git a/DEC/src/models/ThDEC/ThDEC.jl b/DEC/src/models/ThDEC/ThDEC.jl new file mode 100644 index 00000000..3c2b48a1 --- /dev/null +++ b/DEC/src/models/ThDEC/ThDEC.jl @@ -0,0 +1,128 @@ +module ThDEC + +using ...DEC: TypedApplication, TA, Roe, RootVar +using ...DEC.SSAExtract + +using Metatheory: VecExpr +using Metatheory.EGraphs + +include("Signature.jl") # verify the signature holds +include("EGraph.jl") # overload DEC operations to act on roe (egraphs) +include("Luke.jl") # represent operations as matrices + +@nospecialize +""" derivative_cost(allowed_roots)::Function + +Returns a function `cost(n::Metatheory.VecExpr, op, costs)` which sets the cost of operations to Inf if they are either ∂ₜ or forbidden RootVars. Otherwise it computes the astsize. + +""" +function derivative_cost(allowed_roots) + function cost(n::VecExpr, op, costs) + if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) + Inf + else + astsize(n, op, costs) + end + end +end +export derivative_cost + + +""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction + +Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. + +Example: given a diffusivity constant a, the heat equation can be written as: +``` + ∂ₜ u = a * Δ(u) +``` +would return (u, a). + +A limitation of this function can be demonstrated here: given the model + ``` + ∂ₜ a = a + b + ∂ₜ b = a + b + ``` + we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, + the extractor would bypass it completely. +""" +function vfield(model, operator_lookup::Dict{TA, Any}=Dict{TA, Any}()) + roe = Roe() + (state_vars, param_vars) = model(roe) + length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") + state_rootvars = map(state_vars) do x + rv = extract!(x) + rv isa RootVar || error("all state variables must be RootVars") + rv + end + param_rootvars = map(param_vars) do p + rv = extract!(p) + rv isa RootVar || error("all param variables must be RootVars") + rv + end + + u = :u + p = :p + du = :du + + rootvar_lookup = + Dict{RootVar, Union{Expr, Symbol}}( + [ + [rv => :($(u)) for (i, rv) in enumerate(state_rootvars)]; + [rv => :($(p)) for (i, rv) in enumerate(param_rootvars)] + ] + ) + + cost = derivative_cost(Set([state_rootvars; param_rootvars])) + + extractor = EGraphs.Extractor(roe.graph, cost, Float64) + + function term_select(id) + EGraphs.find_best_node(extractor, id) + end + + ssa = SSA() + + derivative_vars = map(state_vars) do v + extract_ssa!(roe.graph, ssa, (∂ₜ(v)).id, term_select) + end + + toexpr(v::SSAVar) = Symbol("tmp%$(v.idx)") + + function toexpr(expr::SSAExpr) + if expr.fn isa RootVar + rootvar_lookup[expr.fn] + elseif expr.fn isa Number + expr.fn + else + op = operator_lookup[TypedApplication(expr.fn, first.(expr.args))] + # Decapodes dec_* functions yield a tuple of both in-place and out-of-place function. + # We choose the first. + if op isa Tuple + op = op[1] + end + Expr(:call, *, op, toexpr.(last.(expr.args))...) + end + end + + ssalines = map(enumerate(ssa.statements)) do (i, expr) + :($(toexpr(SSAVar(i))) = $(toexpr(expr))) + end + + set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) + :($(du) .= $(toexpr(v))) + end + + eval( + quote + ($du, $u, $p, _) -> begin + $(ssalines...) + $(set_derivative_stmts...) + end + end + ) +end +export vfield + + +end diff --git a/DEC/src/models/module.jl b/DEC/src/models/module.jl new file mode 100644 index 00000000..117b5f05 --- /dev/null +++ b/DEC/src/models/module.jl @@ -0,0 +1,9 @@ +module Models + +using Reexport + +include("ThDEC/ThDEC.jl") # the theory of the DEC + +@reexport using .ThDEC + +end diff --git a/DEC/src/roe/RoeUtility.jl b/DEC/src/roe/RoeUtility.jl new file mode 100644 index 00000000..81e0b7ce --- /dev/null +++ b/DEC/src/roe/RoeUtility.jl @@ -0,0 +1,222 @@ +# """ +# Defines Roe, a struct which acts as a wrapper for e-graph typed in the Sorts of a given theory, as well as functions for manipulating it. +# """ +# module RoeUtility + +# using ..SSAExtract: SSA, SSAVar, SSAExpr, extract_ssa! +using ..Util.HashColor + +using StructEquality +import Metatheory +using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH +using MLStyle +using Reexport + +""" +Sorts in each theory are subtypes of this abstract type. +""" +abstract type AbstractSort end +export AbstractSort + +""" TypedApplication + +Struct containing a Function and the vector of Sorts it requires. +""" +@struct_hash_equal struct TypedApplication + head::Function + sorts::Vector{AbstractSort} +end +export TypedApplication + +const TA = TypedApplication +export TA + +Base.show(io::IO, ta::TA) = print(io, Expr(:call, nameof(ta.head), ta.sorts...)) + +struct SortError <: Exception + message::String +end +export SortError + +""" RootVar + +A childless node on an e-graph. + +""" +@struct_hash_equal struct RootVar + name::Symbol + idx::Int + sort::AbstractSort +end +export RootVar + +""" Roe + +Struct for storing an EGraph and its variables. + +Roe is the name for lobster eggs. "Egg" is the name of a Rust implementation of e-graphs, by which Metatheory.jl is inspired by. Lobsters are part of the family Decapodes, which is also the name of the AlgebraicJulia package which motivated this package. Hence, Roe. +""" +struct Roe + variables::Vector{RootVar} + graph::EGraph{Expr, AbstractSort} + function Roe() + new(RootVar[], EGraph{Expr, AbstractSort}()) + end +end +export Roe + +""" + +A struct containing a Roe and the Id of a variable in that EGraph. The type parameter for this struct is the variable it represents. + +""" +struct Var{S} + roe::Roe + id::Id +end + +function EGraphs.make(g::EGraph{Expr, AbstractSort}, n::Metatheory.VecExpr) + op = EGraphs.get_constant(g,Metatheory.v_head(n)) + if op isa RootVar + op.sort + elseif op isa Number + Scalar() + else + op((g[arg].data for arg in Metatheory.v_children(n))...) + end +end +export make + +function EGraphs.join(s1::AbstractSort, s2::AbstractSort) + if s1 == s2 + s1 + else + error("Cannot equate two nodes with different sorts") + end +end +export join + +function extract!(v::Var, f=EGraphs.astsize) + extract!(v.roe.graph, f, v.id) +end +export extract! + +function rootvarcrayon(v::RootVar) + lightnessrange = (50., 100.) + HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) +end + +function Base.show(io::IO, v::RootVar) + if get(io, :color, true) + crayon = rootvarcrayon(v) + print(io, crayon, "$(v.name)") + print(io, inv(crayon)) + else + print(io, "$(v.name)#$(v.idx)") + end +end + +""" fix_functions(e)::Union{Symbol, Expr} + +Traverses the AST of an expression, replacing the head of :call expressions to its name, a Symbol. +""" +function fix_functions(e) + @match e begin + s::Symbol => s + Expr(:call, f::Function, args...) => + Expr(:call, nameof(f), fix_functions.(args)...) + Expr(head, args...) => + Expr(head, fix_functions.(args)...) + _ => e + end +end + +""" getexpr(v::Var)::Union{Symbol, Expr} + +Extracts an expression (::Var) from its Roe. + +""" +function getexpr(v::Var) + e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) + fix_functions(e) +end +export getexpr + +function Base.show(io::IO, v::Var) + print(io, getexpr(v)) +end + +""" fresh!(roe::Roe, sort::AbstractSort, name::Symbol)::Var{sort} + +Creates a new ("fresh") variable in a Roe given a sort and a name. + +Example: +``` +fresh!(roe, Form(0), :Temp) +``` + +""" +function fresh!(roe::Roe, sort::AbstractSort, name::Symbol) + v = RootVar(name, length(roe.variables), sort) + push!(roe.variables, v) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) + Var{sort}(roe, EGraphs.add!(roe.graph, n, false)) +end +export fresh! + + +@nospecialize +""" inject_number!(roe::Roe, x::Number)::Var{Scalar()} + +Adds a number to the Roe as a EGraph constant. + +""" +function inject_number!(roe::Roe, x::Number) + x = Float64(x) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) + Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) +end +export inject_number! + +@nospecialize +""" addcall!(g::EGraph, head, args):: + +Adds a call to an EGraph. + +""" +function addcall!(g::EGraph, head, args) + ar = length(args) + n = Metatheory.v_new(ar) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) + Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) + for i in Metatheory.v_children_range(n) + @inbounds n[i] = args[i - VECEXPR_META_LENGTH] + end + EGraphs.add!(g, n, false) +end +export addcall! + +""" equate!(v1::Var{s1}, v2::Var{s2})::EGraph + +Asserts that two variables of the same e-graph are the same. This is done by returning the union of the variable ids with the e-graph. +""" +function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) + v1.roe === v2.roe || error("Cannot equate variables from different graphs") + union!(v1.roe.graph, v1.id, v2.id) +end +export equate! + +""" +Infix synonym for `equate!` +""" +≐(v1::Var, v2::Var) = equate!(v1, v2) +export ≐ + +# end diff --git a/DEC/src/SSAExtract.jl b/DEC/src/roe/SSAExtract.jl similarity index 73% rename from DEC/src/SSAExtract.jl rename to DEC/src/roe/SSAExtract.jl index ea6126fc..bab8ea43 100644 --- a/DEC/src/SSAExtract.jl +++ b/DEC/src/roe/SSAExtract.jl @@ -1,22 +1,51 @@ module SSAExtract +# +using ..DEC: AbstractSort, TypedApplication, TA, Roe, RootVar + +# other dependencies using MLStyle +using Metatheory: VecExpr using Metatheory.EGraphs -using ..DEC: Sort using StructEquality +""" SSAVar + +A wrapper for the index of a SSAVar +""" @struct_hash_equal struct SSAVar idx::Int end +export SSAVar function Base.show(io::IO, v::SSAVar) print(io, "%", v.idx) end +""" SSAExpr + +A wrapper for a function (::Any) and its args (::Vector{Tuple{Sort, SSAVar}}). + +Example: the equation +``` + a = 1 + b +``` +may have an SSA dictionary +``` + %1 => a + %2 => +(%1, %3) + %3 => b +``` +and so `+` would have +``` +SSAExpr(+, [(Scalar(), SSAVar(1)), (Scalar(), SSAVar(2))]) +``` +""" @struct_hash_equal struct SSAExpr fn::Any - args::Vector{Tuple{Sort, SSAVar}} + args::Vector{Tuple{AbstractSort, SSAVar}} end +export SSAExpr function Base.show(io::IO, e::SSAExpr) print(io, e.fn) @@ -26,6 +55,9 @@ function Base.show(io::IO, e::SSAExpr) end """ + +Struct defining Static Single-Assignment information for a given roe. + Advantages of SSA form: 1. We can preallocate each matrix @@ -38,6 +70,7 @@ struct SSA new(Dict{Id, SSAVar}(), SSAExpr[]) end end +export SSA function Base.show(io::IO, ssa::SSA) println(io, "SSA: ") @@ -46,20 +79,29 @@ function Base.show(io::IO, ssa::SSA) end end +""" add_stmt!(ssa::SSA, id::Id, expr::SSAExpr)::SSAVar + +Given an SSA, add onto the assignment_lookup an SSAExpr. + +""" function add_stmt!(ssa::SSA, id::Id, expr::SSAExpr) push!(ssa.statements, expr) v = SSAVar(length(ssa.statements)) ssa.assignment_lookup[id] = v v end +export add_stmt! +# TODO is this idempotent? function hasid(ssa::SSA, id::Id) haskey(ssa.assignment_lookup, id) end +export hasid function getvar(ssa::SSA, id::Id) ssa.assignment_lookup[id] end +export getvar """ extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar @@ -74,6 +116,7 @@ The closure parameters control the behavior of this function. This closure selects, given an id in an EGraph, the term that we want to use in order to compute a value for that id + """ function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select)::SSAVar if hasid(ssa, id) @@ -91,4 +134,4 @@ function extract_ssa!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=b extract_ssa!(g, ssa, id, term_select) end -end \ No newline at end of file +end diff --git a/DEC/src/roe/module.jl b/DEC/src/roe/module.jl new file mode 100644 index 00000000..150ad779 --- /dev/null +++ b/DEC/src/roe/module.jl @@ -0,0 +1,11 @@ +# module RoeUtility + +using Reexport + +include("RoeUtility.jl") # vfield depends on SSAExtract +include("SSAExtract.jl") + +# @reexport using .RoeUtility +@reexport using .SSAExtract + +# end diff --git a/DEC/src/HashColor.jl b/DEC/src/util/HashColor.jl similarity index 100% rename from DEC/src/HashColor.jl rename to DEC/src/util/HashColor.jl diff --git a/DEC/src/util/module.jl b/DEC/src/util/module.jl new file mode 100644 index 00000000..8c8d51c5 --- /dev/null +++ b/DEC/src/util/module.jl @@ -0,0 +1,9 @@ +module Util + +using Reexport + +include("HashColor.jl") + +@reexport using .HashColor + +end diff --git a/DEC/tests/DEC.jl b/DEC/tests/DEC.jl index 01cf16ff..5649ce69 100644 --- a/DEC/tests/DEC.jl +++ b/DEC/tests/DEC.jl @@ -1,94 +1,22 @@ module TestDEC +# AlgebraicJulia dependencies using DEC -using DEC: Roe, SortError, d, fresh!, ∂ₜ, ∧, Δ, ≐, ★ +import DEC.ThDEC: Δ + +# other dependencies using Test using Metatheory.EGraphs +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} -@test Scalar() + Scalar() == Scalar() -@test Scalar() + PrimalForm(1) == PrimalForm(1) -@test PrimalForm(2) + Scalar() == PrimalForm(2) -@test_throws SortError PrimalForm(1) + PrimalForm(2) - -# Scalar Multiplication -@test Scalar() * Scalar() == Scalar() -@test Scalar() * PrimalForm(1) == PrimalForm(1) -@test PrimalForm(2) * Scalar() == PrimalForm(2) -@test_throws SortError PrimalForm(2) * PrimalForm(1) - -# Exterior Product -@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) - -roe = Roe() - -a = fresh!(roe, Scalar(), :a) -b = fresh!(roe, Scalar(), :b) - -x = a + b -y = a + b - -@test x == y -@test roe.graph[(a+b).id].data == Scalar() - -ω = fresh!(roe, PrimalForm(1), :ω) -η = fresh!(roe, PrimalForm(0), :η) - -@test ω ∧ η isa DEC.Var{PrimalForm(1)} -@test ω ∧ η == ω ∧ η - -@test_throws SortError x ≐ ω - -ω ≐ (ω ∧ η) - -∂ₜ(a) ≐ 3 * a + 5 - -EGraphs.extract!(∂ₜ(a), DEC.derivative_cost([DEC.extract!(a)])) - -function lotka_volterra(pode) - α = fresh!(pode, Scalar(), :α) - β = fresh!(pode, Scalar(), :β) - γ = fresh!(pode, Scalar(), :γ) - - w = fresh!(pode, Scalar(), :w) - s = fresh!(pode, Scalar(), :s) - - ∂ₜ(s) ≐ α * s - β * w * s - ∂ₜ(w) ≐ - γ * w - β * w * s - - ([w, s], [α, β, γ]) -end +# plotting +using CairoMakie -(ssa, derivative_vars) = DEC.vfield(lotka_volterra) - -basicprinted(x; color=false) = sprint(show, x; context=(:color=>color)) - -@test basicprinted(ssa) == """ -SSA: - %1 = γ#2 - %2 = -(%1::Scalar(),) - %3 = w#3 - %4 = *(%2::Scalar(), %3::Scalar()) - %5 = β#1 - %6 = *(%5::Scalar(), %3::Scalar()) - %7 = s#4 - %8 = *(%6::Scalar(), %7::Scalar()) - %9 = -(%4::Scalar(), %8::Scalar()) - %10 = α#0 - %11 = *(%10::Scalar(), %7::Scalar()) - %12 = -(%11::Scalar(), %8::Scalar()) -""" - -function transitivity(pode) - w = fresh!(pode, Scalar(), :w) - ∂ₜ(w) ≐ 1 * w - ∂ₜ(w) ≐ 2 * w - w -end -_w = transitivity(roe) -# picks whichever expression it happens to visit first -EGraphs.extract!((∂ₜ(_w)), DEC.derivative_cost([DEC.extract!(_w)])) - -## HEAT EQUATION +## 1-D HEAT EQUATION function heat_equation(pode) u = fresh!(pode, PrimalForm(0), :u) @@ -98,29 +26,28 @@ function heat_equation(pode) ([u], []) end -using CombinatorialSpaces -using GeometryBasics -using OrdinaryDiffEq -Point2D = Point2{Float64} -Point3D = Point3{Float64} - -rect = triangulated_grid(100, 100, 1, 1, Point3D) -d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect) -subdivide_duals!(d_rect, Circumcenter()) +# initialize primal and dual meshes. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); -operator_lookup = DEC.precompute_matrices(d_rect, DiagonalHodge()) +# precompule matrices from operators in the DEC theory. +operator_lookup = ThDEC.precompute_matrices(d_rect, DiagonalHodge()) -vf = DEC.vfield(heat_equation, operator_lookup) +# produce a vector field +vf = vfield(heat_equation, operator_lookup) -U = first.(d_rect[:point]) +# +U = first.(d_rect[:point]); +# TODO component arrays constants_and_parameters = () tₑ = 500.0 @info("Precompiling Solver") -prob = ODEProblem(vf, U, (0, tₑ), constants_and_parameters) -soln = solve(prob, Tsit5()) +prob = ODEProblem(vf, U, (0, tₑ), constants_and_parameters); +soln = solve(prob, Tsit5()); function save_dynamics(save_file_name) time = Observable(0.0) @@ -129,7 +56,6 @@ function save_dynamics(save_file_name) ax = CairoMakie.Axis(f[1,1], title = @lift("Heat at time $($time)")) gmsh = mesh!(ax, rect, color=h, colormap=:jet, colorrange=extrema(soln(tₑ))) - #Colorbar(f[1,2], gmsh, limits=extrema(soln(tₑ).h)) Colorbar(f[1,2], gmsh) timestamps = range(0, tₑ, step=5.0) record(f, save_file_name, timestamps; framerate = 15) do t @@ -137,4 +63,4 @@ function save_dynamics(save_file_name) end end -end \ No newline at end of file +end diff --git a/DEC/tests/Roe.jl b/DEC/tests/Roe.jl new file mode 100644 index 00000000..f1a04a4a --- /dev/null +++ b/DEC/tests/Roe.jl @@ -0,0 +1,60 @@ +module TestRoe + +using Test +using Metatheory.EGraphs + +# Test question: are function calls in our theory both idempotent and correctly typing expressions? + +# Instantiate a new Roe with two variables of type Var{Scalar} +roe = Roe() +a = fresh!(roe, Scalar(), :a) +b = fresh!(roe, Scalar(), :b) + +# Write the same expresison twice but with different variable bindings. We expect that each `+` dispatches its Var{S} method defined in Roe/RoeFunctions.jl, which adds a new call to the egraph. +x = a + b +y = a + b + +# We expect that `+` is idempotent; addcall! checks if the + call is already present in the egraph with the two ids for `a` and `b`. +@test x == y + +# We also check that the type of (a+b) is a Scalar. +@test roe.graph[(a+b).id].data == Scalar() + +# Test question: + +# Now we define two primal forms. +ω = fresh!(roe, PrimalForm(1), :ω) +η = fresh!(roe, PrimalForm(0), :η) + +# Is the wedge product of a 0-form and 1-form a 1-form? +@test ω ∧ η isa DEC.Var{PrimalForm(1)} + +# Is the addcall! function idempotent? +@test ω ∧ η == ω ∧ η + +@test_throws SortError x ≐ ω + +# Assert that ω is the same as the expression ω∧η +ω ≐ (ω ∧ η) + +# Test question: can we extract a term from the e-graph? + +# Assert to the egraph that ∂ₜ(a) is 3*a + 5 +∂ₜ(a) ≐ 3 * a + 5 + +EGraphs.extract!(∂ₜ(a), DEC.derivative_cost([DEC.extract!(a)])) + +# Test question: given a model with a partial derivative defined by two expressions with the same astsize, which expression is extracted? + +function transitivity(roe) + w = fresh!(roe, Scalar(), :w) + ∂ₜ(w) ≐ 1 * w + ∂ₜ(w) ≐ 2 * w + w +end +_w = transitivity(roe) +# picks whichever expression it happens to visit first +EGraphs.extract!((∂ₜ(_w)), DEC.derivative_cost([DEC.extract!(_w)])) + + +end diff --git a/DEC/tests/SSAExtract.jl b/DEC/tests/SSAExtract.jl index 5a125761..0bd1d87c 100644 --- a/DEC/tests/SSAExtract.jl +++ b/DEC/tests/SSAExtract.jl @@ -1,13 +1,61 @@ +# TODO under construction module TestSSAExtract +# AlgebraicJulia dependencies +using DEC: AbstractSort +import DEC.ThDEC + +# other dependencies using Test +using LinearAlgebra using Metatheory -using DEC -pode = Decapode() +# Test question: SSA + +function lotka_volterra(roe) + α = fresh!(roe, Scalar(), :α) + β = fresh!(roe, Scalar(), :β) + γ = fresh!(roe, Scalar(), :γ) + + w = fresh!(roe, Scalar(), :w) + s = fresh!(roe, Scalar(), :s) + + ∂ₜ(s) ≐ α * s - β * w * s + ∂ₜ(w) ≐ - γ * w - β * w * s + + ([w, s], [α, β, γ]) +end + +# a model ThRing => GL(ℝ) is necessary here +ops = Dict{TA, Any}( + TA(-, AbstractSort[Scalar()]) => -I, + TA(-, AbstractSort[Scalar(), Scalar()]) => I, + TA(*, AbstractSort[Scalar(), Scalar()]) => I,) + +(ssa, derivative_vars) = vfield(lotka_volterra, ops) + +basicprinted(x; color=false) = sprint(show, x; context=(:color=>color)) -a = fresh!(pode, Scalar(), :a) -b = fresh!(pode, Scalar(), :b) +@test basicprinted(ssa) == """ +SSA: + %1 = γ#2 + %2 = -(%1::Scalar(),) + %3 = w#3 + %4 = *(%2::Scalar(), %3::Scalar()) + %5 = β#1 + %6 = *(%5::Scalar(), %3::Scalar()) + %7 = s#4 + %8 = *(%6::Scalar(), %7::Scalar()) + %9 = -(%4::Scalar(), %8::Scalar()) + %10 = α#0 + %11 = *(%10::Scalar(), %7::Scalar()) + %12 = -(%11::Scalar(), %8::Scalar()) +""" + +roe = Roe() + +a = fresh!(roe, Scalar(), :a) +b = fresh!(roe, Scalar(), :b) ssa = SSAExtract.SSA() @@ -15,8 +63,8 @@ function term_select(g::EGraph, id::Id) g[id].nodes[1] end -extract_ssa!(pode.graph, ssa, (a + b).id, term_select) +extract_ssa!(roe.graph, ssa, (a + b).id, term_select) ssa -end \ No newline at end of file +end diff --git a/DEC/tests/Signature.jl b/DEC/tests/Signature.jl new file mode 100644 index 00000000..015b4860 --- /dev/null +++ b/DEC/tests/Signature.jl @@ -0,0 +1,25 @@ +module TestSignature + +using DEC + +using Test + +# ## SIGNATURE TESTS + +# Addition +@test Scalar() + Scalar() == Scalar() +@test Scalar() + PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) + Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(1) + PrimalForm(2) + +# Scalar Multiplication +@test Scalar() * Scalar() == Scalar() +@test Scalar() * PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) * Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(2) * PrimalForm(1) + +# Exterior Product +@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) +@test_throws SortError PrimalForm(1) ∧ Scalar() + +end diff --git a/DEC/tests/runtests.jl b/DEC/tests/runtests.jl new file mode 100644 index 00000000..7b00e4da --- /dev/null +++ b/DEC/tests/runtests.jl @@ -0,0 +1,17 @@ +using Test + +@testset "Signature" begin + include("Signature.jl") +end + +@testset "SSA Extraction" begin + include("SSAExtract.jl") +end + +@testset "Roe Utilities" begin + include("Roe.jl") +end + +@testset "ThDEC" begin + include("DEC.jl") +end From f15f27bbe8d6169ec5372ee2c35b2acb3fc978dd Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 23 Jul 2024 16:15:09 -0400 Subject: [PATCH 2/5] submitting for cleanup --- DEC/docs/make.jl | 57 +++++++++++++++++++++++++++++++ DEC/src/models/ThDEC/EGraph.jl | 4 +++ DEC/src/models/ThDEC/Signature.jl | 20 ++++++++++- DEC/src/models/ThDEC/ThDEC.jl | 31 +++++++++++++---- DEC/src/roe/SSAExtract.jl | 4 +-- DEC/tests/DEC.jl | 28 ++++++++++++++- 6 files changed, 133 insertions(+), 11 deletions(-) create mode 100644 DEC/docs/make.jl diff --git a/DEC/docs/make.jl b/DEC/docs/make.jl new file mode 100644 index 00000000..6cd2cd1e --- /dev/null +++ b/DEC/docs/make.jl @@ -0,0 +1,57 @@ +using Documenter +using Literate +using Distributed + +using DEC + +using CairoMakie + +# Set Literate.jl config if not being compiled on recognized service. +# config = Dict{String,String}() +# if !(haskey(ENV, "GITHUB_ACTIONS") || haskey(ENV, "GITLAB_CI")) +# config["nbviewer_root_url"] = "https://nbviewer.jupyter.org/github/AlgebraicJulia/DEC.jl/blob/gh-pages/dev" +# config["repo_root_url"] = "https://github.com/AlgebraicJulia/Decapodes.jl/blob/main/docs" +# end + +const literate_dir = joinpath(@__DIR__, "..", "examples") +const generated_dir = joinpath(@__DIR__, "src", "examples") + +@info "Building literate files" +for (root, dirs, files) in walkdir(literate_dir) + out_dir = joinpath(generated_dir, relpath(root, literate_dir)) + pmap(files) do file + f,l = splitext(file) + if l == ".jl" && !startswith(f, "_") + Literate.markdown(joinpath(root, file), out_dir; + config=config, documenter=true, credit=false) + Literate.notebook(joinpath(root, file), out_dir; + execute=true, documenter=true, credit=false) + end + end +end +@info "Completed literate" + +pages = Any[] +push!(pages, "DEC.jl" => "index.md") +push!(pages, "Library Reference" => "api.md") + +@info "Building Documenter.jl docs" +makedocs( + modules = [Decapodes], + format = Documenter.HTML( + assets = ["assets/analytics.js"], + ), + remotes = nothing, + sitename = "DEC.jl", + doctest = false, + checkdocs = :none, + pages = pages) + + +@info "Deploying docs" +deploydocs( + target = "build", + repo = "github.com/AlgebraicJulia/DEC.jl.git", + branch = "gh-pages", + devbranch = "main" +) diff --git a/DEC/src/models/ThDEC/EGraph.jl b/DEC/src/models/ThDEC/EGraph.jl index df8757f9..837b52bd 100644 --- a/DEC/src/models/ThDEC/EGraph.jl +++ b/DEC/src/models/ThDEC/EGraph.jl @@ -80,4 +80,8 @@ export ★ Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) export Δ +# ι +# ♯ +# ♭ + # end diff --git a/DEC/src/models/ThDEC/Signature.jl b/DEC/src/models/ThDEC/Signature.jl index f517554b..f0949ee1 100644 --- a/DEC/src/models/ThDEC/Signature.jl +++ b/DEC/src/models/ThDEC/Signature.jl @@ -21,7 +21,9 @@ export PrimalForm DualForm(i::Int) = Form(i, true) export DualForm -Base.show(io::IO, ω::Form) = print(io, ω.isdual ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +function Base.show(io::IO, ω::Form) + print(io, ω.isdual ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +end ## OPERATIONS @@ -96,3 +98,19 @@ function ★(s::Sort) end end +function ι(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i1, isdual1), Form(i2, isdual2)) => + if i1 == 1 && i2 ∈ [1,2] && isdual1 == isdual2 + Form(i2 - 1, isdual2) + else + # TODO fix this error message + throw(SortError("Cannot take the interior product of these forms.")) + end + (Scalar(), _) || (_, Scalar()) => throw(SortError("Cannot take the interior product involving scalars")) + end +end + +function ♯(s::Sort) end + +function ♭(s::Sort) end diff --git a/DEC/src/models/ThDEC/ThDEC.jl b/DEC/src/models/ThDEC/ThDEC.jl index 3c2b48a1..4e1d16be 100644 --- a/DEC/src/models/ThDEC/ThDEC.jl +++ b/DEC/src/models/ThDEC/ThDEC.jl @@ -90,12 +90,29 @@ function vfield(model, operator_lookup::Dict{TA, Any}=Dict{TA, Any}()) toexpr(v::SSAVar) = Symbol("tmp%$(v.idx)") function toexpr(expr::SSAExpr) - if expr.fn isa RootVar - rootvar_lookup[expr.fn] - elseif expr.fn isa Number - expr.fn + @match expr.head begin + ::RootVar => rootvar_lookup[expr.head] + ::Number => expr.head + _ => begin + op = operator_lookup[TA(expr.head, first.(expr.args))] + if op isa Tuple + op = op[1] + end + Expr(:call, *, op, toexpr.(last.(expr.args))...) + end + end + end + + function _toexpr(expr::SSAExpr) + if expr.head isa RootVar + rootvar_lookup[expr.head] + elseif expr.head isa Number + expr.head + elseif expr.head == :* + else - op = operator_lookup[TypedApplication(expr.fn, first.(expr.args))] + @info expr.args + op = operator_lookup[TypedApplication(expr.head, first.(expr.args))] # Decapodes dec_* functions yield a tuple of both in-place and out-of-place function. # We choose the first. if op isa Tuple @@ -106,11 +123,11 @@ function vfield(model, operator_lookup::Dict{TA, Any}=Dict{TA, Any}()) end ssalines = map(enumerate(ssa.statements)) do (i, expr) - :($(toexpr(SSAVar(i))) = $(toexpr(expr))) + :($(toexpr(SSAVar(i))) = $(toexpr(expr))) end set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) - :($(du) .= $(toexpr(v))) + :($(du) .= $(toexpr(v))) end eval( diff --git a/DEC/src/roe/SSAExtract.jl b/DEC/src/roe/SSAExtract.jl index bab8ea43..f7eeadd2 100644 --- a/DEC/src/roe/SSAExtract.jl +++ b/DEC/src/roe/SSAExtract.jl @@ -42,13 +42,13 @@ SSAExpr(+, [(Scalar(), SSAVar(1)), (Scalar(), SSAVar(2))]) ``` """ @struct_hash_equal struct SSAExpr - fn::Any + head::Any args::Vector{Tuple{AbstractSort, SSAVar}} end export SSAExpr function Base.show(io::IO, e::SSAExpr) - print(io, e.fn) + print(io, e.head) if length(e.args) > 0 print(io, Expr(:tuple, (Expr(:(::), v, sort) for (sort, v) in e.args)...)) end diff --git a/DEC/tests/DEC.jl b/DEC/tests/DEC.jl index 5649ce69..aeea45bf 100644 --- a/DEC/tests/DEC.jl +++ b/DEC/tests/DEC.jl @@ -2,11 +2,12 @@ module TestDEC # AlgebraicJulia dependencies using DEC -import DEC.ThDEC: Δ +import DEC.ThDEC: Δ # conflicts with CombinatorialSpaces # other dependencies using Test using Metatheory.EGraphs +using ComponentArrays using CombinatorialSpaces using GeometryBasics using OrdinaryDiffEq @@ -63,4 +64,29 @@ function save_dynamics(save_file_name) end end +## 1-D HEAT EQUATION WITH DIFFUSIVITY + +function new_heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + + ∂ₜ(u) ≐ k * Δ(u) + + ([u], [k]) end + +# we can reuse the mesh and operator lookup +_vf = vfield(new_heat_equation, operator_lookup) + +# we can reuse the initial condition U + +# +constants_and_parameters = ComponentArray(k=0.5,); + +t0 = 50 + +@info("Precompiling solver") +prob = ODEProblem(_vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + + From 94ef1ce3d6abf617ad6081d4e4e8bdb375c6b2af Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 25 Jul 2024 10:23:36 -0400 Subject: [PATCH 3/5] completing code review --- DEC/Project.toml | 1 + DEC/docs/literate/heatequation.jl | 74 +++++++ DEC/docs/literate/tutorial.jl | 52 +++++ DEC/docs/make.jl | 2 +- DEC/src/DEC.jl | 95 ++++++++- DEC/src/SSAs.jl | 141 +++++++++++++ DEC/src/models/ThDEC/EGraph.jl | 87 -------- DEC/src/models/ThDEC/Luke.jl | 72 ------- DEC/src/models/ThDEC/Signature.jl | 116 ----------- DEC/src/models/ThDEC/ThDEC.jl | 145 ------------- DEC/src/roe.jl | 257 ++++++++++++++++++++++++ DEC/src/roe/RoeUtility.jl | 222 -------------------- DEC/src/roe/SSAExtract.jl | 137 ------------- DEC/src/roe/module.jl | 11 - DEC/src/theories/ThDEC/ThDEC.jl | 15 ++ DEC/src/theories/ThDEC/roe_overloads.jl | 56 ++++++ DEC/src/theories/ThDEC/semantics.jl | 73 +++++++ DEC/src/theories/ThDEC/signature.jl | 155 ++++++++++++++ DEC/src/{models => theories}/module.jl | 2 +- DEC/src/util/Plotting.jl | 20 ++ DEC/src/util/module.jl | 2 + DEC/src/vfield.jl | 156 ++++++++++++++ DEC/{tests => test}/SSAExtract.jl | 2 +- DEC/test/ThDEC/ThDEC.jl | 19 ++ DEC/test/ThDEC/model.jl | 73 +++++++ DEC/{tests/Roe.jl => test/ThDEC/roe.jl} | 0 DEC/test/ThDEC/semantics.jl | 0 DEC/test/ThDEC/signature.jl | 40 ++++ DEC/{tests => test}/runtests.jl | 2 +- DEC/tests/DEC.jl | 92 --------- DEC/tests/Signature.jl | 25 --- 31 files changed, 1228 insertions(+), 916 deletions(-) create mode 100644 DEC/docs/literate/heatequation.jl create mode 100644 DEC/docs/literate/tutorial.jl create mode 100644 DEC/src/SSAs.jl delete mode 100644 DEC/src/models/ThDEC/EGraph.jl delete mode 100644 DEC/src/models/ThDEC/Luke.jl delete mode 100644 DEC/src/models/ThDEC/Signature.jl delete mode 100644 DEC/src/models/ThDEC/ThDEC.jl create mode 100644 DEC/src/roe.jl delete mode 100644 DEC/src/roe/RoeUtility.jl delete mode 100644 DEC/src/roe/SSAExtract.jl delete mode 100644 DEC/src/roe/module.jl create mode 100644 DEC/src/theories/ThDEC/ThDEC.jl create mode 100644 DEC/src/theories/ThDEC/roe_overloads.jl create mode 100644 DEC/src/theories/ThDEC/semantics.jl create mode 100644 DEC/src/theories/ThDEC/signature.jl rename DEC/src/{models => theories}/module.jl (85%) create mode 100644 DEC/src/util/Plotting.jl create mode 100644 DEC/src/vfield.jl rename DEC/{tests => test}/SSAExtract.jl (96%) create mode 100644 DEC/test/ThDEC/ThDEC.jl create mode 100644 DEC/test/ThDEC/model.jl rename DEC/{tests/Roe.jl => test/ThDEC/roe.jl} (100%) create mode 100644 DEC/test/ThDEC/semantics.jl create mode 100644 DEC/test/ThDEC/signature.jl rename DEC/{tests => test}/runtests.jl (88%) delete mode 100644 DEC/tests/DEC.jl delete mode 100644 DEC/tests/Signature.jl diff --git a/DEC/Project.toml b/DEC/Project.toml index 9ce6c275..8c492bb8 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -15,6 +15,7 @@ Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" [compat] diff --git a/DEC/docs/literate/heatequation.jl b/DEC/docs/literate/heatequation.jl new file mode 100644 index 00000000..03b1789d --- /dev/null +++ b/DEC/docs/literate/heatequation.jl @@ -0,0 +1,74 @@ +# Load AlgebraicJulia dependencies +using DEC +import DEC.ThDEC: Δ # conflicts with CombinatorialSpaces + +# load other dependencies +using ComponentArrays +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} +using CairoMakie + +## Here we define the 1D heat equation model with one state variable and no parameters. That is, given an e-graph "roe," we define `u` to be a primal 0-form. The root variable carries a reference to the e-graph which it resides in. We then assert that the time derivative of the state is just its Laplacian. We return the state variable. +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# Since this is a model in the DEC, we need to initialize the primal and dual meshes. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# Now that we have a dual mesh, we can associate operators in our theory with precomputed matrices from Decapodes.jl. +op_lookup = ThDEC.precompute_matrices(d_rect, DiagonalHodge()) + +# Now we produce a "vector field" function which, given a model and operators in a theory, returns a function to be passed to the ODESolver. In stages, this function +# +# 1) extracts the Root Variables (state or parameter term) and runs the extractor along the e-graph, +# 2) extracts the derivative terms from the model into an SSA +# 3) yields a function accepting derivative terms, state terms, and parameter terms, whose body is both the lines, and derivatives. +vf = vfield(heat_equation, op_lookup) + +# Let's initialize the +U = first.(d_rect[:point]); + +# TODO component arrays +constants_and_parameters = () + +# We will run this for 500 timesteps. +t0 = 500.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +## 1-D HEAT EQUATION WITH DIFFUSIVITY + +function heat_equation_with_constants(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + ℓ = fresh!(roe, Scalar(), :ℓ) + + ∂ₜ(u) ≐ k * Δ(u) + + ([u], [k]) +end + +# we can reuse the mesh and operator lookup +vf = vfield(heat_equation_with_constants, operator_lookup) + +# we can reuse the initial condition U but are specifying diffusivity constants. +constants_and_parameters = ComponentArray(k=0.25,); +t0 = 500 + +@info("Precompiling solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + + diff --git a/DEC/docs/literate/tutorial.jl b/DEC/docs/literate/tutorial.jl new file mode 100644 index 00000000..31d1885c --- /dev/null +++ b/DEC/docs/literate/tutorial.jl @@ -0,0 +1,52 @@ +# This tutorial is a slower-paced introduction into the design. Here, we will construct a simple exponential model. +using DEC +using Test +using Metatheory.EGraphs +using ComponentArrays +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} + +using CairoMakie + +# We define our model of exponential growth. This model is a function which accepts a Roe and returns a tuple of State and Parameter variables. Let's break it down: +# +# 1. Function adds root variables (::RootVar) to the Roe. The root variables have no child nodes. +# 2. Our model makes claims about what terms equal one another. The "≐" operator is an infix of "equate!" which claims unites the ids of the left and right VecExprs. +# 3. The State and Parameter variables are returned. Each variable points to the same parent Roe. +# +# +# Each variable points to the same Roe. +function exp_growth(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + + ∂ₜ(u) ≐ k * u + + ([u], [k]) +end + +# We now need to initialize the primal and dual meshes we'll need to compute with. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# For the theory of the DEC, we will need to associate each operator to the precomputed matrix specific to our dual mesh. +operator_lookup = ThDEC.create_dynamic_model(d_rect, DiagonalHodge()) + +# We now need to convert our model to an ODEProblem. In our case, ``vfield`` produces +vf = vfield(exp_growth, operator_lookup) + +U = first.(d_rect[:point]); + +constants_and_parameters = ComponentArray(k=-0.5,) + +t0 = 50.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +save_dynamics(soln, "decay.gif") + diff --git a/DEC/docs/make.jl b/DEC/docs/make.jl index 6cd2cd1e..7fa424c3 100644 --- a/DEC/docs/make.jl +++ b/DEC/docs/make.jl @@ -37,7 +37,7 @@ push!(pages, "Library Reference" => "api.md") @info "Building Documenter.jl docs" makedocs( - modules = [Decapodes], + modules = [DEC], format = Documenter.HTML( assets = ["assets/analytics.js"], ), diff --git a/DEC/src/DEC.jl b/DEC/src/DEC.jl index f913060c..f9a50f9e 100644 --- a/DEC/src/DEC.jl +++ b/DEC/src/DEC.jl @@ -6,17 +6,102 @@ using MLStyle using Reexport using StructEquality import Metatheory -using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH +using Metatheory: EGraph, EGraphs, Id, astsize +using Metatheory: VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH import Metatheory: extract! import Base: +, -, * include("util/module.jl") # Pretty-printing -include("roe/module.jl") # Checking signature for DEC operations -include("models/module.jl") # manipulating SSAs +include("roe.jl") # Checking signature for DEC operations +include("SSAs.jl") # manipulating SSAs +include("vfield.jl") # producing a vector field function + +# currently this only holds the DEC +include("theories/module.jl") @reexport using .Util -@reexport using .SSAExtract -@reexport using .Models +@reexport using .SSAs +@reexport using .Theories + +# function vfield(model, operator_lookup::Dict{TA, Any}) +# roe = Roe(DEC.ThDEC.Sort) + +# (state_vars, param_vars) = model(roe) +# length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") +# state_rootvars = map(state_vars) do x +# rv = extract!(x) +# rv isa RootVar ? rv : error("all state variables must be RootVars") +# end +# param_rootvars = map(param_vars) do p +# rv = extract!(p) +# rv isa RootVar ? rv : error("all param variables must be RootVars") +# end + +# u = :u +# p = :p +# du = :du + +# rootvar_lookup = +# Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}}( +# [ +# [rv => (:($(u)), false) for rv in state_rootvars]; +# [rv => (:($(p)), true) for rv in param_rootvars] +# ] +# ) + +# cost = derivative_cost(Set([state_rootvars; param_rootvars])) + +# extractor = EGraphs.Extractor(roe.graph, cost, Float64) + +# function term_select(id) +# EGraphs.find_best_node(extractor, id) +# end + +# ssa = SSA() + +# # TODO overload extract! to index by graph +# derivative_vars = map(state_vars) do v +# extract!(roe.graph, ssa, (∂ₜ(v)).id, term_select) +# end + +# toexpr(v::DEC.SSAs.Var) = Symbol("tmp%$(v.idx)") + +# function toexpr(expr::Term) +# @match expr.head begin +# ::RootVar => @match rootvar_lookup[expr.head] begin +# (v, false) => v +# # evaluates in DEC.k, and this gets the index +# (v, true) => Expr(:ref, v, expr.head.name) +# end +# ::Number => expr.head +# _ => begin +# op = get(operator_lookup, TA(expr.head, first.(expr.args))) +# # Decapode operators return a tuple of functions. We choose the first of these. +# if op isa Tuple +# op = op[1] +# end +# Expr(:call, *, op, toexpr.(last.(expr.args))...) +# end +# end +# end + +# ssalines = map(enumerate(ssa.statements)) do (i, expr) +# :($(toexpr(SSAs.Var(i))) = $(toexpr(expr))) +# end + +# set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) +# :($(du) .= $(toexpr(v))) +# end + +# # yield function +# eval(quote +# f(du, u, p, _) = begin +# $(ssalines...) +# $(set_derivative_stmts...) +# end +# end) +# end +# export vfield end diff --git a/DEC/src/SSAs.jl b/DEC/src/SSAs.jl new file mode 100644 index 00000000..5d97f643 --- /dev/null +++ b/DEC/src/SSAs.jl @@ -0,0 +1,141 @@ +module SSAs + +using ..DEC: AbstractSort, TypedApplication, TA, Roe, RootVar + +# other dependencies +using MLStyle +using StructEquality +using Metatheory: VecExpr +using Metatheory.EGraphs +import Metatheory: extract! + +""" Var + +A wrapper for the index of a Var +""" +@struct_hash_equal struct Var + idx::Int +end +export Var + +function Base.show(io::IO, v::Var) + print(io, "%", v.idx) +end + +""" Term + +A wrapper for a function (::Any) and its args (::Vector{Tuple{Sort, Var}}). + +Example: the equation +``` + a = 1 + b +``` +may have an SSA dictionary +``` + %1 => a + %2 => +(%1, %3) + %3 => b +``` +and so `+` would have +``` +Term(+, [(Scalar(), Var(1)), (Scalar(), Var(2))]) +``` +""" +@struct_hash_equal struct Term + head::Any + args::Vector{Tuple{AbstractSort, Var}} +end +export Term + +function Base.show(io::IO, e::Term) + print(io, e.head) + if length(e.args) > 0 + print(io, Expr(:tuple, (Expr(:(::), v, sort) for (sort, v) in e.args)...)) + end +end + +""" + +Struct defining Static Single-Assignment information for a given roe. + +Advantages of SSA form: + +1. We can preallocate each matrix +2. We can run a register-allocation algorithm to minimize the number of matrices that we have to preallocate +""" +struct SSA + assignment_lookup::Dict{Id, Var} + statements::Vector{Term} + function SSA() + new(Dict{Id, Var}(), Term[]) + end +end +export SSA + +# accessors +statements(ssa::SSA) = ssa.statements +export statements + +# show methods + +function Base.show(io::IO, ssa::SSA) + println(io, "SSA: ") + for (i, expr) in enumerate(statements(ssa)) + println(io, " ", Var(i), " = ", expr) + end +end + +""" add_stmt!(ssa::SSA, id::Id, expr::Term)::Var + +Low-level function which, given an SSA, adds a Term onto the assignment_lookup. Users should use `extract!` instead. + +""" +function add_stmt!(ssa::SSA, id::Id, expr::Term) + push!(ssa.statements, expr) + v = Var(length(ssa.statements)) + ssa.assignment_lookup[id] = v + v +end + +Base.contains(ssa::SSA, id::Id) = haskey(ssa.assignment_lookup, id) +export contains + +Base.getindex(ssa::SSA, id::Id) = ssa.assignment_lookup[id] +export getindex + +""" + extract!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::Var + +This function adds (recursively) the necessary lines to the SSA in order to +compute a value for `id`, and then returns the Var that the value for `id` +will be assigned to. + +The closure parameters control the behavior of this function. + + term_select(id::Id)::VecExpr + +This closure selects, given an id in an EGraph, the term that we want to use in +order to compute a value for that id + +""" +function extract!(g::EGraph, ssa::SSA, id::Id, term_select) + if contains(ssa, id) + return getindex(ssa, id) + end + term = term_select(id) + args = map(EGraphs.v_children(term)) do arg + (g[arg].data, extract!(g, ssa, arg, term_select)) + end + add_stmt!(ssa, id, Term(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) +end +export extract! + +function extract!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) + extract!(g, ssa, id, term_select) +end + +function extract!(roe::Roe{S}, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) where S + extract!(roe, ssa, id, term_select) +end + +end diff --git a/DEC/src/models/ThDEC/EGraph.jl b/DEC/src/models/ThDEC/EGraph.jl deleted file mode 100644 index 837b52bd..00000000 --- a/DEC/src/models/ThDEC/EGraph.jl +++ /dev/null @@ -1,87 +0,0 @@ -using ...DEC: Var, addcall! - -import Base: +, -, * - -# These operations create calls on a common egraph. We validate the signature by dispatching the operation on the types using methods we defined in Signature. - -@nospecialize -function +(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot add variables from different graphs") - s = s1 + s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, +, (v1.id, v2.id))) -end -export + - -@nospecialize -+(v::Var, x::Number) = +(v, inject_number!(v.roe, x)) - -@nospecialize -+(x::Number, v::Var) = +(inject_number!(v.roe, x), v) - -@nospecialize -function -(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe == v2.roe || error("Cannot subtract variables from different graphs") - s = s1 - s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, -, (v1.id, v2.id))) -end -export - - -@nospecialize --(v::Var{s}) where {s} = Var{s}(v.roe, addcall!(v.roe.graph, -, (v.id,))) - -@nospecialize --(v::Var, x::Number) = -(v, inject_number!(v.roe, x)) - -@nospecialize --(x::Number, v::Var) = -(inject_number!(v.roe, x), v) - -@nospecialize -function *(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot multiply variables from different graphs") - s = s1 * s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, *, (v1.id, v2.id))) -end -export * - -@nospecialize -*(v::Var, x::Number) = *(v, inject_number!(v.roe, x)) - -@nospecialize -*(x::Number, v::Var) = *(inject_number!(v.roe, x), v) - -@nospecialize -function ∧(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.roe === v2.roe || error("Cannot wedge variables from different graphs") - s = s1 ∧ s2 - Var{s}(v1.roe, addcall!(v1.roe.graph, ∧, (v1.id, v2.id))) -end -export ∧ - -@nospecialize -function ∂ₜ(v::Var{s}) where {s} - Var{s}(v.roe, addcall!(v.roe.graph, ∂ₜ, (v.id,))) -end -export ∂ₜ - -@nospecialize -function d(v::Var{s}) where {s} - s′ = d(s) - Var{s′}(v.roe, addcall!(v.roe.graph, d, (v.id,))) -end -export d - -@nospecialize -function ★(v::Var{s}) where {s} - s′ = ★(s) - Var{s′}(v.roe, addcall!(v.roe.graph, ★, (v.id,))) -end -export ★ - -Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) -export Δ - -# ι -# ♯ -# ♭ - -# end diff --git a/DEC/src/models/ThDEC/Luke.jl b/DEC/src/models/ThDEC/Luke.jl deleted file mode 100644 index 22affa0b..00000000 --- a/DEC/src/models/ThDEC/Luke.jl +++ /dev/null @@ -1,72 +0,0 @@ -import Decapodes -using StructEquality - - -""" precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} - -Given a matrix and a hodge star (DiagonalHodge() or GeometricHodge()), this returns a lookup dictionary between operators (as TypedApplications) and their corresponding matrices. - -""" -function precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} - Dict{TypedApplication, Any}( - # Regular Hodge Stars - TA(★, Sort[PrimalForm(0)]) => Decapodes.dec_mat_hodge(0, sd, hodge), - TA(★, Sort[PrimalForm(1)]) => Decapodes.dec_mat_hodge(1, sd, hodge), - TA(★, Sort[PrimalForm(2)]) => Decapodes.dec_mat_hodge(2, sd, hodge), - - # Inverse Hodge Stars - TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(1, sd, hodge), - # why is this 1??? - TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), - # Special since Geo is a solver - TA(★, Sort[DualForm(2)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), - - # Differentials - TA(d, Sort[PrimalForm(0)]) => Decapodes.dec_mat_differential(0, sd), - TA(d, Sort[PrimalForm(1)]) => Decapodes.dec_mat_differential(1, sd), - - # Dual Differentials - TA(d, Sort[DualForm(0)]) => Decapodes.dec_mat_dual_differential(0, sd), - TA(d, Sort[DualForm(1)]) => Decapodes.dec_mat_dual_differential(1, sd), - - # Wedge Products - TA(∧, Sort[PrimalForm(0), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{0,1}, sd), - TA(∧, Sort[PrimalForm(1), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{1,0}, sd), - TA(∧, Sort[PrimalForm(0), PrimalForm(2)]) => Decapodes.dec_pair_wedge_product(Tuple{0,2}, sd), - TA(∧, Sort[PrimalForm(2), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{2,0}, sd), - TA(∧, Sort[PrimalForm(1), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{1,1}, sd), - - # Primal-Dual Wedge Products - TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{1,1}, sd), - TA(∧, Sort[PrimalForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{0,1}, sd), - TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dp(Tuple{1,1}, sd), - TA(∧, Sort[PrimalForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dp(Tuple{1,0}, sd), - - # Dual-Dual Wedge Products - # TA(∧, Sort[DualForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{1,1}, sd), - TA(∧, Sort[DualForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dd(Tuple{1,0}, sd), - TA(∧, Sort[DualForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{0,1}, sd), - - # # Dual-Dual Interior Products - # :ι₁₁ => interior_product_dd(Tuple{1,1}, sd) - # :ι₁₂ => interior_product_dd(Tuple{1,2}, sd) - - # # Dual-Dual Lie Derivatives - # :ℒ₁ => ℒ_dd(Tuple{1,1}, sd) - - # # Dual Laplacians - # :Δᵈ₀ => Δᵈ(Val{0},sd) - # :Δᵈ₁ => Δᵈ(Val{1},sd) - - # # Musical Isomorphisms - # :♯ => Decapodes.dec_♯_p(sd) - # :♯ᵈ => Decapodes.dec_♯_d(sd) - - # :♭ => Decapodes.dec_♭(sd) - - # # Averaging Operator - # :avg₀₁ => Decapodes.dec_avg₀₁(sd) - - ) -end - diff --git a/DEC/src/models/ThDEC/Signature.jl b/DEC/src/models/ThDEC/Signature.jl deleted file mode 100644 index f0949ee1..00000000 --- a/DEC/src/models/ThDEC/Signature.jl +++ /dev/null @@ -1,116 +0,0 @@ -using ...DEC: AbstractSort, SortError - -using MLStyle - -import Base: +, -, * - -# Define the sorts in your theory. -# For the DEC, we work with Scalars and Forms, graded objects which can also be primal or dual. -@data Sort <: AbstractSort begin - Scalar() - Form(dim::Int, isdual::Bool) -end -export Scalar, Form - -dim(f::Form) = f.dim -duality(f::Form) = f.isdual ? "dual" : "primal" - -PrimalForm(i::Int) = Form(i, false) -export PrimalForm - -DualForm(i::Int) = Form(i, true) -export DualForm - -function Base.show(io::IO, ω::Form) - print(io, ω.isdual ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") -end - -## OPERATIONS - -@nospecialize -function +(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || - (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(i1, isdual1), Form(i2, isdual2)) => - if (i1 == i2) && (isdual1 == isdual2) - Form(i1, isdual1) - else - throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) - end - end -end - -# Type-checking inverse of addition follows addition --(s1::Sort, s2::Sort) = +(s1, s2) - -# Negation is valid --(s::Sort) = s - -@nospecialize -function *(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || - (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) - end -end - -@nospecialize -function ∧(s1::Sort, s2::Sort) - @match (s1, s2) begin - (_, Scalar()) || (Scalar(), _) => throw(SortError("Cannot take a wedge product with a scalar")) - (Form(i1, isdual1), Form(i2, isdual2)) => - if isdual1 == isdual2 - if i1 + i2 <= 2 - Form(i1 + i2, isdual1) - else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) - end - else - throw(SortError("Cannot wedge two forms of different dualities: attempted to wedge $(duality(s1)) and $(duality(s2))")) - end - end -end - -@nospecialize -∂ₜ(s::Sort) = s - -@nospecialize -function d(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) - Form(i, isdual) => - if i <= 1 - Form(i + 1, isdual) - else - throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) - end - end -end - -function ★(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) - Form(i, isdual) => Form(2 - i, !isdual) - end -end - -function ι(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Form(i1, isdual1), Form(i2, isdual2)) => - if i1 == 1 && i2 ∈ [1,2] && isdual1 == isdual2 - Form(i2 - 1, isdual2) - else - # TODO fix this error message - throw(SortError("Cannot take the interior product of these forms.")) - end - (Scalar(), _) || (_, Scalar()) => throw(SortError("Cannot take the interior product involving scalars")) - end -end - -function ♯(s::Sort) end - -function ♭(s::Sort) end diff --git a/DEC/src/models/ThDEC/ThDEC.jl b/DEC/src/models/ThDEC/ThDEC.jl deleted file mode 100644 index 4e1d16be..00000000 --- a/DEC/src/models/ThDEC/ThDEC.jl +++ /dev/null @@ -1,145 +0,0 @@ -module ThDEC - -using ...DEC: TypedApplication, TA, Roe, RootVar -using ...DEC.SSAExtract - -using Metatheory: VecExpr -using Metatheory.EGraphs - -include("Signature.jl") # verify the signature holds -include("EGraph.jl") # overload DEC operations to act on roe (egraphs) -include("Luke.jl") # represent operations as matrices - -@nospecialize -""" derivative_cost(allowed_roots)::Function - -Returns a function `cost(n::Metatheory.VecExpr, op, costs)` which sets the cost of operations to Inf if they are either ∂ₜ or forbidden RootVars. Otherwise it computes the astsize. - -""" -function derivative_cost(allowed_roots) - function cost(n::VecExpr, op, costs) - if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) - Inf - else - astsize(n, op, costs) - end - end -end -export derivative_cost - - -""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction - -Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. - -Example: given a diffusivity constant a, the heat equation can be written as: -``` - ∂ₜ u = a * Δ(u) -``` -would return (u, a). - -A limitation of this function can be demonstrated here: given the model - ``` - ∂ₜ a = a + b - ∂ₜ b = a + b - ``` - we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, - the extractor would bypass it completely. -""" -function vfield(model, operator_lookup::Dict{TA, Any}=Dict{TA, Any}()) - roe = Roe() - (state_vars, param_vars) = model(roe) - length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") - state_rootvars = map(state_vars) do x - rv = extract!(x) - rv isa RootVar || error("all state variables must be RootVars") - rv - end - param_rootvars = map(param_vars) do p - rv = extract!(p) - rv isa RootVar || error("all param variables must be RootVars") - rv - end - - u = :u - p = :p - du = :du - - rootvar_lookup = - Dict{RootVar, Union{Expr, Symbol}}( - [ - [rv => :($(u)) for (i, rv) in enumerate(state_rootvars)]; - [rv => :($(p)) for (i, rv) in enumerate(param_rootvars)] - ] - ) - - cost = derivative_cost(Set([state_rootvars; param_rootvars])) - - extractor = EGraphs.Extractor(roe.graph, cost, Float64) - - function term_select(id) - EGraphs.find_best_node(extractor, id) - end - - ssa = SSA() - - derivative_vars = map(state_vars) do v - extract_ssa!(roe.graph, ssa, (∂ₜ(v)).id, term_select) - end - - toexpr(v::SSAVar) = Symbol("tmp%$(v.idx)") - - function toexpr(expr::SSAExpr) - @match expr.head begin - ::RootVar => rootvar_lookup[expr.head] - ::Number => expr.head - _ => begin - op = operator_lookup[TA(expr.head, first.(expr.args))] - if op isa Tuple - op = op[1] - end - Expr(:call, *, op, toexpr.(last.(expr.args))...) - end - end - end - - function _toexpr(expr::SSAExpr) - if expr.head isa RootVar - rootvar_lookup[expr.head] - elseif expr.head isa Number - expr.head - elseif expr.head == :* - - else - @info expr.args - op = operator_lookup[TypedApplication(expr.head, first.(expr.args))] - # Decapodes dec_* functions yield a tuple of both in-place and out-of-place function. - # We choose the first. - if op isa Tuple - op = op[1] - end - Expr(:call, *, op, toexpr.(last.(expr.args))...) - end - end - - ssalines = map(enumerate(ssa.statements)) do (i, expr) - :($(toexpr(SSAVar(i))) = $(toexpr(expr))) - end - - set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) - :($(du) .= $(toexpr(v))) - end - - eval( - quote - ($du, $u, $p, _) -> begin - $(ssalines...) - $(set_derivative_stmts...) - end - end - ) -end -export vfield - - -end diff --git a/DEC/src/roe.jl b/DEC/src/roe.jl new file mode 100644 index 00000000..671207c7 --- /dev/null +++ b/DEC/src/roe.jl @@ -0,0 +1,257 @@ +using ..Util.HashColor + +using StructEquality +import Metatheory +using Metatheory: EGraph, EGraphs, Id, VecExpr, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH, astsize +using MLStyle +using Reexport + +""" +Sorts in each theory are subtypes of this abstract type. +""" +abstract type AbstractSort end +export AbstractSort + +""" TypedApplication + +Struct containing a Function and the vector of Sorts it requires. +""" +@struct_hash_equal struct TypedApplication{Sort<:AbstractSort} + head::Function + sorts::Vector{Sort} + + function TypedApplication(head::Function, sorts::Vector{Sort}) where Sort + new{Sort}(head, sorts) + end +end +export TypedApplication + +const TA = TypedApplication +export TA + +Base.show(io::IO, ta::TA) = print(io, Expr(:call, nameof(ta.head), ta.sorts...)) + +struct SortError <: Exception + message::String +end +export SortError + +Base.get(lookup::Dict{TA, Any}, key::TA) = lookup[key] +export get + +""" RootVar + +A childless node on an e-graph. + +""" +@struct_hash_equal struct RootVar{Sort<:AbstractSort} + name::Symbol + idx::Int + sort::Sort + + function RootVar(name::Symbol, idx::Int, sort::Sort) where Sort + new{Sort}(name, idx, sort) + end +end +export RootVar + +""" Roe + +Struct for storing an EGraph and its variables. + +Roe is the name for lobster eggs. "Egg" is the name of a Rust implementation of e-graphs, by which Metatheory.jl is inspired by. Lobsters are part of the family Decapodes, which is also the name of the AlgebraicJulia package which motivated this package. Hence, Roe. +""" +struct Roe{Sort<:AbstractSort} + variables::Vector{RootVar} + graph::EGraph{Expr, Sort} + + function Roe(Sort::DataType) + new{Sort}(RootVar[], EGraph{Expr, Sort}()) + end +end +export Roe + +# accessors +variables(roe::Roe{S}) where S = roe.variables +graph(roe::Roe{S}) where S = roe.graph +param(roe::Roe{S}) where S = S +export variables, graph, param + +""" + +A struct containing a Roe and the Id of a variable in that EGraph. The type parameter for this struct is the variable it represents. + +""" +struct Var{S} + roe::Roe + id::Id +end + +# accessors +roe(v::Var{S}) where S = v.roe +graph(v::Var{S}) where S = roe(v).graph +id(v::Var{S}) where S = v.id +export roe, graph, id + +# MAKE AND JOIN + +function EGraphs.make(g::EGraph{Expr, Sort}, n::VecExpr) where {Sort<:AbstractSort} + op = EGraphs.get_constant(g, Metatheory.v_head(n)) + @match op begin + ::RootVar => op.sort + ::Number => Scalar() + _ => op((g[arg].data for arg in Metatheory.v_children(n))...) + end +end + +function EGraphs.join(s1::S1, s2::S2) where {S1<:AbstractSort,S2<:AbstractSort} + s1 == s2 ? s1 : throw(JoinError(s1, s2)) +end + +# EXTRACT + +function extract!(v::Var, f=EGraphs.astsize) + extract!(v.roe.graph, f, v.id) +end + +function rootvarcrayon(v::RootVar) + lightnessrange = (50., 100.) + HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) +end + +function Base.show(io::IO, v::RootVar) + if get(io, :color, true) + crayon = rootvarcrayon(v) + print(io, crayon, "$(v.name)") + print(io, inv(crayon)) + else + print(io, "$(v.name)#$(v.idx)") + end +end + +""" fix_functions(e)::Union{Symbol, Expr} + +Used in the show method for Vars. Traverses the AST of an expression, replacing the head of :call expressions to its name, a Symbol. +""" +function fix_functions(e) + @match e begin + s::Symbol => s + Expr(:call, f::Function, args...) => Expr(:call, nameof(f), fix_functions.(args)...) + Expr(head, args...) => Expr(head, fix_functions.(args)...) + _ => e + end +end + +""" getexpr(v::Var)::Union{Symbol, Expr} + +Extracts an expression (::Var) from its Roe. + +""" +function getexpr(v::Var) + e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) + fix_functions(e) +end +export getexpr + +function Base.show(io::IO, v::Var) + print(io, getexpr(v)) +end + +""" fresh!(roe::Roe, sort::AbstractSort, name::Symbol)::Var{sort} + +Creates a new variable in a Roe. Specifically, it appends a new RootVar with a given a sort and name to the Roe, adds that RootVar to the e-graph, and returns a Var wrapper around the new e-graph Id, with type parameter given by the sort. + +Example: +``` +fresh!(roe, Form(0), :Temp) +``` +""" +function fresh!(roe::Roe, sort::Sort, name::Symbol) where {Sort<:AbstractSort} + v = RootVar(name, length(roe.variables), sort) + push!(roe.variables, v) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) + id = EGraphs.add!(roe.graph, n, false) + Var{sort}(roe, id) +end +export fresh! + + +@nospecialize +""" inject_number!(roe::Roe, x::Number)::Var{Scalar()} + +Adds a number to the Roe as a EGraph constant. + +""" +function inject_number!(roe::Roe, x::Number) + x = Float64(x) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) + Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) +end +export inject_number! + +@nospecialize +""" addcall!(g::EGraph, head, args):: + +Adds a call to an EGraph. + +""" +function addcall!(g::EGraph, head, args) + ar = length(args) + n = Metatheory.v_new(ar) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) + Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) + for i in Metatheory.v_children_range(n) + @inbounds n[i] = args[i - VECEXPR_META_LENGTH] + end + EGraphs.add!(g, n, false) +end +export addcall! + +""" equate!(v1::Var{s1}, v2::Var{s2})::EGraph + +Asserts that two variables of the same e-graph are the same. This is done by returning the union of the variable ids with the e-graph. +""" +function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + (s1 == s2) || throw(JoinError(s1, s2)) + v1.roe === v2.roe || throw(EquateError()) + union!(v1.roe.graph, v1.id, v2.id) +end +export equate! + +""" +Infix synonym for `equate!` +""" +≐(v1::Var, v2::Var) = equate!(v1, v2) +export ≐ + +@nospecialize +""" derivative_cost(allowed_roots)::Function + +Returns a function `cost(n::Metatheory.VecExpr, op, costs)` which sets the cost of operations to Inf if they are either ∂ₜ or forbidden RootVars. Otherwise it computes the astsize. + +""" +function derivative_cost(allowed_roots) + function cost(n::VecExpr, op, costs) + if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) + Inf + else + astsize(n, op, costs) + end + end +end +export derivative_cost + +# EXCEPTIONS + +struct JoinError <: Exception; s1::AbstractSort; s2::AbstractSort end +Base.showerror(io::IO, e::JoinError) = print(io, "Cannot equate two nodes with different sorts: attempted to equate $(e.s1) with $(e.s2)") + +struct EquateError <: Exception end +Base.showerror(io::IO, e::EquateError) = print(io, "Cannot equate variables from different graphs") + diff --git a/DEC/src/roe/RoeUtility.jl b/DEC/src/roe/RoeUtility.jl deleted file mode 100644 index 81e0b7ce..00000000 --- a/DEC/src/roe/RoeUtility.jl +++ /dev/null @@ -1,222 +0,0 @@ -# """ -# Defines Roe, a struct which acts as a wrapper for e-graph typed in the Sorts of a given theory, as well as functions for manipulating it. -# """ -# module RoeUtility - -# using ..SSAExtract: SSA, SSAVar, SSAExpr, extract_ssa! -using ..Util.HashColor - -using StructEquality -import Metatheory -using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH -using MLStyle -using Reexport - -""" -Sorts in each theory are subtypes of this abstract type. -""" -abstract type AbstractSort end -export AbstractSort - -""" TypedApplication - -Struct containing a Function and the vector of Sorts it requires. -""" -@struct_hash_equal struct TypedApplication - head::Function - sorts::Vector{AbstractSort} -end -export TypedApplication - -const TA = TypedApplication -export TA - -Base.show(io::IO, ta::TA) = print(io, Expr(:call, nameof(ta.head), ta.sorts...)) - -struct SortError <: Exception - message::String -end -export SortError - -""" RootVar - -A childless node on an e-graph. - -""" -@struct_hash_equal struct RootVar - name::Symbol - idx::Int - sort::AbstractSort -end -export RootVar - -""" Roe - -Struct for storing an EGraph and its variables. - -Roe is the name for lobster eggs. "Egg" is the name of a Rust implementation of e-graphs, by which Metatheory.jl is inspired by. Lobsters are part of the family Decapodes, which is also the name of the AlgebraicJulia package which motivated this package. Hence, Roe. -""" -struct Roe - variables::Vector{RootVar} - graph::EGraph{Expr, AbstractSort} - function Roe() - new(RootVar[], EGraph{Expr, AbstractSort}()) - end -end -export Roe - -""" - -A struct containing a Roe and the Id of a variable in that EGraph. The type parameter for this struct is the variable it represents. - -""" -struct Var{S} - roe::Roe - id::Id -end - -function EGraphs.make(g::EGraph{Expr, AbstractSort}, n::Metatheory.VecExpr) - op = EGraphs.get_constant(g,Metatheory.v_head(n)) - if op isa RootVar - op.sort - elseif op isa Number - Scalar() - else - op((g[arg].data for arg in Metatheory.v_children(n))...) - end -end -export make - -function EGraphs.join(s1::AbstractSort, s2::AbstractSort) - if s1 == s2 - s1 - else - error("Cannot equate two nodes with different sorts") - end -end -export join - -function extract!(v::Var, f=EGraphs.astsize) - extract!(v.roe.graph, f, v.id) -end -export extract! - -function rootvarcrayon(v::RootVar) - lightnessrange = (50., 100.) - HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) -end - -function Base.show(io::IO, v::RootVar) - if get(io, :color, true) - crayon = rootvarcrayon(v) - print(io, crayon, "$(v.name)") - print(io, inv(crayon)) - else - print(io, "$(v.name)#$(v.idx)") - end -end - -""" fix_functions(e)::Union{Symbol, Expr} - -Traverses the AST of an expression, replacing the head of :call expressions to its name, a Symbol. -""" -function fix_functions(e) - @match e begin - s::Symbol => s - Expr(:call, f::Function, args...) => - Expr(:call, nameof(f), fix_functions.(args)...) - Expr(head, args...) => - Expr(head, fix_functions.(args)...) - _ => e - end -end - -""" getexpr(v::Var)::Union{Symbol, Expr} - -Extracts an expression (::Var) from its Roe. - -""" -function getexpr(v::Var) - e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) - fix_functions(e) -end -export getexpr - -function Base.show(io::IO, v::Var) - print(io, getexpr(v)) -end - -""" fresh!(roe::Roe, sort::AbstractSort, name::Symbol)::Var{sort} - -Creates a new ("fresh") variable in a Roe given a sort and a name. - -Example: -``` -fresh!(roe, Form(0), :Temp) -``` - -""" -function fresh!(roe::Roe, sort::AbstractSort, name::Symbol) - v = RootVar(name, length(roe.variables), sort) - push!(roe.variables, v) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) - Var{sort}(roe, EGraphs.add!(roe.graph, n, false)) -end -export fresh! - - -@nospecialize -""" inject_number!(roe::Roe, x::Number)::Var{Scalar()} - -Adds a number to the Roe as a EGraph constant. - -""" -function inject_number!(roe::Roe, x::Number) - x = Float64(x) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) - Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) -end -export inject_number! - -@nospecialize -""" addcall!(g::EGraph, head, args):: - -Adds a call to an EGraph. - -""" -function addcall!(g::EGraph, head, args) - ar = length(args) - n = Metatheory.v_new(ar) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) - Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) - for i in Metatheory.v_children_range(n) - @inbounds n[i] = args[i - VECEXPR_META_LENGTH] - end - EGraphs.add!(g, n, false) -end -export addcall! - -""" equate!(v1::Var{s1}, v2::Var{s2})::EGraph - -Asserts that two variables of the same e-graph are the same. This is done by returning the union of the variable ids with the e-graph. -""" -function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) - v1.roe === v2.roe || error("Cannot equate variables from different graphs") - union!(v1.roe.graph, v1.id, v2.id) -end -export equate! - -""" -Infix synonym for `equate!` -""" -≐(v1::Var, v2::Var) = equate!(v1, v2) -export ≐ - -# end diff --git a/DEC/src/roe/SSAExtract.jl b/DEC/src/roe/SSAExtract.jl deleted file mode 100644 index f7eeadd2..00000000 --- a/DEC/src/roe/SSAExtract.jl +++ /dev/null @@ -1,137 +0,0 @@ -module SSAExtract - -# -using ..DEC: AbstractSort, TypedApplication, TA, Roe, RootVar - -# other dependencies -using MLStyle -using Metatheory: VecExpr -using Metatheory.EGraphs -using StructEquality - -""" SSAVar - -A wrapper for the index of a SSAVar -""" -@struct_hash_equal struct SSAVar - idx::Int -end -export SSAVar - -function Base.show(io::IO, v::SSAVar) - print(io, "%", v.idx) -end - -""" SSAExpr - -A wrapper for a function (::Any) and its args (::Vector{Tuple{Sort, SSAVar}}). - -Example: the equation -``` - a = 1 + b -``` -may have an SSA dictionary -``` - %1 => a - %2 => +(%1, %3) - %3 => b -``` -and so `+` would have -``` -SSAExpr(+, [(Scalar(), SSAVar(1)), (Scalar(), SSAVar(2))]) -``` -""" -@struct_hash_equal struct SSAExpr - head::Any - args::Vector{Tuple{AbstractSort, SSAVar}} -end -export SSAExpr - -function Base.show(io::IO, e::SSAExpr) - print(io, e.head) - if length(e.args) > 0 - print(io, Expr(:tuple, (Expr(:(::), v, sort) for (sort, v) in e.args)...)) - end -end - -""" - -Struct defining Static Single-Assignment information for a given roe. - -Advantages of SSA form: - -1. We can preallocate each matrix -2. We can run a register-allocation algorithm to minimize the number of matrices that we have to preallocate -""" -struct SSA - assignment_lookup::Dict{Id, SSAVar} - statements::Vector{SSAExpr} - function SSA() - new(Dict{Id, SSAVar}(), SSAExpr[]) - end -end -export SSA - -function Base.show(io::IO, ssa::SSA) - println(io, "SSA: ") - for (i, expr) in enumerate(ssa.statements) - println(io, " ", SSAVar(i), " = ", expr) - end -end - -""" add_stmt!(ssa::SSA, id::Id, expr::SSAExpr)::SSAVar - -Given an SSA, add onto the assignment_lookup an SSAExpr. - -""" -function add_stmt!(ssa::SSA, id::Id, expr::SSAExpr) - push!(ssa.statements, expr) - v = SSAVar(length(ssa.statements)) - ssa.assignment_lookup[id] = v - v -end -export add_stmt! -# TODO is this idempotent? - -function hasid(ssa::SSA, id::Id) - haskey(ssa.assignment_lookup, id) -end -export hasid - -function getvar(ssa::SSA, id::Id) - ssa.assignment_lookup[id] -end -export getvar - -""" - extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar - -This function adds (recursively) the necessary lines to the SSA in order to -compute a value for `id`, and then returns the SSAVar that the value for `id` -will be assigned to. - -The closure parameters control the behavior of this function. - - term_select(id::Id)::VecExpr - -This closure selects, given an id in an EGraph, the term that we want to use in -order to compute a value for that id - -""" -function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select)::SSAVar - if hasid(ssa, id) - return getvar(ssa, id) - end - term = term_select(id) - args = map(EGraphs.v_children(term)) do arg - (g[arg].data, extract_ssa!(g, ssa, arg, term_select)) - end - add_stmt!(ssa, id, SSAExpr(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) -end -export extract_ssa! - -function extract_ssa!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) - extract_ssa!(g, ssa, id, term_select) -end - -end diff --git a/DEC/src/roe/module.jl b/DEC/src/roe/module.jl deleted file mode 100644 index 150ad779..00000000 --- a/DEC/src/roe/module.jl +++ /dev/null @@ -1,11 +0,0 @@ -# module RoeUtility - -using Reexport - -include("RoeUtility.jl") # vfield depends on SSAExtract -include("SSAExtract.jl") - -# @reexport using .RoeUtility -@reexport using .SSAExtract - -# end diff --git a/DEC/src/theories/ThDEC/ThDEC.jl b/DEC/src/theories/ThDEC/ThDEC.jl new file mode 100644 index 00000000..f56807db --- /dev/null +++ b/DEC/src/theories/ThDEC/ThDEC.jl @@ -0,0 +1,15 @@ +module ThDEC + +using ...DEC: TypedApplication, TA, Roe, RootVar +using ...DEC.SSAs + +using Metatheory: VecExpr +using Metatheory.EGraphs + +include("signature.jl") # verify operations type-check +include("roe_overloads.jl") # overload DEC operations to act on roe (egraphs) +include("semantics.jl") # represent operations as matrices + +# include("rewriting.jl") + +end diff --git a/DEC/src/theories/ThDEC/roe_overloads.jl b/DEC/src/theories/ThDEC/roe_overloads.jl new file mode 100644 index 00000000..b5cea146 --- /dev/null +++ b/DEC/src/theories/ThDEC/roe_overloads.jl @@ -0,0 +1,56 @@ +using ...DEC: Var, addcall!, inject_number! +using ...DEC: roe, graph, id # Var{S} accessors + +import Base: +, -, * + +# These operations create calls on a common egraph. We validate the signature by dispatching the operation on the types using methods we defined in Signature. + +## UNARY OPERATIONS + +unop_dec = [:∂ₜ, :d, :★, :-, :♯, :♭] +for unop in unop_dec + @eval begin + @nospecialize + function $unop(v::Var{s}) where s + s′ = $unop(s) + Var{s′}(roe(v), addcall!(graph(v), $unop, (id(v),))) + end + + export $unop + end +end + +# Δ is a composite of Hodge star and d +Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) +export Δ + +♭♯(v::Var{DualVF()}) = ♯(♭(v)) +export ♭♯ + +## BINARY OPERATIONS + +binop_dec = [:+, :-, :*, :∧] +for binop in binop_dec + @eval begin + @nospecialize + function $binop(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + roe(v1) === roe(v2) || throw(BinopError($binop)) + s = $binop(s1, s2) + Var{s}(v1.roe, addcall!(graph(v1), $binop, (id(v1), id(v2)))) + end + + @nospecialize + $binop(v::Var, x::Number) = $binop(v, inject_number!(roe(v), x)) + + @nospecialize + $binop(x::Number, v::Var) = $binop(inject_number!(roe(v)), x) + + export $binop + end +end + +struct BinopError <: Exception + binop::Symbol +end + +Base.showerror(io::IO, e::BinopError) = print(io, "Cannot use '$binop' on variables from different graphs.") diff --git a/DEC/src/theories/ThDEC/semantics.jl b/DEC/src/theories/ThDEC/semantics.jl new file mode 100644 index 00000000..f0d031f8 --- /dev/null +++ b/DEC/src/theories/ThDEC/semantics.jl @@ -0,0 +1,73 @@ +import Decapodes +using StructEquality + + +""" create_dynamic_model(sd, hodge)::Dict{TypedApplication, Any} + +Given a matrix and a hodge star (DiagonalHodge() or GeometricHodge()), this returns a lookup dictionary between operators (as TypedApplications) and their corresponding matrices. + +""" +function create_dynamic_model(sd, hodge)::Dict{TypedApplication, Any} + Dict{TypedApplication, Any}( + TA(*, Sort[Scalar(), Scalar()]) => 1, + TA(*, Sort[Scalar(), PrimalForm(0)]) => 1, + # Regular Hodge Stars + TA(★, Sort[PrimalForm(0)]) => Decapodes.dec_mat_hodge(0, sd, hodge), + TA(★, Sort[PrimalForm(1)]) => Decapodes.dec_mat_hodge(1, sd, hodge), + TA(★, Sort[PrimalForm(2)]) => Decapodes.dec_mat_hodge(2, sd, hodge), + + # Inverse Hodge Stars + TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), + # TODO verify ^ why is this 1??? + TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), + # Special since Geo is a solver + TA(★, Sort[DualForm(2)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), + + # Differentials + TA(d, Sort[PrimalForm(0)]) => Decapodes.dec_mat_differential(0, sd), + TA(d, Sort[PrimalForm(1)]) => Decapodes.dec_mat_differential(1, sd), + + # Dual Differentials + TA(d, Sort[DualForm(0)]) => Decapodes.dec_mat_dual_differential(0, sd), + TA(d, Sort[DualForm(1)]) => Decapodes.dec_mat_dual_differential(1, sd), + + # Wedge Products + TA(∧, Sort[PrimalForm(0), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{1,0}, sd), + TA(∧, Sort[PrimalForm(0), PrimalForm(2)]) => Decapodes.dec_pair_wedge_product(Tuple{0,2}, sd), + TA(∧, Sort[PrimalForm(2), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{2,0}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{1,1}, sd), + + # Primal-Dual Wedge Products + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dp(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dp(Tuple{1,0}, sd), + + # Dual-Dual Wedge Products + # TA(∧, Sort[DualForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{1,1}, sd), + TA(∧, Sort[DualForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dd(Tuple{1,0}, sd), + TA(∧, Sort[DualForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{0,1}, sd), + + # # Dual-Dual Interior Products + TA(ι, Sort[DualForm(1), DualForm(1)]) => Decapodes.interior_product_dd(Tuple{1,1}, sd), + TA(ι, Sort[DualForm(1), DualForm(2)]) => Decapodes.interior_product_dd(Tuple{1,2}, sd), + + # # Dual-Dual Lie Derivatives + # :ℒ₁ => ℒ_dd(Tuple{1,1}, sd) + + # # Dual Laplacians + # :Δᵈ₀ => Δᵈ(Val{0},sd) + # :Δᵈ₁ => Δᵈ(Val{1},sd) + + # # Musical Isomorphisms + TA(♯, Sort[PrimalForm(1)]) => Decapodes.dec_♯_p(sd), # Primal(1) -> PVField + TA(♯, Sort[DualForm(1)]) => Decapodes.dec_♯_d(sd), # Dual(1) -> DVField + + TA(♭, Sort[DualVF()]) => Decapodes.dec_♭(sd), # DVField -> Primal(1) + + # # Averaging Operator + # :avg₀₁ => Decapodes.dec_avg₀₁(sd) + ) +end +# TODO can we use OrderedDict to retain our nice presentation? diff --git a/DEC/src/theories/ThDEC/signature.jl b/DEC/src/theories/ThDEC/signature.jl new file mode 100644 index 00000000..ef68938a --- /dev/null +++ b/DEC/src/theories/ThDEC/signature.jl @@ -0,0 +1,155 @@ +using ...DEC: AbstractSort, SortError + +using MLStyle + +import Base: +, -, * + +# Define the sorts in your theory. +# For the DEC, we work with Scalars and Forms, graded objects which can also be primal or dual. +@data Sort <: AbstractSort begin + Scalar() + Form(dim::Int, isdual::Bool) + VF(isdual::Bool) +end +export Scalar, Form + +# accessors +dim(ω::Form) = ω.dim +isdual(ω::Form) = ω.isdual + +# convenience functions +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm + +PrimalVF() = VF(false) +export PrimalVF + +DualVF() = VF(true) +export DualVF + +# show methods +show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" + +function Base.show(io::IO, ω::Form) + print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +end + +## OPERATIONS + +@nospecialize +function +(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end + end +end + +# Type-checking inverse of addition follows addition +-(s1::Sort, s2::Sort) = +(s1, s2) + +# TODO error for Forms + +# Negation is valid +-(s::Sort) = s + +@nospecialize +function *(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + end +end + +@nospecialize +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if i1 + i2 <= 2 + Form(i1 + i2, isdual1) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + _ => throw(SortError("Can only take a wedge product of two forms")) + end +end + +@nospecialize +∂ₜ(s::Sort) = s + +@nospecialize +function d(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) + else + throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + end + end +end + +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end + +function ι(s1::Sort, s2::Sort) + @match (s1, s2) begin + (VF(true), Form(i, true)) => PrimalForm() # wrong + (VF(true), Form(i, false)) => DualForm() + _ => throw(SortError("Can only define the discrete interior product on: + PrimalVF, DualForm(i) + DualVF(), PrimalForm(i) + .")) + end +end + +# in practice, a scalar may be treated as a constant 0-form. +function ♯(s::Sort) + @match s begin + Scalar() => PrimalVF() + Form(1, isdual) => VF(isdual) + _ => throw(SortError("Can only take ♯ to 1-forms")) + end +end +# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. + +function ♭(s::Sort) + @match s begin + VF(true) => PrimalForm(1) + _ => throw(SortError("Can only apply ♭ to dual vector fields")) + end +end + +# OTHER + +function ♭♯(s::Sort) + @match s begin + Form(i, isdual) => Form(i, !isdual) + _ => throw(SortError("♭♯ is only defined on forms.")) + end +end + +# Δ = ★d⋆d, but we check signature here to throw a more helpful error +function Δ(s::Sort) + @match s begin + Form(0, isdual) => Form(0, isdual) + _ => throw(SortError("Δ is not defined for $s")) + end +end diff --git a/DEC/src/models/module.jl b/DEC/src/theories/module.jl similarity index 85% rename from DEC/src/models/module.jl rename to DEC/src/theories/module.jl index 117b5f05..31e01a97 100644 --- a/DEC/src/models/module.jl +++ b/DEC/src/theories/module.jl @@ -1,4 +1,4 @@ -module Models +module Theories using Reexport diff --git a/DEC/src/util/Plotting.jl b/DEC/src/util/Plotting.jl new file mode 100644 index 00000000..82ea3695 --- /dev/null +++ b/DEC/src/util/Plotting.jl @@ -0,0 +1,20 @@ +module Plotting + +using CairoMakie + +function save_dynamics(soln, timespan, save_file_name) + time = Observable(0.0) + h = @lift(soln($time)) + f = Figure() + ax = CairoMakie.Axis(f[1,1], title = @lift("Heat at time $($time)")) + gmsh = mesh!(ax, rect, color=h, colormap=:jet, + colorrange=extrema(soln(timespan))) + Colorbar(f[1,2], gmsh) + timestamps = range(0, timespan, step=5.0) + record(f, save_file_name, timestamps; framerate = 15) do t + time[] = t + end +end +export save_dynamics + +end diff --git a/DEC/src/util/module.jl b/DEC/src/util/module.jl index 8c8d51c5..bde14dbe 100644 --- a/DEC/src/util/module.jl +++ b/DEC/src/util/module.jl @@ -3,7 +3,9 @@ module Util using Reexport include("HashColor.jl") +include("Plotting.jl") @reexport using .HashColor +@reexport using .Plotting end diff --git a/DEC/src/vfield.jl b/DEC/src/vfield.jl new file mode 100644 index 00000000..eb10877a --- /dev/null +++ b/DEC/src/vfield.jl @@ -0,0 +1,156 @@ +using .SSAs +using MLStyle + +""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction + +Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. + +Example: given a diffusivity constant a, the heat equation can be written as: +``` + ∂ₜ u = a * Δ(u) +``` +would return (u, a). + +A limitation of this function can be demonstrated here: given the model + ``` + ∂ₜ a = a + b + ∂ₜ b = a + b + ``` + we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, the extractor would bypass it completely. +""" +function vfield(model, op_lookup::Dict{TA, Any}) + + # ::Roe + # inttialize the Roe (e-graph) + roe = Roe(DEC.ThDEC.Sort) + + # ::Tuple{Vector{Var}, Vector{Var}} + # Pass the roe into the model function, which contributes the variables (via `fresh!`) and equations (via `equate!`). Retrieve the state and parameter variables in the model. + (state_vars, param_vars) = model(roe) + + # A model is inadmissible if there is no state variables. + length(state_vars) >= 1 || throw(VFieldError()) + + # ::Vector{RootVar} + # iterate `extract!` through the state and parameter variables. + state_rootvars = extract_rootvars!(state_vars); + param_rootvars = extract_rootvars!(param_vars); + + # TODO This is currently fixed + u = :u; p = :p; du = :du; + + # ::Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}} + rv_lookup = make_rv_lookup(state_rootvars, param_rootvars, u, p); + + # ::Function + # Return a cost function whose allowed roots are the set union of the model's rootvars. + cost = derivative_cost(Set([state_rootvars; param_rootvars])) + + # ::Extractor + # Pass the Roe's E-Graph into a Metatheory Extractor. + extractor = EGraphs.Extractor(roe.graph, cost, Float64) + + # ::SSA + ssa = SSA() + + # ::Function + term_select(id) = EGraphs.find_best_node(extractor, id); + + # ::Vector{Var} + d_vars = extract_derivative_vars!(roe, ssa, state_vars, term_select); + + # ::Tuple{Vector{Expr}, Vector{Expr}} + # convert the SSA statements and derivative variables into Julia Exprs + (ssalines, derivative_stmts) = build_result_body(ssa, d_vars, du, op_lookup, rv_lookup) + + # yield the function that will be passed to a solver + eval(quote + f(du, u, p, _) = begin + $(ssalines...) + $(derivative_stmts...) + end + end) +end +export vfield + +# Build the body of the function by returning the lines of the ssas and the derivative statments. +function build_result_body(ssa, derivative_vars, du, op_lookup, rv_lookup) + + _toexpr(term) = toexpr(term, op_lookup, rv_lookup) + + ssalines = map(enumerate(ssa.statements)) do (i, stmt) + :($(_toexpr(SSAs.Var(i))) = $(_toexpr(stmt))) + end + + derivative_stmts = map(enumerate(derivative_vars)) do (i, stmt) + :($(du) .= $(_toexpr(stmt))) + end + + return (ssalines, derivative_stmts) +end + +# For normalization purposes, I factored `toexpr` out of `vfield`. However, this means the two lookup variables were no longer in scope for `toexpr`. +# +# It is possible to thread the lookups into the arguments of the `toexpr`s, +# +# ``` +# :($(toexpr(SSAs.Var(i), lookup1, lookup2)) = $(toexpr(stmt, lookup1, lookup2))) +# ``` +# but you would also need to pass the lookup arguments for the `::Var` dispatch for `toexpr`, where the variables would not be used. +# +# Then, you could simplify this but uniting the two functions and using a conditional or @match expression. Since we are traversing a Term, we could just call the function recusively, or define one @λ. +# +# but I felt this was visually too noisy in `build_result_body`. +function toexpr(expr::Union{Term, DEC.SSAs.Var}, op_lookup, rv_lookup) + λtoexpr = @λ begin + var::DEC.SSAs.Var => Symbol("tmp%$(var.idx)") + term::Term && if term.head isa Number end => term.head + # if the head of a term is a RootVar, we'll need to ensure that we can retrieve the value from a named tuple. + # if the boolean value is false, the rootvar is a state_var, otherwise it is a parameter and assumed to be + # accessed by a named tuple. + term::Term && if term.head isa RootVar end => @match rv_lookup[term.head] begin + (rv, false) => rv + (rv, true) => Expr(:ref, rv, term.head.name) + end + # This default case is Decapodes-specific. Decapode operators return a tuple of functions, so we choose the first. + term => begin + op = get(op_lookup, TA(term.head, first.(term.args))) + if op isa Tuple; op = op[1] end + Expr(:call, *, op, λtoexpr.(last.(term.args))...) + end + end + λtoexpr(expr) +end + +# map over the state_vars to apply `extract!` +function extract_derivative_vars!(roe::Roe, ssa::SSA, state_vars, term_select::Function) + map(state_vars) do v + extract!(roe.graph, ssa, (∂ₜ(v)).id, term_select) + end +end + +# given root variables, and produce a dictionary +function make_rv_lookup(state_rvs, param_rvs, state, param) + Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}}( + [ + [rv => (:($(state)), false) for rv in state_rvs]; + [rv => (:($(param)), true) for rv in param_rvs] + ] + ) +end + +# map over vars +function extract_rootvars!(vars) + map(vars) do x + rv = extract!(x) + rv isa RootVar ? rv : throw(RootVarError("All variables must be RootVars")) + end +end + +struct VFieldError <: Exception end + +Base.showerror(io::IO, e::VFieldError) = println(io, "Need at least one state variable in order to create a vector field") + +struct RootVarError <: Exception; msg::String end + +Base.showerror(io::IO, e::RootVarError) = println(io, e.msg) diff --git a/DEC/tests/SSAExtract.jl b/DEC/test/SSAExtract.jl similarity index 96% rename from DEC/tests/SSAExtract.jl rename to DEC/test/SSAExtract.jl index 0bd1d87c..12b437b1 100644 --- a/DEC/tests/SSAExtract.jl +++ b/DEC/test/SSAExtract.jl @@ -63,7 +63,7 @@ function term_select(g::EGraph, id::Id) g[id].nodes[1] end -extract_ssa!(roe.graph, ssa, (a + b).id, term_select) +extract!(roe.graph, ssa, (a + b).id, term_select) ssa diff --git a/DEC/test/ThDEC/ThDEC.jl b/DEC/test/ThDEC/ThDEC.jl new file mode 100644 index 00000000..aa69e10f --- /dev/null +++ b/DEC/test/ThDEC/ThDEC.jl @@ -0,0 +1,19 @@ +# AlgebraicJulia dependencies +using DEC +import DEC.ThDEC: ∧, Δ # conflicts with CombinatorialSpaces + +# preliminary dependencies for testing +using Test +using Metatheory.EGraphs + +# test the signature +include("signature.jl") + +# test the roe_overloads +include("roe_overloads.jl") + +# test the semantics +include("semantics.jl") + +# test modeling +include("model.jl") diff --git a/DEC/test/ThDEC/model.jl b/DEC/test/ThDEC/model.jl new file mode 100644 index 00000000..1cd2b75a --- /dev/null +++ b/DEC/test/ThDEC/model.jl @@ -0,0 +1,73 @@ +# load other dependencies +using ComponentArrays +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} + +# plotting +using CairoMakie + +## 1-D HEAT EQUATION + +# initialize the model +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# initialize primal and dual meshes. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# precompute matrices from operators in the DEC theory. +op_lookup = ThDEC.create_dynamic_model(d_rect, DiagonalHodge()) + +# produce a vector field. +vf = vfield(heat_equation, op_lookup) + +U = first.(d_rect[:point]); +constants_and_parameters = () +t0 = 50.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +save_dynamics(soln, t0, "heat-1D.gif") + +## 1-D HEAT EQUATION WITH DIFFUSIVITY + +function new_heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + ℓ = fresh!(roe, Scalar(), :ℓ) + + ∂ₜ(u) ≐ ℓ * k * Δ(u) + + ([u], [k, ℓ]) +end + +# we can reuse the mesh and operator lookup +vf = vfield(new_heat_equation, op_lookup) + +# we can reuse the initial condition U. However we need to specify the diffusivity constant +constants_and_parameters = ComponentArray(k=0.25,ℓ=2,); + +# this is a shim +DEC.k = :k; DEC.ℓ = :ℓ; + +# Let's set the time +t0 = 500 + +@info("Precompiling solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); + +soln = solve(prob, Tsit5()); + +save_dynamics(soln, t0, "heat-1D-scalar.gif") diff --git a/DEC/tests/Roe.jl b/DEC/test/ThDEC/roe.jl similarity index 100% rename from DEC/tests/Roe.jl rename to DEC/test/ThDEC/roe.jl diff --git a/DEC/test/ThDEC/semantics.jl b/DEC/test/ThDEC/semantics.jl new file mode 100644 index 00000000..e69de29b diff --git a/DEC/test/ThDEC/signature.jl b/DEC/test/ThDEC/signature.jl new file mode 100644 index 00000000..8cb91835 --- /dev/null +++ b/DEC/test/ThDEC/signature.jl @@ -0,0 +1,40 @@ +# ## SIGNATURE TESTS + +# Addition +@test Scalar() + Scalar() == Scalar() +@test Scalar() + PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) + Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(1) + PrimalForm(2) + +# Negation and Subtraction +@test -Scalar() == Scalar() +@test Scalar() - Scalar() == Scalar() + +# Scalar Multiplication +@test Scalar() * Scalar() == Scalar() +@test Scalar() * PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) * Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(2) * PrimalForm(1) + +# Exterior Product +@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) +@test PrimalForm(1) ∧ Scalar() == PrimalForm(1) + +@test_throws SortError PrimalForm(1) ∧ DualForm(1) +@test_throws SortError PrimalForm(2) ∧ PrimalForm(1) + +# Time derivative +@test ∂ₜ(Scalar()) == Scalar() +@test ∂ₜ(PrimalForm(1)) == PrimalForm(1) +@test ∂ₜ(DualForm(0)) == DualForm(0) + +# Derivative +@test_throws SortError d(Scalar()) +@test d(PrimalForm(1)) == PrimalForm(2) +@test d(DualForm(0)) == DualForm(1) + +# Hodge star +@test_throws SortError ★(Scalar()) +@test ★(PrimalForm(1)) == DualForm(1) +@test ★(DualForm(0)) == PrimalForm(2) + diff --git a/DEC/tests/runtests.jl b/DEC/test/runtests.jl similarity index 88% rename from DEC/tests/runtests.jl rename to DEC/test/runtests.jl index 7b00e4da..d5ad61dc 100644 --- a/DEC/tests/runtests.jl +++ b/DEC/test/runtests.jl @@ -13,5 +13,5 @@ end end @testset "ThDEC" begin - include("DEC.jl") + include("ThDEC/tests.jl") end diff --git a/DEC/tests/DEC.jl b/DEC/tests/DEC.jl deleted file mode 100644 index aeea45bf..00000000 --- a/DEC/tests/DEC.jl +++ /dev/null @@ -1,92 +0,0 @@ -module TestDEC - -# AlgebraicJulia dependencies -using DEC -import DEC.ThDEC: Δ # conflicts with CombinatorialSpaces - -# other dependencies -using Test -using Metatheory.EGraphs -using ComponentArrays -using CombinatorialSpaces -using GeometryBasics -using OrdinaryDiffEq -Point2D = Point2{Float64} -Point3D = Point3{Float64} - -# plotting -using CairoMakie - -## 1-D HEAT EQUATION - -function heat_equation(pode) - u = fresh!(pode, PrimalForm(0), :u) - - ∂ₜ(u) ≐ Δ(u) - - ([u], []) -end - -# initialize primal and dual meshes. -rect = triangulated_grid(100, 100, 1, 1, Point3D); -d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); -subdivide_duals!(d_rect, Circumcenter()); - -# precompule matrices from operators in the DEC theory. -operator_lookup = ThDEC.precompute_matrices(d_rect, DiagonalHodge()) - -# produce a vector field -vf = vfield(heat_equation, operator_lookup) - -# -U = first.(d_rect[:point]); - -# TODO component arrays -constants_and_parameters = () - -tₑ = 500.0 - -@info("Precompiling Solver") -prob = ODEProblem(vf, U, (0, tₑ), constants_and_parameters); -soln = solve(prob, Tsit5()); - -function save_dynamics(save_file_name) - time = Observable(0.0) - h = @lift(soln($time)) - f = Figure() - ax = CairoMakie.Axis(f[1,1], title = @lift("Heat at time $($time)")) - gmsh = mesh!(ax, rect, color=h, colormap=:jet, - colorrange=extrema(soln(tₑ))) - Colorbar(f[1,2], gmsh) - timestamps = range(0, tₑ, step=5.0) - record(f, save_file_name, timestamps; framerate = 15) do t - time[] = t - end -end - -## 1-D HEAT EQUATION WITH DIFFUSIVITY - -function new_heat_equation(roe) - u = fresh!(roe, PrimalForm(0), :u) - k = fresh!(roe, Scalar(), :k) - - ∂ₜ(u) ≐ k * Δ(u) - - ([u], [k]) -end - -# we can reuse the mesh and operator lookup -_vf = vfield(new_heat_equation, operator_lookup) - -# we can reuse the initial condition U - -# -constants_and_parameters = ComponentArray(k=0.5,); - -t0 = 50 - -@info("Precompiling solver") -prob = ODEProblem(_vf, U, (0, t0), constants_and_parameters); -soln = solve(prob, Tsit5()); - - diff --git a/DEC/tests/Signature.jl b/DEC/tests/Signature.jl deleted file mode 100644 index 015b4860..00000000 --- a/DEC/tests/Signature.jl +++ /dev/null @@ -1,25 +0,0 @@ -module TestSignature - -using DEC - -using Test - -# ## SIGNATURE TESTS - -# Addition -@test Scalar() + Scalar() == Scalar() -@test Scalar() + PrimalForm(1) == PrimalForm(1) -@test PrimalForm(2) + Scalar() == PrimalForm(2) -@test_throws SortError PrimalForm(1) + PrimalForm(2) - -# Scalar Multiplication -@test Scalar() * Scalar() == Scalar() -@test Scalar() * PrimalForm(1) == PrimalForm(1) -@test PrimalForm(2) * Scalar() == PrimalForm(2) -@test_throws SortError PrimalForm(2) * PrimalForm(1) - -# Exterior Product -@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) -@test_throws SortError PrimalForm(1) ∧ Scalar() - -end From 8358b745b969c9e6bd4b9c7ff903ce03eda647c2 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 1 Aug 2024 14:42:49 -0400 Subject: [PATCH 4/5] begun rewriting work --- DEC/src/theories/ThDEC/rewriting.jl | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 DEC/src/theories/ThDEC/rewriting.jl diff --git a/DEC/src/theories/ThDEC/rewriting.jl b/DEC/src/theories/ThDEC/rewriting.jl new file mode 100644 index 00000000..cb9bec74 --- /dev/null +++ b/DEC/src/theories/ThDEC/rewriting.jl @@ -0,0 +1,36 @@ +using Metatheory +using Metatheory.Library + +ThMultiplicativeMonoid = @commutative_monoid (*) 1 +ThAdditiveGroup = @commutative_group (+) 0 (-) +Distributivity = @distrib (*) (+) +ThRing = ThMultiplicativeMonoid ∪ ThAdditiveGroup ∪ Distributivity + +Derivative = @theory (f, g)::Function, a::Scalar begin + f * d(g) + d(f) * g --> d(f * g) + d(f) + d(g) --> d(f + g) + d(a * f) --> a * d(f) +end + + +# e = :(f * d(g) + d(f) * g) +# g = EGraph(e) +# saturate!(g, product) +# extract!(g, astsize) + +zero = @theory f begin + f * 0 --> 0 + f + 0 --> f + 0 + f --> f +end + +square_zero = @theory ω begin + d(d(ω)) --> 0 +end + +linearity = @theory f g a begin + Δ(f + g) == Δ(f) + Δ(g) + Δ(a * f) == a * Δ(f) +end +export linearity + From 175a9e6dae6b67a88d2cf9312a39d9339cc403df Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 2 Aug 2024 17:12:28 -0400 Subject: [PATCH 5/5] prototype of typed rewriting for theories and continued with little progress on docs --- DEC/Project.toml | 1 + DEC/docs/literate/egraphs.jl | 22 +++ DEC/scratch/tc.jl | 240 ++++++++++++++++++++++++++++ DEC/src/theories/ThDEC/rewriting.jl | 20 +++ DEC/src/theories/ThDEC/signature.jl | 24 +++ DEC/test/ThDEC/model.jl | 3 + 6 files changed, 310 insertions(+) create mode 100644 DEC/docs/literate/egraphs.jl create mode 100644 DEC/scratch/tc.jl diff --git a/DEC/Project.toml b/DEC/Project.toml index 8c492bb8..263a2a02 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -17,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [compat] CairoMakie = "0.12.5" diff --git a/DEC/docs/literate/egraphs.jl b/DEC/docs/literate/egraphs.jl new file mode 100644 index 00000000..b3b11cf3 --- /dev/null +++ b/DEC/docs/literate/egraphs.jl @@ -0,0 +1,22 @@ +# This lesson covers the internals of Metatheory.jl-style E-Graphs. Let's reuse the heat_equation model on a new roe. +roe = Roe(DEC.ThDEC.Sort); +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# We apply the model to the roe and collect its state variables. +(state, _) = heat_equation(roe) + +# Recall from the Introduction that an E-Graph is a bipartite graph of ENodes and EClasses. Let's look at the EClasses: +classes = roe.graph.classes +# The keys are Metatheory Id types which store an Id. The values are EClasses, which are implementations of equivalence classes. Nodes which share the same EClass are considered equivalent. + + + +# The constants in Roe are a dictionary of hashes of functions and constants. Let's extract just the values again: +vals = collect(values(e.graph.constants)) +# The `u` is ::RootVar{Form} and ∂ₜ, ★, d are all functions defined in ThDEC/signature.jl file. diff --git a/DEC/scratch/tc.jl b/DEC/scratch/tc.jl new file mode 100644 index 00000000..ee490cec --- /dev/null +++ b/DEC/scratch/tc.jl @@ -0,0 +1,240 @@ +using Metatheory +using Metatheory: OptBuffer +using Metatheory.Library +using Metatheory.Rewriters +using MLStyle +using Test +using Metatheory.Plotting + +b = OptBuffer{UInt128}(10) + +@testset "Predicate Assertions" begin + r = @rule ~a::iseven --> true + Base.iseven(g, ec::EClass) = + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Number && iseven(c) + end + false + end + # + g = EGraph(:(f(2, 1))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + g = EGraph(:2) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + # + g = EGraph(:3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + new_id = addexpr!(g, :f) + union!(g, g.root, new_id) + # + new_id = addexpr!(g, 4) + union!(g, g.root, new_id) + # + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + +### + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +@testset "Check form" begin + + function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c))" + return c isa Form + end + false + end + end + + r = @rule ~a::isform --> true + + t = @theory a begin + ~a::isform + ~b::isform --> 0 + end + + ## initialize and sanity-check + a1=Form(1); a2=Form(2) + a3=Form(1); a4=Form(1) + @assert a1 isa Form; @assert a3 isa Form + @assert a3 isa Form; @assert a4 isa Form + + g = EGraph(:($a1 + $a2)) + saturate!(g, t) + extract!(g, astsize) + + g = EGraph(:a1) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + g = EGraph(:a3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + + g = EGraph(:(a + b )) + saturate!(g, t) + extract!(g, astsize) + + +end + + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +t = @theory a b begin + a::Form(1) ∧ b::Form(2) --> 0 +end +# breaks! makevar @ MT/src/Syntax.jl:57 (<- makepattern, ibid:151) +# expects Symbol, not Expr + +function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c)), $(c isa Form)" + return c isa Form + end + end +end + +r = @rule ~a::isform --> true + +g = EGraph(:(f(2, 1))) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +g = EGraph(:2) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + +g = EGraph(:3) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +new_id = addexpr!(g, :f) +union!(g, g.root, new_id) + +new_id = addexpr!(g, 4) +union!(g, g.root, new_id) + +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + + + + +a=Form(1) +b=Form(2) +c=Form(1) +d=Form(1) +@assert a isa Form +@assert b isa Form +@assert c isa Form +@assert d isa Form + +ex = :(a + b + c + d) +g = EGraph(ex) +saturate!(g, _T) +extract!(g, astsize) + + +@data Sort1 <: AbstractSort begin + AnyForm +end + +_t = @theory a b begin + a::AnyForm ∧ b::AnyForm --> 0 +end +# PatVar error + +__t = @theory a b begin + a::var"AnyForm's constructor" + 0 --> a +end + +d0 = AnyForm +d1 = AnyForm +d2 = AnyForm +ex = :(d0 + 0 + (d1 + 0) + d2) + +g = EGraph(ex) +saturate!(g, __t) +extract!(g, astsize) + + +rwth = Fixpoint(Prewalk(Chain(__t))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, _t) +extract!(g, astsize) +# returns (a ∧ b), as + +rwth = Fixpoint(Prewalk(Chain(_t))) +rwth(ex) + +a = Form(0) +ex = :(a ∧ b) + + +Derivative = @theory f g a begin + f::Function * d(g::Function) + d(f::Function) * g::Function --> d(f * g) + d(f::Function) + d(g::Function) --> d(f + g) + d(a::Number * f::Function) --> a * d(f) +end + +_Derivative = @theory f g a begin + f * d(g) + d(f) * g --> d(f * g) + d(f) + d(g) --> d(f + g) + d(a * f) --> a * d(f) +end + +ex = :(f * d(g) + d(f) * g) + +rwth = Fixpoint(Prewalk(Chain(_Derivative))) +rwth(ex) + +foo(x) = x + 1 +goo(x) = x + 3 + +rwth = Fixpoint(Prewalk(Chain(Derivative))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, Derivative); +extract!(g, astsize) + + +rwth(ex) + + +types = (U=Form(0), k=Scalar(),); + +tc(x) = @match x begin + s::Symbol => types[s] + ::Expr => true +end + +cond = x -> begin + @info "$x, $(type(x))" + tc(x) +end + +orw = Fixpoint(Prewalk(If(cond, Chain(rewrite_theory)))) + +orw(expr) diff --git a/DEC/src/theories/ThDEC/rewriting.jl b/DEC/src/theories/ThDEC/rewriting.jl index cb9bec74..81182e50 100644 --- a/DEC/src/theories/ThDEC/rewriting.jl +++ b/DEC/src/theories/ThDEC/rewriting.jl @@ -1,5 +1,25 @@ using Metatheory using Metatheory.Library +using Metatheory.Rewriters +using MLStyle + +buf = OptBuffer{UInt128}(10) + +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + +t = @theory a b begin + ~a::isForm + ~b::isForm --> 0 +end + ThMultiplicativeMonoid = @commutative_monoid (*) 1 ThAdditiveGroup = @commutative_group (+) 0 (-) diff --git a/DEC/src/theories/ThDEC/signature.jl b/DEC/src/theories/ThDEC/signature.jl index ef68938a..ed201983 100644 --- a/DEC/src/theories/ThDEC/signature.jl +++ b/DEC/src/theories/ThDEC/signature.jl @@ -37,6 +37,30 @@ function Base.show(io::IO, ω::Form) print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") end +## Predicates +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + + +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + ## OPERATIONS @nospecialize diff --git a/DEC/test/ThDEC/model.jl b/DEC/test/ThDEC/model.jl index 1cd2b75a..08e998a1 100644 --- a/DEC/test/ThDEC/model.jl +++ b/DEC/test/ThDEC/model.jl @@ -1,3 +1,6 @@ +using DEC +import DEC: Δ, ∧ + # load other dependencies using ComponentArrays using CombinatorialSpaces