From b78d724c34c3381bf150f4b616e84205d6bf8186 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 May 2024 13:52:07 +0530 Subject: [PATCH] fix: make `@register_array_symbolic` overload `promote_symtype` --- src/register.jl | 54 ++++++++++++++++++++++++++++++++---- test/macro.jl | 74 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 110 insertions(+), 18 deletions(-) diff --git a/src/register.jl b/src/register.jl index 4f62ad9a4..51c4152ac 100644 --- a/src/register.jl +++ b/src/register.jl @@ -81,7 +81,7 @@ function destructure_registration_expr(expr, Ts) end -function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :()) +function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true) def_assignments = MacroTools.rmlines(partial_defs).args defs = map(def_assignments) do ex @assert ex.head == :(=) @@ -90,7 +90,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs args′ = map((a, T) -> :($a::$T), argnames, Ts) - quote + fexpr = quote @wrapped function $f($(args′...)) args = [$(argnames...),] unwrapped_args = map($unwrap, args) @@ -109,10 +109,44 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs end end end |> esc + + if define_promotion + container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...)))) + etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...)))) + ndims = get(defs, :ndims, nothing) + is_callable_struct = f isa Expr && f.head == :(::) + fn_arg = if is_callable_struct + f + else + :(f::$ftype) + end + fn_arg_name = if is_callable_struct + f.args[1] + else + :f + end + promote_expr = quote + function (::$typeof($promote_symtype))($fn_arg, $(argnames...)) + f = $fn_arg_name + container_type = $container_type + etype = $etype + $( + if ndims === nothing + :(return container_type{etype}) + else + :(ndims = $ndims; return container_type{etype, ndims}) + end + ) + end + end |> esc + fexpr = :($fexpr; $promote_expr) + end + + return fexpr end """ - @register_array_symbolic(expr) + @register_array_symbolic(expr, define_promotion = true) Example: @@ -132,8 +166,18 @@ You can also register calls on callable structs: eltype=promote_type(eltype(x), eltype(c)) end ``` + +If `define_promotion = true` then a promotion method in the form of +```julia +SymbolicUtils.promote_symtype(::typeof(f_registered), args...) = # inferred or annotated return type +``` + +is defined for the register function. Note that when defining multiple register +overloads for one function, all the rest of the registers must set +`define_promotion` to `false` except for the first one, to avoid method +overwriting. """ -macro register_array_symbolic(expr, block) +macro register_array_symbolic(expr, block, define_promotion = true) f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([])) - return register_array_symbolic(f, ftype, argnames, Ts, ret_type, block) + register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion) end diff --git a/test/macro.jl b/test/macro.jl index 5d9f417a0..936fbbb7a 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -1,6 +1,6 @@ using Symbolics import Symbolics: getsource, getdefaultval, wrap, unwrap, getname -import SymbolicUtils: Term, symtype, FnType, BasicSymbolic +import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype using LinearAlgebra using Test @@ -9,34 +9,57 @@ Symbolics.@register_symbolic fff(t) @test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)]))) const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}} +many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4] + +let + @register_array_symbolic ggg(x::AbstractVector) begin + container_type=SymMatrix + size=(length(x) * 2, length(x) * 2) + eltype=eltype(x) + end false + + ## @variables + + gg = ggg(x) + + @test ndims(gg) == 2 + @test size(gg) == (8,8) + @test eltype(gg) == Real + @test symtype(unwrap(gg)) == SymMatrix{Real, 2} + @test promote_symtype(ggg, symtype(unwrap(x))) == Any # no promote_symtype defined +end +let + # redefine with promote_symtype + @register_array_symbolic ggg(x::AbstractVector) begin + container_type=SymMatrix + size=(length(x) * 2, length(x) * 2) + eltype=eltype(x) + end + @test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real} +end + +# ndims specified @register_array_symbolic ggg(x::AbstractVector) begin container_type=SymMatrix size=(length(x) * 2, length(x) * 2) eltype=eltype(x) + ndims = 2 end - -## @variables - -many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4] - +@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real, 2} gg = ggg(x) -@test ndims(gg) == 2 -@test size(gg) == (8,8) -@test eltype(gg) == Real -@test symtype(unwrap(gg)) == SymMatrix{Real, 2} - struct CanCallWithArray{T} params::T end +ccwa = CanCallWithArray((length=10,)) @register_array_symbolic (c::CanCallWithArray)(x::AbstractArray, b::AbstractVector) begin size=(size(x, 1), length(b), c.params.length) eltype=Real -end +end false # without promote_symtype -hh = CanCallWithArray((length=10,))(gg, x) +hh = ccwa(gg, x) @test size(hh) == (8,4,10) @test eltype(hh) == Real @test isequal(arguments(unwrap(hh)), unwrap.([gg, x])) @@ -52,9 +75,34 @@ hh = CanCallWithArray((length=10,))(gg, x) @test getdefaultval(z[3]) == 4 @test symtype(p) <: FnType{Tuple, Array{Real,1}} +@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Any @test p(t)[1] isa Symbolics.Num +struct CanCallWithArray2{T} + params::T +end + +ccwa = CanCallWithArray2((length=10,)) +@register_array_symbolic (c::CanCallWithArray2)(x::AbstractArray, b::AbstractVector) begin + size=(size(x, 1), length(b), c.params.length) + eltype=Real +end +@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real} + +struct CanCallWithArray3{T} + params::T +end + +ccwa = CanCallWithArray3((length=10,)) +# ndims specified +@register_array_symbolic (c::CanCallWithArray3)(x::AbstractArray, b::AbstractVector) begin + size=(size(x, 1), length(b), c.params.length) + eltype=Real + ndims = 3 +end +@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real, 3} + ## Wrapper types abstract type AbstractFoo{T} end