Skip to content

Commit

Permalink
Merge pull request #1309 from AayushSabharwal/as/arr-register-fix
Browse files Browse the repository at this point in the history
fix: fix `promote_symtype` in `register_array_symbolic` with `ndims`
  • Loading branch information
ChrisRackauckas authored Oct 19, 2024
2 parents 4ad4fea + e036266 commit 16d4bd6
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 26 deletions.
19 changes: 0 additions & 19 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,23 +317,4 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
sol
end


# XXX: on 1.11, SymbolicsNemoExt isn't actually loaded during the pre-compilation of SymbolicsGroebnerExt,
# so we can't run a workload using its definitions yet. (The new restriction was added in 1.11 because
# of "AB" issues where an extension A and B would mutually expect the other to load "first")
#
# This can be re-enabled when it is possible for an extension A to explicitly declare that it depends on
# an extension B.

# # Helps with precompilation time
# PrecompileTools.@setup_workload begin
# @variables a b c x y z
# simple_linear_equations = [x - y, y + 2z]
# equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z]
# PrecompileTools.@compile_workload begin
# symbolic_solve(simple_linear_equations, [x, y], warns=false)
# symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
# end
# end

end # module
6 changes: 3 additions & 3 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
if define_promotion
container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...))))
etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...))))
ndims = get(defs, :ndims, nothing)
ndim = get(defs, :ndims, nothing)
is_callable_struct = f isa Expr && f.head == :(::)
fn_arg = if is_callable_struct
f
Expand All @@ -147,10 +147,10 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
container_type = $container_type
etype = $etype
$(
if ndims === nothing
if ndim === nothing
:(return container_type{etype})
else
:(ndims = $ndims; return container_type{etype, ndims})
:(ndim = $ndim; return container_type{etype, ndim})
end
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/intermediate-exprs-inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/intermediate-exprs-outplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/manual-limits-inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/manual-limits-outplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
10 changes: 10 additions & 0 deletions test/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ let
end

# ndims specified

# in terms of argument
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
ndims = ndims(x) + 1
end
@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real, 2}

@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
Expand Down

0 comments on commit 16d4bd6

Please sign in to comment.