Skip to content

Commit

Permalink
fix: make @register_array_symbolic overload promote_symtype
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 6, 2024
1 parent 5787b48 commit b78d724
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 18 deletions.
54 changes: 49 additions & 5 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function destructure_registration_expr(expr, Ts)
end


function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :())
function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true)
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
Expand All @@ -90,7 +90,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs


args′ = map((a, T) -> :($a::$T), argnames, Ts)
quote
fexpr = quote
@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
Expand All @@ -109,10 +109,44 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
end
end
end |> esc

if define_promotion
container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...))))
etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...))))
ndims = get(defs, :ndims, nothing)
is_callable_struct = f isa Expr && f.head == :(::)
fn_arg = if is_callable_struct
f
else
:(f::$ftype)
end
fn_arg_name = if is_callable_struct
f.args[1]
else
:f
end
promote_expr = quote
function (::$typeof($promote_symtype))($fn_arg, $(argnames...))
f = $fn_arg_name
container_type = $container_type
etype = $etype
$(
if ndims === nothing
:(return container_type{etype})
else
:(ndims = $ndims; return container_type{etype, ndims})
end
)
end
end |> esc
fexpr = :($fexpr; $promote_expr)
end

return fexpr
end

"""
@register_array_symbolic(expr)
@register_array_symbolic(expr, define_promotion = true)
Example:
Expand All @@ -132,8 +166,18 @@ You can also register calls on callable structs:
eltype=promote_type(eltype(x), eltype(c))
end
```
If `define_promotion = true` then a promotion method in the form of
```julia
SymbolicUtils.promote_symtype(::typeof(f_registered), args...) = # inferred or annotated return type
```
is defined for the register function. Note that when defining multiple register
overloads for one function, all the rest of the registers must set
`define_promotion` to `false` except for the first one, to avoid method
overwriting.
"""
macro register_array_symbolic(expr, block)
macro register_array_symbolic(expr, block, define_promotion = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([]))
return register_array_symbolic(f, ftype, argnames, Ts, ret_type, block)
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion)
end
74 changes: 61 additions & 13 deletions test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype
using LinearAlgebra
using Test

Expand All @@ -9,34 +9,57 @@ Symbolics.@register_symbolic fff(t)
@test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)])))

const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}}
many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]

let
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
end false

## @variables

gg = ggg(x)

@test ndims(gg) == 2
@test size(gg) == (8,8)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}
@test promote_symtype(ggg, symtype(unwrap(x))) == Any # no promote_symtype defined
end
let
# redefine with promote_symtype
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
end
@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real}
end

# ndims specified
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
ndims = 2
end

## @variables

many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]

@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real, 2}

gg = ggg(x)

@test ndims(gg) == 2
@test size(gg) == (8,8)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}

struct CanCallWithArray{T}
params::T
end

ccwa = CanCallWithArray((length=10,))
@register_array_symbolic (c::CanCallWithArray)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
end
end false # without promote_symtype

hh = CanCallWithArray((length=10,))(gg, x)
hh = ccwa(gg, x)
@test size(hh) == (8,4,10)
@test eltype(hh) == Real
@test isequal(arguments(unwrap(hh)), unwrap.([gg, x]))
Expand All @@ -52,9 +75,34 @@ hh = CanCallWithArray((length=10,))(gg, x)
@test getdefaultval(z[3]) == 4

@test symtype(p) <: FnType{Tuple, Array{Real,1}}
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Any
@test p(t)[1] isa Symbolics.Num


struct CanCallWithArray2{T}
params::T
end

ccwa = CanCallWithArray2((length=10,))
@register_array_symbolic (c::CanCallWithArray2)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
end
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real}

struct CanCallWithArray3{T}
params::T
end

ccwa = CanCallWithArray3((length=10,))
# ndims specified
@register_array_symbolic (c::CanCallWithArray3)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
ndims = 3
end
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real, 3}

## Wrapper types

abstract type AbstractFoo{T} end
Expand Down

0 comments on commit b78d724

Please sign in to comment.