From 8c552b13ddde57fd12dec07e16e85561a886b2fa Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 13:04:45 -0400 Subject: [PATCH 1/5] Add `set_scalar_metadata` to set metadata scalar-wise --- src/variable.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 11bee6251..4fba6cead 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -28,23 +28,22 @@ function recurse_and_apply(f, x) end end -function setdefaultval(x, val) +function set_scalar_metadata(x, V, val) if symtype(x) <: AbstractArray if val isa AbstractArray getindex_posthook(x) do r,x,i... - setdefaultval(r, val[i...]) + set_scalar_metadata(r, V, val[i...]) end else getindex_posthook(x) do r,x,i... - setdefaultval(r, val) + set_scalar_metadata(r, V, val) end end else - setmetadata(x, - VariableDefaultValue, - val) + setmetadata(x, V, val) end end +setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val) struct GetindexParent end From 1452fba51e8cad27c97d60e7cae038ceadbb536c Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 13:23:40 -0400 Subject: [PATCH 2/5] Automatic vectorization of array valued md to array variables --- src/variable.jl | 7 +++---- test/arrays.jl | 6 +++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 4fba6cead..19f101503 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -30,7 +30,7 @@ end function set_scalar_metadata(x, V, val) if symtype(x) <: AbstractArray - if val isa AbstractArray + x = if val isa AbstractArray getindex_posthook(x) do r,x,i... set_scalar_metadata(r, V, val[i...]) end @@ -39,9 +39,8 @@ function set_scalar_metadata(x, V, val) set_scalar_metadata(r, V, val) end end - else - setmetadata(x, V, val) end + setmetadata(x, V, val) end setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val) @@ -225,7 +224,7 @@ function setprops_expr(expr, props, macroname, varname) lhs, rhs = opt.args @assert lhs isa Symbol "the lhs of an option must be a symbol" - expr = :($setmetadata($expr, + expr = :($set_scalar_metadata($expr, $(option_to_metadata_type(Val{lhs}())), $rhs)) end diff --git a/test/arrays.jl b/test/arrays.jl index f2ba539b0..07157333f 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -55,15 +55,19 @@ end # https://github.com/JuliaSymbolics/Symbolics.jl/issues/842 # getindex should keep metadata - @variables tv v(tv)[1:2] [test_meta = 4] + @variables tv v(tv)[1:2] [test_meta = 4] v2(tv)[1:3] [test_meta=[1, 2, 3]] @test !isnothing(metadata(unwrap(v))) @test hasmetadata(unwrap(v), TestMetaT) @test getmetadata(unwrap(v), TestMetaT) == 4 + @test getmetadata(unwrap(v2), TestMetaT) == [1, 2, 3] vs = scalarize(v) vsw = unwrap.(vs) + vs2 = scalarize(v2) + vsw2 = unwrap.(vs2) @test !isnothing(metadata(vsw[1])) @test hasmetadata(vsw[1], TestMetaT) @test getmetadata(vsw[1], TestMetaT) == 4 + @test getmetadata.(vsw2, TestMetaT) == [1, 2, 3] @test !isnothing(metadata(unwrap(v[1]))) @test hasmetadata(unwrap(v[1]), TestMetaT) @test getmetadata(unwrap(v[1]), TestMetaT) == 4 From d72220a44f42fc6566b5d5fb3096b5bfc9af4403 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 13:28:09 -0400 Subject: [PATCH 3/5] Update tests --- test/macro.jl | 2 +- test/stencils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/macro.jl b/test/macro.jl index d102eb796..d89302ef0 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -14,7 +14,7 @@ many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 @test all(t->getsource(t)[1] === :variables, many_vars) @test getdefaultval(t) == 0 @test getdefaultval(a) == 1 -@test_throws ErrorException getdefaultval(x) +@test getdefaultval(x) == 2 @test getdefaultval(x[1]) == 2 @test getdefaultval(y[2]) == 3 @test getdefaultval(w[2]) == 2 diff --git a/test/stencils.jl b/test/stencils.jl index 6b6666b33..df5bdea12 100644 --- a/test/stencils.jl +++ b/test/stencils.jl @@ -55,7 +55,7 @@ end end @test iszero(scalarize(y[1,1])) - test_funcs("stencil-extents", y, x, broken=broken) + test_funcs("stencil-extents", y, x) @variables u[1:5, 1:5] n = 5 From c79fbf17ed061893f185bf8e817ea9cda7a13ee4 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 14:09:39 -0400 Subject: [PATCH 4/5] More robust tests --- test/parsing.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/parsing.jl b/test/parsing.jl index 792483631..7e37e4141 100644 --- a/test/parsing.jl +++ b/test/parsing.jl @@ -3,7 +3,7 @@ using Symbolics, Test ex = [:(y ~ x) :(y ~ -2x + 3 / z) :(z ~ 2)] -eqs = parse_expr_to_symbolic.(ex, (Main,)) +eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,)) @variables x y z ex = [y ~ x @@ -14,9 +14,9 @@ ex = [y ~ x ex = [:(b(t) ~ a(t)) :(b(t) ~ -2a(t) + 3 / c(t)) :(c(t) ~ 2)] -eqs = parse_expr_to_symbolic.(ex, (Main,)) +eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,)) @variables t a(t) b(t) c(t) ex = [b ~ a b ~ -2a + 3 / c c ~ 2] -@test_broken all(isequal.(eqs,ex)) \ No newline at end of file +@test_broken all(isequal.(eqs,ex)) From 1a74a87861d3b3d9829038571038020624eebe0f Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 4 Oct 2023 14:28:21 -0400 Subject: [PATCH 5/5] Robust tests --- test/solver.jl | 77 +++++++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/test/solver.jl b/test/solver.jl index a29187739..1a510c7f7 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -6,44 +6,43 @@ using LambertW #Testing @testset "solving tests" begin - - function hasFloat(expr)#make sure answer does not contain any strange floats - if expr isa Float64 - return !isinteger(expr) && expr != float(pi) && expr != exp(1.0) - elseif expr isa Equation - return hasFloat(expr.lhs) || hasFloat(expr.rhs) - elseif istree(expr) - elements = arguments(expr) - for element in elements - if hasFloat(element) - return true - end - end - end - return false - end - correctAns(p,a) = isequal(sort(Symbolics.convert_solutions_to_floats(p)),a) && !hasFloat(p) - - @syms x y z a b c - - #quadratics - @test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0]) - @test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)]) - @test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)]) - @test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)]) - #lambert w - @test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))]) - @test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)]) - #more challenging quadratics - @test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0]) - @test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0]) - #functions inverses - @test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))]) - @test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0]) - #strange - @test_broken correctAns(solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x),[acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)]) - #product - @test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0]) -end + function hasFloat(expr)#make sure answer does not contain any strange floats + if expr isa Float64 + return !isinteger(expr) && expr != float(pi) && expr != exp(1.0) + elseif expr isa Equation + return hasFloat(expr.lhs) || hasFloat(expr.rhs) + elseif istree(expr) + elements = arguments(expr) + for element in elements + if hasFloat(element) + return true + end + end + end + return false + end + correctAns(p,a) = isapprox(sort(Symbolics.convert_solutions_to_floats(p)), a) && !hasFloat(p) + @syms x y z a b c + #quadratics + @test correctAns(solve_single_eq(x^2~4,x),[-2.0,2.0]) + @test correctAns(solve_single_eq(x^2~2,x),[-sqrt(2.0),sqrt(2.0)]) + @test correctAns(solve_single_eq(x^2~32,x),[-sqrt(32.0),sqrt(32.0)]) + @test correctAns(solve_single_eq(x^3~32,x),[32.0^(1.0/3.0)]) + #lambert w + @test correctAns(solve_single_eq(x^x~2,x),[log(2.0)/lambertw(log(2.0))]) + @test correctAns(solve_single_eq(2*x*exp(x)~3,x),[LambertW.lambertw(3.0/2.0)]) + #more challenging quadratics + @test correctAns(solve_single_eq(x+sqrt(1+x)~5,x),[3.0]) + @test correctAns(solve_single_eq(2*x^2-6*x-7~0,x),[(3.0/2.0)-sqrt(23.0)/2.0,(3.0/2.0)+sqrt(23.0)/2.0]) + #functions inverses + @test correctAns(solve_single_eq(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))]) + @test correctAns(solve_single_eq(sin(x+3)~1//3,x),[asin(1.0/3.0)-3.0]) + r = solve_single_eq(sin(x+2//5)+cos(x+2//5)~1//2,x) + if r !== nothing + @test correctAns(r, [acos(0.5/sqrt(2.0))+3.141592653589793/4.0-(2.0/5.0)]) + end + #product + @test correctAns(solve_single_eq((x^2-4)*(x+1)~0,x),[-2.0,-1.0,2.0]) +end