Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix promote_symtype in register_array_symbolic with ndims #1309

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,23 +317,4 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
sol
end


# XXX: on 1.11, SymbolicsNemoExt isn't actually loaded during the pre-compilation of SymbolicsGroebnerExt,
# so we can't run a workload using its definitions yet. (The new restriction was added in 1.11 because
# of "AB" issues where an extension A and B would mutually expect the other to load "first")
#
# This can be re-enabled when it is possible for an extension A to explicitly declare that it depends on
# an extension B.

# # Helps with precompilation time
# PrecompileTools.@setup_workload begin
# @variables a b c x y z
# simple_linear_equations = [x - y, y + 2z]
# equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z]
# PrecompileTools.@compile_workload begin
# symbolic_solve(simple_linear_equations, [x, y], warns=false)
# symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
# end
# end

end # module
6 changes: 3 additions & 3 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
if define_promotion
container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...))))
etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...))))
ndims = get(defs, :ndims, nothing)
ndim = get(defs, :ndims, nothing)
is_callable_struct = f isa Expr && f.head == :(::)
fn_arg = if is_callable_struct
f
Expand All @@ -147,10 +147,10 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
container_type = $container_type
etype = $etype
$(
if ndims === nothing
if ndim === nothing
:(return container_type{etype})
else
:(ndims = $ndims; return container_type{etype, ndims})
:(ndim = $ndim; return container_type{etype, ndim})
end
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/intermediate-exprs-inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/intermediate-exprs-outplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
_out[i′, j′] = (+)(_out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/manual-limits-inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/build_function_tests/manual-limits-outplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
begin
for (j, j′) = zip(1:5, reset_to_one(1:5))
for (i, i′) = zip(1:5, reset_to_one(1:5))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (limit)((+)(-1, i), 5), (limit)((+)(1, j), 5)))
ˍ₋out[i′, j′] = (+)(ˍ₋out[i′, j′], (getindex)(u, (Main.limit)((+)(-1, i), 5), (Main.limit)((+)(1, j), 5)))
end
end
end
Expand Down
10 changes: 10 additions & 0 deletions test/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ let
end

# ndims specified

# in terms of argument
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
ndims = ndims(x) + 1
end
@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real, 2}

@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
Expand Down
Loading