Skip to content

Commit

Permalink
feat: add support for typed function symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 13, 2024
1 parent 6366280 commit e1ba803
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 19 deletions.
100 changes: 85 additions & 15 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ function _parse_vars(macroname, type, x, transform=identity)
# x = 1, [connect = flow; unit = u"m^3/s"]
if Meta.isexpr(v, :(=))
v, val = v.args
# defaults with metadata for function variables
if Meta.isexpr(val, :block)
Base.remove_linenums!(val)
val = only(val.args)
end
if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2])
options = val.args[2].args
val = val.args[1]
Expand All @@ -124,7 +129,7 @@ function _parse_vars(macroname, type, x, transform=identity)
isruntime, v = unwrap_runtime_var(v)
iscall = Meta.isexpr(v, :call)
isarray = Meta.isexpr(v, :ref)
if iscall && Meta.isexpr(v.args[1], :ref)
if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(lastunwrap_runtime_var, @view v.args[2:end]))
@warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead.
The former creates an array of functions, while the latter creates an array valued function.
The deprecated syntax will cause an error in the next major release of Symbolics.
Expand Down Expand Up @@ -155,35 +160,54 @@ function _parse_vars(macroname, type, x, transform=identity)
return ex
end

call_args_are_function(_) = false
function call_args_are_function(call_args::AbstractArray)
!isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args))
end

function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime)
ndim = :($length(($(indices...),)))
vname = !isruntime ? Meta.quot(lhs) : lhs
if call_args[1] == :..
ex = :($CallWithMetadata($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)))
if call_args_are_function(call_args)
vname, fntype = function_name_and_type(lhs)
isruntime, vname = unwrap_runtime_var(vname)
if isruntime
_vname = vname
else
_vname = Meta.quot(vname)
end
argtypes = arg_types_from_call_args(call_args)
ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $fntype}}($_vname)))
else
ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)(map($unwrap, ($(call_args...),))...))
vname = lhs
if isruntime
_vname = vname
else
_vname = Meta.quot(vname)
end
ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}, Nothing}}($_vname)(map($unwrap, ($(call_args...),))...))
end
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))

if val !== nothing
ex = :($setdefaultval($ex, $val))
end
ex = setprops_expr(ex, prop, macroname, Meta.quot(lhs))
ex = setprops_expr(ex, prop, macroname, Meta.quot(vname))
#ex = :($scalarize_getindex($ex))

ex = :($wrap($ex))

ex = :($transform($ex))
if isruntime
lhs = gensym(lhs)
vname = gensym(vname)
end
lhs, :($lhs = $ex)
vname, :($vname = $ex)
end

function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime)
issym = v isa Symbol
isarray = isa(v, Expr) && v.head == :ref
isarray = Meta.isexpr(v, :ref)
if isarray
# this can't be an array of functions, since that was handled by `construct_dep_array_vars`
var_name = v.args[1]
if Meta.isexpr(var_name, :(::))
var_name, type′ = var_name.args
Expand All @@ -192,6 +216,15 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr
isruntime, var_name = unwrap_runtime_var(var_name)
indices = v.args[2:end]
expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...)
elseif call_args_are_function(call_args)
var_name, fntype = function_name_and_type(v)
isruntime, var_name = unwrap_runtime_var(var_name)
if isruntime
vname = var_name
else
vname = Meta.quot(var_name)
end
expr = construct_var(macroname, fntype == Nothing ? vname : Expr(:(::), vname, fntype), type, call_args, val, prop)
else
var_name = v
if Meta.isexpr(v, :(::))
Expand Down Expand Up @@ -253,13 +286,48 @@ function (f::CallWithMetadata)(args...)
metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f))
end

function arg_types_from_call_args(call_args)
if length(call_args) == 1 && only(call_args) == :..
return Tuple
end
Ts = map(call_args) do arg
if arg == :..
Vararg
elseif arg isa Expr && arg.head == :(::)
if length(arg.args) == 1
arg.args[1]
elseif arg.args[1] == :..
:(Vararg{$(arg.args[2])})
else
arg.args[2]
end
else
error("Invalid call argument $arg")
end
end
return :(Tuple{$(Ts...)})
end

function function_name_and_type(var_name)
if var_name isa Expr && Meta.isexpr(var_name, :(::), 2)
var_name.args
else
var_name, Nothing
end
end

function construct_var(macroname, var_name, type, call_args, val, prop)
expr = if call_args === nothing
:($Sym{$type}($var_name))
elseif !isempty(call_args) && call_args[end] == :..
:($CallWithMetadata($Sym{$FnType{Tuple, $type}}($var_name)))
elseif call_args_are_function(call_args)
# function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet
# .. is Vararg
# (..)::ArgT is Vararg{ArgT}
var_name, fntype = function_name_and_type(var_name)
argtypes = arg_types_from_call_args(call_args)
:($CallWithMetadata($Sym{$FnType{$argtypes, $type, $fntype}}($var_name)))
else
:($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...)))
:($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type, Nothing}}($var_name)($(map(x->:($value($x)), call_args)...)))
end

if val !== nothing
Expand All @@ -283,15 +351,17 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop,
expr = if call_args === nothing
ex = :($Sym{Array{$type, $ndim}}($var_name))
:($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
elseif !isempty(call_args) && call_args[end] == :..
elseif call_args_are_function(call_args)
need_scalarize = true
ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name))
var_name, fntype = function_name_and_type(var_name)
argtypes = arg_types_from_call_args(call_args)
ex = :($Sym{Array{$FnType{$argtypes, $type, $fntype}, $ndim}}($var_name))
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
:($map($CallWithMetadata, $ex))
else
# [(R -> R)(R) ....]
need_scalarize = true
ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name))
ex = :($Sym{Array{$FnType{Tuple, $type, Nothing}, $ndim}}($var_name))
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
:($map($CallWith(($(call_args...),)), $ex))
end
Expand Down
123 changes: 120 additions & 3 deletions test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import Symbolics: CallWithMetadata, getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype
using LinearAlgebra
using Test
Expand Down Expand Up @@ -221,12 +221,12 @@ let
end

@variables t y(t)
yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real})
yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real, Nothing})
yyy = yy(t)
@test isequal(yyy, y)
@test yyy isa Num
@test y isa Num
yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real})
yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real, Nothing})
yyy = yy(t)
@test !isequal(yyy, y)
@variables y(..)
Expand All @@ -238,3 +238,120 @@ spam(x) = 2x
sym = spam([a, 2a])
@test sym isa Num
@test unwrap(sym) isa BasicSymbolic{Real}

fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)]

struct VariableFoo end
Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo

function test_all_functions(fns)
f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns
@variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String
@test f1 isa CallWithMetadata{FnType{Tuple, Real, Nothing}}
@test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)])
@test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int, Nothing}}
@test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)])
@test_throws ErrorException f2()
@test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}}
@test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)])
@test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real, Nothing}}
@test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)])
@test_throws ErrorException f4(x)
@test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real, Nothing}}
@test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)])
@test_throws ErrorException f5(x)
@test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, Nothing}}
@test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)])
@test_throws ErrorException f6()
@test_throws ErrorException f6(1)
@test_throws ErrorException f6(x, y)
@test_throws ErrorException f6(y)
@test f7 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, typeof(max)}}
# call behavior tested by f6
@test f8 isa CallWithMetadata{FnType{Tuple{Function, Vararg}, Real, typeof(sum)}}
@test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)])
@test_throws ErrorException f8(x)
@test_throws ErrorException f8(1)
@test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}, Nothing}}
@test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)])
@test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}}
@test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)])
@test_throws ErrorException f10(w, v)
end

function test_functions_defaults(fns)
for (fn, def) in zip(fns, fn_defaults)
@test Symbolics.getdefaultval(fn, nothing) == def
end
end

function test_functions_metadata(fns)
for (i, fn) in enumerate(fns)
@test Symbolics.getmetadata(fn, VariableFoo, nothing) == i
end
end

fns = @test_nowarn @variables begin
f1(..)
f2(::Any, ..)::Int
(f3::typeof(max))(..)
f4(::Int)
f5(::Int, (..)::Int)
f6(::Int, ::Int)::Int
(f7::typeof(max))(::Int, ::Int)::Int
(f8::typeof(sum))(::Function, ..)
f9(..)[1:3]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3]
# f11[1:3](::Int)::Int
end

test_all_functions(fns)

fns = @test_nowarn @variables begin
f1(..) = fn_defaults[1]
f2(::Any, ..)::Int = fn_defaults[2]
(f3::typeof(max))(..) = fn_defaults[3]
f4(::Int) = fn_defaults[4]
f5(::Int, (..)::Int) = fn_defaults[5]
f6(::Int, ::Int)::Int = fn_defaults[6]
(f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7]
(f8::typeof(sum))(::Function, ..) = fn_defaults[8]
f9(..)[1:3] = fn_defaults[9]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10]
end

test_all_functions(fns)
test_functions_defaults(fns)

fns = @variables begin
f1(..) = fn_defaults[1], [foo = 1]
f2(::Any, ..)::Int = fn_defaults[2], [foo = 2;]
(f3::typeof(max))(..) = fn_defaults[3], [foo = 3;]
f4(::Int) = fn_defaults[4], [foo = 4;]
f5(::Int, (..)::Int) = fn_defaults[5], [foo = 5;]
f6(::Int, ::Int)::Int = fn_defaults[6], [foo = 6;]
(f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7], [foo = 7;]
(f8::typeof(sum))(::Function, ..) = fn_defaults[8], [foo = 8;]
f9(..)[1:3] = fn_defaults[9], [foo = 9;]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10], [foo = 10;]
end

test_all_functions(fns)
test_functions_defaults(fns)
test_functions_metadata(fns)

fns = @test_nowarn @variables begin
f1(..), [foo = 1,]
f2(::Any, ..)::Int, [foo = 2,]
(f3::typeof(max))(..), [foo = 3,]
f4(::Int), [foo = 4,]
f5(::Int, (..)::Int), [foo = 5,]
f6(::Int, ::Int)::Int, [foo = 6,]
(f7::typeof(max))(::Int, ::Int)::Int, [foo = 7,]
(f8::typeof(sum))(::Function, ..), [foo = 8,]
f9(..)[1:3], [foo = 9,]
(f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3], [foo = 10,]
end

test_all_functions(fns)
test_functions_metadata(fns)
2 changes: 1 addition & 1 deletion test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ vars = @variables t $a $b(t) $c(t)[1:3]
@test c === :value_c
@test isequal(vars[1], t)
@test isequal(vars[2], Num(Sym{Real}(a)))
@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real}}(b)(value(t))))
@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real,Nothing}}(b)(value(t))))

vars = @variables a,b,c,d,e,f,g,h,i
@test isequal(vars, [a,b,c,d,e,f,g,h,i])
Expand Down

0 comments on commit e1ba803

Please sign in to comment.