From 86dc9567b9acceb6989a27b9f12f007f4a47cc99 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 13 Sep 2024 16:50:31 +0530 Subject: [PATCH 1/6] feat: add support for typed function symbolics --- src/variable.jl | 100 +++++++++++++++++++++++++++++++------ test/macro.jl | 123 ++++++++++++++++++++++++++++++++++++++++++++-- test/overloads.jl | 2 +- 3 files changed, 206 insertions(+), 19 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index a8845171b..c73c423cb 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -101,6 +101,11 @@ function _parse_vars(macroname, type, x, transform=identity) # x = 1, [connect = flow; unit = u"m^3/s"] if Meta.isexpr(v, :(=)) v, val = v.args + # defaults with metadata for function variables + if Meta.isexpr(val, :block) + Base.remove_linenums!(val) + val = only(val.args) + end if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2]) options = val.args[2].args val = val.args[1] @@ -124,7 +129,7 @@ function _parse_vars(macroname, type, x, transform=identity) isruntime, v = unwrap_runtime_var(v) iscall = Meta.isexpr(v, :call) isarray = Meta.isexpr(v, :ref) - if iscall && Meta.isexpr(v.args[1], :ref) + if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(last∘unwrap_runtime_var, @view v.args[2:end])) @warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead. The former creates an array of functions, while the latter creates an array valued function. The deprecated syntax will cause an error in the next major release of Symbolics. @@ -155,35 +160,54 @@ function _parse_vars(macroname, type, x, transform=identity) return ex end +call_args_are_function(_) = false +function call_args_are_function(call_args::AbstractArray) + !isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args)) +end + function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime) ndim = :($length(($(indices...),))) - vname = !isruntime ? Meta.quot(lhs) : lhs - if call_args[1] == :.. - ex = :($CallWithMetadata($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname))) + if call_args_are_function(call_args) + vname, fntype = function_name_and_type(lhs) + isruntime, vname = unwrap_runtime_var(vname) + if isruntime + _vname = vname + else + _vname = Meta.quot(vname) + end + argtypes = arg_types_from_call_args(call_args) + ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $fntype}}($_vname))) else - ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)(map($unwrap, ($(call_args...),))...)) + vname = lhs + if isruntime + _vname = vname + else + _vname = Meta.quot(vname) + end + ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}, Nothing}}($_vname)(map($unwrap, ($(call_args...),))...)) end ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) if val !== nothing ex = :($setdefaultval($ex, $val)) end - ex = setprops_expr(ex, prop, macroname, Meta.quot(lhs)) + ex = setprops_expr(ex, prop, macroname, Meta.quot(vname)) #ex = :($scalarize_getindex($ex)) ex = :($wrap($ex)) ex = :($transform($ex)) if isruntime - lhs = gensym(lhs) + vname = gensym(vname) end - lhs, :($lhs = $ex) + vname, :($vname = $ex) end function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime) issym = v isa Symbol - isarray = isa(v, Expr) && v.head == :ref + isarray = Meta.isexpr(v, :ref) if isarray + # this can't be an array of functions, since that was handled by `construct_dep_array_vars` var_name = v.args[1] if Meta.isexpr(var_name, :(::)) var_name, type′ = var_name.args @@ -192,6 +216,15 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr isruntime, var_name = unwrap_runtime_var(var_name) indices = v.args[2:end] expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...) + elseif call_args_are_function(call_args) + var_name, fntype = function_name_and_type(v) + isruntime, var_name = unwrap_runtime_var(var_name) + if isruntime + vname = var_name + else + vname = Meta.quot(var_name) + end + expr = construct_var(macroname, fntype == Nothing ? vname : Expr(:(::), vname, fntype), type, call_args, val, prop) else var_name = v if Meta.isexpr(v, :(::)) @@ -253,13 +286,48 @@ function (f::CallWithMetadata)(args...) metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)) end +function arg_types_from_call_args(call_args) + if length(call_args) == 1 && only(call_args) == :.. + return Tuple + end + Ts = map(call_args) do arg + if arg == :.. + Vararg + elseif arg isa Expr && arg.head == :(::) + if length(arg.args) == 1 + arg.args[1] + elseif arg.args[1] == :.. + :(Vararg{$(arg.args[2])}) + else + arg.args[2] + end + else + error("Invalid call argument $arg") + end + end + return :(Tuple{$(Ts...)}) +end + +function function_name_and_type(var_name) + if var_name isa Expr && Meta.isexpr(var_name, :(::), 2) + var_name.args + else + var_name, Nothing + end +end + function construct_var(macroname, var_name, type, call_args, val, prop) expr = if call_args === nothing :($Sym{$type}($var_name)) - elseif !isempty(call_args) && call_args[end] == :.. - :($CallWithMetadata($Sym{$FnType{Tuple, $type}}($var_name))) + elseif call_args_are_function(call_args) + # function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet + # .. is Vararg + # (..)::ArgT is Vararg{ArgT} + var_name, fntype = function_name_and_type(var_name) + argtypes = arg_types_from_call_args(call_args) + :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $fntype}}($var_name))) else - :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...))) + :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type, Nothing}}($var_name)($(map(x->:($value($x)), call_args)...))) end if val !== nothing @@ -283,15 +351,17 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop, expr = if call_args === nothing ex = :($Sym{Array{$type, $ndim}}($var_name)) :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - elseif !isempty(call_args) && call_args[end] == :.. + elseif call_args_are_function(call_args) need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name)) + var_name, fntype = function_name_and_type(var_name) + argtypes = arg_types_from_call_args(call_args) + ex = :($Sym{Array{$FnType{$argtypes, $type, $fntype}, $ndim}}($var_name)) ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) :($map($CallWithMetadata, $ex)) else # [(R -> R)(R) ....] need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name)) + ex = :($Sym{Array{$FnType{Tuple, $type, Nothing}, $ndim}}($var_name)) ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) :($map($CallWith(($(call_args...),)), $ex)) end diff --git a/test/macro.jl b/test/macro.jl index 20b96f00e..faae311e1 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -1,5 +1,5 @@ using Symbolics -import Symbolics: getsource, getdefaultval, wrap, unwrap, getname +import Symbolics: CallWithMetadata, getsource, getdefaultval, wrap, unwrap, getname import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype using LinearAlgebra using Test @@ -221,12 +221,12 @@ let end @variables t y(t) -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real, Nothing}) yyy = yy(t) @test isequal(yyy, y) @test yyy isa Num @test y isa Num -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real, Nothing}) yyy = yy(t) @test !isequal(yyy, y) @variables y(..) @@ -238,3 +238,120 @@ spam(x) = 2x sym = spam([a, 2a]) @test sym isa Num @test unwrap(sym) isa BasicSymbolic{Real} + +fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)] + +struct VariableFoo end +Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo + +function test_all_functions(fns) + f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns + @variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String + @test f1 isa CallWithMetadata{FnType{Tuple, Real, Nothing}} + @test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)]) + @test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int, Nothing}} + @test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)]) + @test_throws ErrorException f2() + @test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}} + @test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)]) + @test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real, Nothing}} + @test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)]) + @test_throws ErrorException f4(x) + @test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real, Nothing}} + @test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)]) + @test_throws ErrorException f5(x) + @test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, Nothing}} + @test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)]) + @test_throws ErrorException f6() + @test_throws ErrorException f6(1) + @test_throws ErrorException f6(x, y) + @test_throws ErrorException f6(y) + @test f7 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, typeof(max)}} + # call behavior tested by f6 + @test f8 isa CallWithMetadata{FnType{Tuple{Function, Vararg}, Real, typeof(sum)}} + @test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)]) + @test_throws ErrorException f8(x) + @test_throws ErrorException f8(1) + @test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}, Nothing}} + @test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)]) + @test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}} + @test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)]) + @test_throws ErrorException f10(w, v) +end + +function test_functions_defaults(fns) + for (fn, def) in zip(fns, fn_defaults) + @test Symbolics.getdefaultval(fn, nothing) == def + end +end + +function test_functions_metadata(fns) + for (i, fn) in enumerate(fns) + @test Symbolics.getmetadata(fn, VariableFoo, nothing) == i + end +end + +fns = @test_nowarn @variables begin + f1(..) + f2(::Any, ..)::Int + (f3::typeof(max))(..) + f4(::Int) + f5(::Int, (..)::Int) + f6(::Int, ::Int)::Int + (f7::typeof(max))(::Int, ::Int)::Int + (f8::typeof(sum))(::Function, ..) + f9(..)[1:3] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] + # f11[1:3](::Int)::Int +end + +test_all_functions(fns) + +fns = @test_nowarn @variables begin + f1(..) = fn_defaults[1] + f2(::Any, ..)::Int = fn_defaults[2] + (f3::typeof(max))(..) = fn_defaults[3] + f4(::Int) = fn_defaults[4] + f5(::Int, (..)::Int) = fn_defaults[5] + f6(::Int, ::Int)::Int = fn_defaults[6] + (f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7] + (f8::typeof(sum))(::Function, ..) = fn_defaults[8] + f9(..)[1:3] = fn_defaults[9] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10] +end + +test_all_functions(fns) +test_functions_defaults(fns) + +fns = @variables begin + f1(..) = fn_defaults[1], [foo = 1] + f2(::Any, ..)::Int = fn_defaults[2], [foo = 2;] + (f3::typeof(max))(..) = fn_defaults[3], [foo = 3;] + f4(::Int) = fn_defaults[4], [foo = 4;] + f5(::Int, (..)::Int) = fn_defaults[5], [foo = 5;] + f6(::Int, ::Int)::Int = fn_defaults[6], [foo = 6;] + (f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7], [foo = 7;] + (f8::typeof(sum))(::Function, ..) = fn_defaults[8], [foo = 8;] + f9(..)[1:3] = fn_defaults[9], [foo = 9;] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10], [foo = 10;] +end + +test_all_functions(fns) +test_functions_defaults(fns) +test_functions_metadata(fns) + +fns = @test_nowarn @variables begin + f1(..), [foo = 1,] + f2(::Any, ..)::Int, [foo = 2,] + (f3::typeof(max))(..), [foo = 3,] + f4(::Int), [foo = 4,] + f5(::Int, (..)::Int), [foo = 5,] + f6(::Int, ::Int)::Int, [foo = 6,] + (f7::typeof(max))(::Int, ::Int)::Int, [foo = 7,] + (f8::typeof(sum))(::Function, ..), [foo = 8,] + f9(..)[1:3], [foo = 9,] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3], [foo = 10,] +end + +test_all_functions(fns) +test_functions_metadata(fns) diff --git a/test/overloads.jl b/test/overloads.jl index be893876e..bea817d85 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -12,7 +12,7 @@ vars = @variables t $a $b(t) $c(t)[1:3] @test c === :value_c @test isequal(vars[1], t) @test isequal(vars[2], Num(Sym{Real}(a))) -@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real}}(b)(value(t)))) +@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real,Nothing}}(b)(value(t)))) vars = @variables a,b,c,d,e,f,g,h,i @test isequal(vars, [a,b,c,d,e,f,g,h,i]) From 89a53a9fcfb200bd7bbb9047fc9b25faa7bb4405 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 13 Sep 2024 18:04:16 +0530 Subject: [PATCH 2/6] feat: add `CallWithParent` to keep track of callable variables --- src/variable.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/variable.jl b/src/variable.jl index c73c423cb..8abb3b3ee 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -282,8 +282,10 @@ function Base.show(io::IO, c::CallWithMetadata) print(io, "⋆") end +struct CallWithParent end + function (f::CallWithMetadata)(args...) - metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)) + setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f) end function arg_types_from_call_args(call_args) From 8e1a3086702191fb9d2aeaa908881932933e161a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 13 Sep 2024 18:21:20 +0530 Subject: [PATCH 3/6] feat: support `Base.isequal` for `CallWithMetadata` --- src/variable.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/variable.jl b/src/variable.jl index 8abb3b3ee..40d6c5fb4 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -288,6 +288,8 @@ function (f::CallWithMetadata)(args...) setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f) end +Base.isequal(a::CallWithMetadata, b::CallWithMetadata) = isequal(a.f, b.f) + function arg_types_from_call_args(call_args) if length(call_args) == 1 && only(call_args) == :.. return Tuple From 13f31e3076ded902761108e343ec2db4d2223453 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 17 Sep 2024 11:39:17 +0530 Subject: [PATCH 4/6] build: bump SymbolicUtils compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b4829c368..9f7f78347 100644 --- a/Project.toml +++ b/Project.toml @@ -93,7 +93,7 @@ StaticArraysCore = "1.4" SymPy = "2.2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.2" -SymbolicUtils = "2, 3" +SymbolicUtils = "3.7" TermInterface = "2" julia = "1.10" From ee99dea042ee965b0fcd5cc7a464b3e7d277bb23 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 18 Sep 2024 13:24:35 +0530 Subject: [PATCH 5/6] refactor: do not use `Nothing` as a sentinel value for `FnType` --- src/variable.jl | 18 +++++++++--------- test/macro.jl | 16 ++++++++-------- test/overloads.jl | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 40d6c5fb4..a85d8755c 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -176,7 +176,7 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, _vname = Meta.quot(vname) end argtypes = arg_types_from_call_args(call_args) - ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $fntype}}($_vname))) + ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname))) else vname = lhs if isruntime @@ -184,7 +184,7 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, else _vname = Meta.quot(vname) end - ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}, Nothing}}($_vname)(map($unwrap, ($(call_args...),))...)) + ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($_vname)(map($unwrap, ($(call_args...),))...)) end ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) @@ -224,7 +224,7 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr else vname = Meta.quot(var_name) end - expr = construct_var(macroname, fntype == Nothing ? vname : Expr(:(::), vname, fntype), type, call_args, val, prop) + expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop) else var_name = v if Meta.isexpr(v, :(::)) @@ -314,9 +314,9 @@ end function function_name_and_type(var_name) if var_name isa Expr && Meta.isexpr(var_name, :(::), 2) - var_name.args + var_name.args[1], (var_name.args[2],) else - var_name, Nothing + var_name, () end end @@ -329,9 +329,9 @@ function construct_var(macroname, var_name, type, call_args, val, prop) # (..)::ArgT is Vararg{ArgT} var_name, fntype = function_name_and_type(var_name) argtypes = arg_types_from_call_args(call_args) - :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $fntype}}($var_name))) + :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name))) else - :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type, Nothing}}($var_name)($(map(x->:($value($x)), call_args)...))) + :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...))) end if val !== nothing @@ -359,13 +359,13 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop, need_scalarize = true var_name, fntype = function_name_and_type(var_name) argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{Array{$FnType{$argtypes, $type, $fntype}, $ndim}}($var_name)) + ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name)) ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) :($map($CallWithMetadata, $ex)) else # [(R -> R)(R) ....] need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type, Nothing}, $ndim}}($var_name)) + ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name)) ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) :($map($CallWith(($(call_args...),)), $ex)) end diff --git a/test/macro.jl b/test/macro.jl index faae311e1..8624af78f 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -221,12 +221,12 @@ let end @variables t y(t) -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real, Nothing}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real}) yyy = yy(t) @test isequal(yyy, y) @test yyy isa Num @test y isa Num -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real, Nothing}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real}) yyy = yy(t) @test !isequal(yyy, y) @variables y(..) @@ -247,20 +247,20 @@ Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo function test_all_functions(fns) f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns @variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String - @test f1 isa CallWithMetadata{FnType{Tuple, Real, Nothing}} + @test f1 isa CallWithMetadata{FnType{Tuple, Real}} @test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)]) - @test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int, Nothing}} + @test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int}} @test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)]) @test_throws ErrorException f2() @test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}} @test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)]) - @test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real, Nothing}} + @test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real}} @test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)]) @test_throws ErrorException f4(x) - @test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real, Nothing}} + @test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real}} @test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)]) @test_throws ErrorException f5(x) - @test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, Nothing}} + @test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int}} @test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)]) @test_throws ErrorException f6() @test_throws ErrorException f6(1) @@ -272,7 +272,7 @@ function test_all_functions(fns) @test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)]) @test_throws ErrorException f8(x) @test_throws ErrorException f8(1) - @test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}, Nothing}} + @test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}}} @test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)]) @test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}} @test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)]) diff --git a/test/overloads.jl b/test/overloads.jl index bea817d85..be893876e 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -12,7 +12,7 @@ vars = @variables t $a $b(t) $c(t)[1:3] @test c === :value_c @test isequal(vars[1], t) @test isequal(vars[2], Num(Sym{Real}(a))) -@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real,Nothing}}(b)(value(t)))) +@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real}}(b)(value(t)))) vars = @variables a,b,c,d,e,f,g,h,i @test isequal(vars, [a,b,c,d,e,f,g,h,i]) From 79023776dc440ef17bc662e83c7148325bd52586 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 18 Sep 2024 13:24:46 +0530 Subject: [PATCH 6/6] feat: support interpolated names in callable symbolics --- src/variable.jl | 28 +++++++++++++++++++++------- test/macro.jl | 16 ++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index a85d8755c..b5230f07b 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -169,11 +169,18 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, ndim = :($length(($(indices...),))) if call_args_are_function(call_args) vname, fntype = function_name_and_type(lhs) - isruntime, vname = unwrap_runtime_var(vname) + # name was already unwrapped before calling this function and is of the form $x if isruntime _vname = vname else - _vname = Meta.quot(vname) + # either no ::fnType or $x::fnType + vname, fntype = function_name_and_type(lhs) + isruntime, vname = unwrap_runtime_var(vname) + if isruntime + _vname = vname + else + _vname = Meta.quot(vname) + end end argtypes = arg_types_from_call_args(call_args) ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname))) @@ -198,14 +205,14 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, ex = :($transform($ex)) if isruntime - vname = gensym(vname) + vname = gensym(Symbol(vname)) end vname, :($vname = $ex) end function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime) issym = v isa Symbol - isarray = Meta.isexpr(v, :ref) + isarray = !isruntime && Meta.isexpr(v, :ref) if isarray # this can't be an array of functions, since that was handled by `construct_dep_array_vars` var_name = v.args[1] @@ -218,11 +225,18 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...) elseif call_args_are_function(call_args) var_name, fntype = function_name_and_type(v) - isruntime, var_name = unwrap_runtime_var(var_name) + # name was already unwrapped before calling this function and is of the form $x if isruntime vname = var_name else - vname = Meta.quot(var_name) + # either no ::fnType or $x::fnType + var_name, fntype = function_name_and_type(v) + isruntime, var_name = unwrap_runtime_var(var_name) + if isruntime + vname = var_name + else + vname = Meta.quot(var_name) + end end expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop) else @@ -233,7 +247,7 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr end expr = construct_var(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop) end - lhs = isruntime ? gensym(var_name) : var_name + lhs = isruntime ? gensym(Symbol(var_name)) : var_name rhs = :($transform($expr)) lhs, :($lhs = $rhs) end diff --git a/test/macro.jl b/test/macro.jl index 8624af78f..0dd82728b 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -240,6 +240,7 @@ sym = spam([a, 2a]) @test unwrap(sym) isa BasicSymbolic{Real} fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)] +fn_names = [Symbol(:f, i) for i in 1:10] struct VariableFoo end Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo @@ -355,3 +356,18 @@ end test_all_functions(fns) test_functions_metadata(fns) + +fns = @test_nowarn @variables begin + $(fn_names[1])(..) + $(fn_names[2])(::Any, ..)::Int + ($(fn_names[3])::typeof(max))(..) + $(fn_names[4])(::Int) + $(fn_names[5])(::Int, (..)::Int) + $(fn_names[6])(::Int, ::Int)::Int + ($(fn_names[7])::typeof(max))(::Int, ::Int)::Int + ($(fn_names[8])::typeof(sum))(::Function, ..) + $(fn_names[9])(..)[1:3] + ($(fn_names[10])::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] +end + +test_all_functions(fns)