Skip to content

Commit

Permalink
Merge pull request #984 from JuliaSymbolics/myb/md
Browse files Browse the repository at this point in the history
Add `set_scalar_metadata` to set metadata scalar-wise
  • Loading branch information
YingboMa authored Oct 4, 2023
2 parents 0c91be9 + 1a74a87 commit 6e33cfc
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 54 deletions.
16 changes: 7 additions & 9 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ function recurse_and_apply(f, x)
end
end

function setdefaultval(x, val)
function set_scalar_metadata(x, V, val)
if symtype(x) <: AbstractArray
if val isa AbstractArray
x = if val isa AbstractArray
getindex_posthook(x) do r,x,i...
setdefaultval(r, val[i...])
set_scalar_metadata(r, V, val[i...])
end
else
getindex_posthook(x) do r,x,i...
setdefaultval(r, val)
set_scalar_metadata(r, V, val)
end
end
else
setmetadata(x,
VariableDefaultValue,
val)
end
setmetadata(x, V, val)
end
setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val)

struct GetindexParent end

Expand Down Expand Up @@ -226,7 +224,7 @@ function setprops_expr(expr, props, macroname, varname)
lhs, rhs = opt.args

@assert lhs isa Symbol "the lhs of an option must be a symbol"
expr = :($setmetadata($expr,
expr = :($set_scalar_metadata($expr,
$(option_to_metadata_type(Val{lhs}())),
$rhs))
end
Expand Down
6 changes: 5 additions & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ end

# https://github.com/JuliaSymbolics/Symbolics.jl/issues/842
# getindex should keep metadata
@variables tv v(tv)[1:2] [test_meta = 4]
@variables tv v(tv)[1:2] [test_meta = 4] v2(tv)[1:3] [test_meta=[1, 2, 3]]
@test !isnothing(metadata(unwrap(v)))
@test hasmetadata(unwrap(v), TestMetaT)
@test getmetadata(unwrap(v), TestMetaT) == 4
@test getmetadata(unwrap(v2), TestMetaT) == [1, 2, 3]
vs = scalarize(v)
vsw = unwrap.(vs)
vs2 = scalarize(v2)
vsw2 = unwrap.(vs2)
@test !isnothing(metadata(vsw[1]))
@test hasmetadata(vsw[1], TestMetaT)
@test getmetadata(vsw[1], TestMetaT) == 4
@test getmetadata.(vsw2, TestMetaT) == [1, 2, 3]
@test !isnothing(metadata(unwrap(v[1])))
@test hasmetadata(unwrap(v[1]), TestMetaT)
@test getmetadata(unwrap(v[1]), TestMetaT) == 4
Expand Down
2 changes: 1 addition & 1 deletion test/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5
@test all(t->getsource(t)[1] === :variables, many_vars)
@test getdefaultval(t) == 0
@test getdefaultval(a) == 1
@test_throws ErrorException getdefaultval(x)
@test getdefaultval(x) == 2
@test getdefaultval(x[1]) == 2
@test getdefaultval(y[2]) == 3
@test getdefaultval(w[2]) == 2
Expand Down
6 changes: 3 additions & 3 deletions test/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Symbolics, Test
ex = [:(y ~ x)
:(y ~ -2x + 3 / z)
:(z ~ 2)]
eqs = parse_expr_to_symbolic.(ex, (Main,))
eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,))

@variables x y z
ex = [y ~ x
Expand All @@ -14,9 +14,9 @@ ex = [y ~ x
ex = [:(b(t) ~ a(t))
:(b(t) ~ -2a(t) + 3 / c(t))
:(c(t) ~ 2)]
eqs = parse_expr_to_symbolic.(ex, (Main,))
eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,))
@variables t a(t) b(t) c(t)
ex = [b ~ a
b ~ -2a + 3 / c
c ~ 2]
@test_broken all(isequal.(eqs,ex))
@test_broken all(isequal.(eqs,ex))
77 changes: 38 additions & 39 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,43 @@ using LambertW
#Testing

@testset "solving tests" begin

function hasFloat(expr)#make sure answer does not contain any strange floats
if expr isa Float64
return !isinteger(expr) && expr != float(pi) && expr != exp(1.0)
elseif expr isa Equation
return hasFloat(expr.lhs) || hasFloat(expr.rhs)
elseif istree(expr)
elements = arguments(expr)
for element in elements
if hasFloat(element)
return true
end
end
end
return false
end
correctAns(p,a) = isequal(sort(Symbolics.convert_solutions_to_floats(p)),a) && !hasFloat(p)

@syms x y z a b c

#quadratics
@test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0])
@test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)])
@test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)])
@test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)])
#lambert w
@test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))])
@test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)])
#more challenging quadratics
@test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0])
@test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0])
#functions inverses
@test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))])
@test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0])
#strange
@test_broken correctAns(solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x),[acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)])
#product
@test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0])
end
function hasFloat(expr)#make sure answer does not contain any strange floats
if expr isa Float64
return !isinteger(expr) && expr != float(pi) && expr != exp(1.0)
elseif expr isa Equation
return hasFloat(expr.lhs) || hasFloat(expr.rhs)
elseif istree(expr)
elements = arguments(expr)
for element in elements
if hasFloat(element)
return true
end
end
end
return false
end
correctAns(p,a) = isapprox(sort(Symbolics.convert_solutions_to_floats(p)), a) && !hasFloat(p)

@syms x y z a b c

#quadratics
@test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0])
@test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)])
@test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)])
@test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)])
#lambert w
@test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))])
@test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)])
#more challenging quadratics
@test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0])
@test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0])
#functions inverses
@test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))])
@test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0])
r = solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x)
if r !== nothing
@test correctAns(r, [acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)])
end
#product
@test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0])
end
2 changes: 1 addition & 1 deletion test/stencils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
end

@test iszero(scalarize(y[1,1]))
test_funcs("stencil-extents", y, x, broken=broken)
test_funcs("stencil-extents", y, x)

@variables u[1:5, 1:5]
n = 5
Expand Down

0 comments on commit 6e33cfc

Please sign in to comment.