diff --git a/Project.toml b/Project.toml index cbd28bf33..e5cc66efa 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" +ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" @@ -44,6 +45,7 @@ DomainSets = "0.5" Groebner = "0.1, 0.2" IfElse = "0.1" LaTeXStrings = "1.3" +LambertW = "0.4.5" Latexify = "0.11, 0.12, 0.13, 0.14, 0.15" MacroTools = "0.5" NaNMath = "0.3, 1" @@ -58,7 +60,6 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "1.1" SymbolicUtils = "1.0.1" TreeViews = "0.3" -LambertW = "0.4.5" julia = "1.6" [extras] diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 5ce227b31..851c0551d 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -43,6 +43,10 @@ using MacroTools import MacroTools: splitdef, combinedef, postwalk, striplines include("wrapper-types.jl") +import ExprTools +export specialize_methods +include("specialize_methods.jl") + include("num.jl") include("complex.jl") diff --git a/src/init.jl b/src/init.jl index 73dab9a0b..ead8373ff 100644 --- a/src/init.jl +++ b/src/init.jl @@ -23,4 +23,6 @@ function __init__() end end # SymPy + + specialize_methods((LinearAlgebra,)) end diff --git a/src/specialize_methods.jl b/src/specialize_methods.jl new file mode 100644 index 000000000..29d520eb0 --- /dev/null +++ b/src/specialize_methods.jl @@ -0,0 +1,45 @@ +""" + specialize_methods(func, abstract_arg_types, inner_func, mods=nothing) + +For any method that implements `func` with signature +fitting `abstract_arg_types`, define methods for corresponding +symbolic types that pass all arguments to `inner_func`. +`mods` is an optional list of modules to look for methods in. +""" +function specialize_methods(func, abstract_arg_types, inner_func, mods=nothing) + ms = isnothing(mods) ? methods(func, abstract_arg_types) : methods(func, abstract_arg_types, mods) + for m in ms + mod = m.module + if mod != @__MODULE__ # do not overwrite method definitions from within this module itself, else: precompilation warnings + sig = ExprTools.signature(m; extra_hygiene=true) + fname = sig[:name] + args = sig[:args] + kwargs = get(sig, :kwargs, Symbol[]) + whereparams = get(sig, :whereparams, Symbol[]) + args_names = expr_argname.(args) + kwargs_names = expr_kwargname.(kwargs) + body = :($(inner_func)($(args_names...); $(kwargs_names...))) + Base.eval( + @__MODULE__, + wrap_func_expr( + mod, fname, args, kwargs, args_names, kwargs_names, whereparams, body; + abstract_arg_types + ) + ) + end#of `mod != @__MODULE__` + end#of `for m in ms` +end + +""" + specialize_methods(mods=nothing) + +Define specialized methods accepting symbolic types for the following functions and +signatures found in modules `mods` via `methods(...)`: + +* `Base.:(*)` for arguments of `(AbstractMatrix, AbstractVector)` to redirect to `_matvec`. +* `Base.:(*)` for arguments of `(AbstractMatrix, AbstractMetrax)` to redirect to `_matmul`. +""" +function specialize_methods(mods=nothing) + specialize_methods(Base.:(*), (AbstractMatrix, AbstractVector), _matvec, mods) + specialize_methods(Base.:(*), (AbstractMatrix, AbstractMatrix), _matmul, mods) +end \ No newline at end of file diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index c36fc9eca..55767773b 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -57,58 +57,82 @@ function wraps_type end has_symwrapper(::Type) = false is_wrapper_type(::Type) = false +# helper function to extract keyword argument names from expressions +function expr_kwargname(kwarg) + if kwarg isa Expr && kwarg.head == :kw + kwarg.args[1] + elseif kwarg isa Expr && kwarg.head == :(...) + kwarg.args[1] + else + kwarg + end +end + +# helper function to extract argument names from expressions +function expr_argname(arg) + if arg isa Expr && (arg.head == :(::) || arg.head == :(...)) + arg.args[1] + elseif arg isa Expr + error("$arg not supported as an argument") + else + arg + end +end + function wrap_func_expr(mod, expr) @assert expr.head == :function || (expr.head == :(=) && expr.args[1] isa Expr && expr.args[1].head == :call) def = splitdef(expr) - - sig = expr.args[1] body = def[:body] - fname = def[:name] args = get(def, :args, []) kwargs = get(def, :kwargs, []) + args_names = expr_argname.(args) + kwargs_names = expr_kwargname.(kwargs) + + wrap_func_expr(mod, fname, args, kwargs, args_names, kwargs_names, Symbol[], body) +end - impl_name = Symbol(fname,"_", hash(string(args)*string(kwargs))) - - function kwargname(kwarg) - if kwarg isa Expr && kwarg.head == :kw - kwarg.args[1] - elseif kwarg isa Expr && kwarg.head == :(...) - kwarg.args[1] - else - kwarg - end - end - - function argname(arg) - if arg isa Expr && (arg.head == :(::) || arg.head == :(...)) - arg.args[1] - elseif arg isa Expr - error("$arg not supported as an argument") - else - arg - end - end - - names = vcat(argname.(args), kwargname.(kwargs)) - - function type_options(arg) +function wrap_func_expr( + mod, fname, args, kwargs, args_names, kwargs_names, whereparams, body; + abstract_arg_types=nothing +) + names = vcat(args_names, kwargs_names) + + function type_options(wparams, arg, arg_ind) + pmod = parentmodule(mod) + atype = isnothing(abstract_arg_types) ? Any : abstract_arg_types[arg_ind] if arg isa Expr && arg.head == :(::) - T = Base.eval(mod, arg.args[2]) + T = Base.eval(mod, quote + let $(Symbol(pmod)) = $(pmod); # make name of parent module available in eval scope + #= + NOTE + `typeintersect` is important here for consecutive calls to `specialize_methods` + with conceptually different super types. + E.g.: Consider we first specialize `*(::AbstractMatrix, ::AbstractVector)` to + redirect to `_matvec`, and then `*(::AbstractMatrix, ::AbstractMatrix)` to + redirect to `_matmul`. If we encounter some existing method for `*` which accepts + an `AbstractMatrix` and an `VecOrMat` (type union), then we accidentally redirect + a matrix-vector-product to `_matmul` without `typeintersect`. + =# + typeintersect($(atype), $(arg.args[2]) where {$(wparams...)}) + end + end) has_symwrapper(T) ? (T, :(SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) : - (T,:(SymbolicUtils.Symbolic{<:$T})) + (T, :(SymbolicUtils.Symbolic{<:$T})) elseif arg isa Expr && arg.head == :(...) - Ts = type_options(arg.args[1]) + Ts = type_options(wparams, arg.args[1], arg_ind) map(x->Vararg{x},Ts) else (Any,) end end - types = map(type_options, args) + types = [type_options(whereparams, arg, arg_ind) for (arg_ind, arg)=enumerate(args)] + + impl_name = Symbol(fname,"_", hash(string(args)*string(kwargs)*string(types))) impl = :(function $impl_name($(names...)) $body @@ -139,9 +163,9 @@ function wrap_func_expr(mod, expr) quote $impl $(methods...) - end |> esc + end end macro wrapped(expr) - wrap_func_expr(__module__, expr) + esc(wrap_func_expr(__module__, expr)) end