Skip to content

Commit

Permalink
add tests for array registration
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed Dec 11, 2023
1 parent 500b180 commit b9220ee
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,24 +352,25 @@ end
#

"""
array_term(f, args...; arrayop=nothing)
array_term(f, args...;
container_type = propagate_atype(f, args...),
eltype = propagate_eltype(f, args...),
size = map(length, propagate_shape(f, args...)),
ndims = propagate_ndims(f, args...))
Create a term of `Term{<: AbstractArray}` which
is the representation of `f(args...)`.
- Calls `propagate_atype(f, args...)` to determine the
container type, i.e. `Array` or `StaticArray` etc.
- Calls `propagate_eltype(f, args...)` to determine the
output element type.
- Calls `propagate_ndims(f, args...)` to determine the
output dimension.
- Calls `propagate_shape(f, args...)` to determine the
output array shape.
Default arguments:
- `container_type=propagate_atype(f, args...)` - the container type,
i.e. `Array` or `StaticArray` etc.
- `eltype=propagate_eltype(f, args...)` - the output element type.
- `size=map(length, propagate_shape(f, args...))` - the
output array size. `propagate_shape` returns a tuple of index ranges.
- `ndims=propagate_ndims(f, args...)` the output dimension.
`propagate_shape`, `propagate_atype`, `propagate_eltype` may
return `Unknown()` to say that the output cannot be determined
But `propagate_ndims` must work and return a non-negative integer.
"""
function array_term(f, args...;
container_type = propagate_atype(f, args...),
Expand Down
30 changes: 30 additions & 0 deletions test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic
using LinearAlgebra
using Test

@variables t
Symbolics.@register_symbolic fff(t)
@test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)])))

const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}}
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
end

## @variables

many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]


gg = ggg(x)

@test ndims(gg) == 2
@test size(gg) == (8,8)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}

struct CanCallWithArray{T}
params::T
end

@register_array_symbolic (c::CanCallWithArray)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
end

hh = CanCallWithArray((length=10,))(gg, x)
@test size(hh) == (8,4,10)
@test eltype(hh) == Real
@test isequal(arguments(unwrap(hh)), unwrap.([gg, x]))

@test all(t->getsource(t)[1] === :variables, many_vars)
@test getdefaultval(t) == 0
@test getdefaultval(a) == 1
Expand Down

0 comments on commit b9220ee

Please sign in to comment.