Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add set_scalar_metadata to set metadata scalar-wise #984

Merged
merged 5 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ function recurse_and_apply(f, x)
end
end

function setdefaultval(x, val)
function set_scalar_metadata(x, V, val)
if symtype(x) <: AbstractArray
if val isa AbstractArray
x = if val isa AbstractArray
getindex_posthook(x) do r,x,i...
setdefaultval(r, val[i...])
set_scalar_metadata(r, V, val[i...])
end
else
getindex_posthook(x) do r,x,i...
setdefaultval(r, val)
set_scalar_metadata(r, V, val)
end
end
else
setmetadata(x,
VariableDefaultValue,
val)
end
setmetadata(x, V, val)
end
setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val)

struct GetindexParent end

Expand Down Expand Up @@ -226,7 +224,7 @@ function setprops_expr(expr, props, macroname, varname)
lhs, rhs = opt.args

@assert lhs isa Symbol "the lhs of an option must be a symbol"
expr = :($setmetadata($expr,
expr = :($set_scalar_metadata($expr,
$(option_to_metadata_type(Val{lhs}())),
$rhs))
end
Expand Down
6 changes: 5 additions & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ end

# https://github.com/JuliaSymbolics/Symbolics.jl/issues/842
# getindex should keep metadata
@variables tv v(tv)[1:2] [test_meta = 4]
@variables tv v(tv)[1:2] [test_meta = 4] v2(tv)[1:3] [test_meta=[1, 2, 3]]
@test !isnothing(metadata(unwrap(v)))
@test hasmetadata(unwrap(v), TestMetaT)
@test getmetadata(unwrap(v), TestMetaT) == 4
@test getmetadata(unwrap(v2), TestMetaT) == [1, 2, 3]
vs = scalarize(v)
vsw = unwrap.(vs)
vs2 = scalarize(v2)
vsw2 = unwrap.(vs2)
@test !isnothing(metadata(vsw[1]))
@test hasmetadata(vsw[1], TestMetaT)
@test getmetadata(vsw[1], TestMetaT) == 4
@test getmetadata.(vsw2, TestMetaT) == [1, 2, 3]
@test !isnothing(metadata(unwrap(v[1])))
@test hasmetadata(unwrap(v[1]), TestMetaT)
@test getmetadata(unwrap(v[1]), TestMetaT) == 4
Expand Down
2 changes: 1 addition & 1 deletion test/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5
@test all(t->getsource(t)[1] === :variables, many_vars)
@test getdefaultval(t) == 0
@test getdefaultval(a) == 1
@test_throws ErrorException getdefaultval(x)
@test getdefaultval(x) == 2
@test getdefaultval(x[1]) == 2
@test getdefaultval(y[2]) == 3
@test getdefaultval(w[2]) == 2
Expand Down
6 changes: 3 additions & 3 deletions test/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Symbolics, Test
ex = [:(y ~ x)
:(y ~ -2x + 3 / z)
:(z ~ 2)]
eqs = parse_expr_to_symbolic.(ex, (Main,))
eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,))

@variables x y z
ex = [y ~ x
Expand All @@ -14,9 +14,9 @@ ex = [y ~ x
ex = [:(b(t) ~ a(t))
:(b(t) ~ -2a(t) + 3 / c(t))
:(c(t) ~ 2)]
eqs = parse_expr_to_symbolic.(ex, (Main,))
eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,))
@variables t a(t) b(t) c(t)
ex = [b ~ a
b ~ -2a + 3 / c
c ~ 2]
@test_broken all(isequal.(eqs,ex))
@test_broken all(isequal.(eqs,ex))
77 changes: 38 additions & 39 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,43 @@ using LambertW
#Testing

@testset "solving tests" begin

function hasFloat(expr)#make sure answer does not contain any strange floats
if expr isa Float64
return !isinteger(expr) && expr != float(pi) && expr != exp(1.0)
elseif expr isa Equation
return hasFloat(expr.lhs) || hasFloat(expr.rhs)
elseif istree(expr)
elements = arguments(expr)
for element in elements
if hasFloat(element)
return true
end
end
end
return false
end
correctAns(p,a) = isequal(sort(Symbolics.convert_solutions_to_floats(p)),a) && !hasFloat(p)

@syms x y z a b c

#quadratics
@test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0])
@test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)])
@test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)])
@test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)])
#lambert w
@test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))])
@test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)])
#more challenging quadratics
@test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0])
@test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0])
#functions inverses
@test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))])
@test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0])
#strange
@test_broken correctAns(solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x),[acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)])
#product
@test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0])
end
function hasFloat(expr)#make sure answer does not contain any strange floats
if expr isa Float64
return !isinteger(expr) && expr != float(pi) && expr != exp(1.0)
elseif expr isa Equation
return hasFloat(expr.lhs) || hasFloat(expr.rhs)
elseif istree(expr)
elements = arguments(expr)
for element in elements
if hasFloat(element)
return true
end
end
end
return false
end
correctAns(p,a) = isapprox(sort(Symbolics.convert_solutions_to_floats(p)), a) && !hasFloat(p)

@syms x y z a b c

#quadratics
@test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0])
@test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)])
@test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)])
@test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)])
#lambert w
@test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))])
@test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)])
#more challenging quadratics
@test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0])
@test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0])
#functions inverses
@test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))])
@test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0])
r = solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x)
if r !== nothing
@test correctAns(r, [acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)])
end
#product
@test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0])
end
2 changes: 1 addition & 1 deletion test/stencils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
end

@test iszero(scalarize(y[1,1]))
test_funcs("stencil-extents", y, x, broken=broken)
test_funcs("stencil-extents", y, x)

@variables u[1:5, 1:5]
n = 5
Expand Down
Loading