From afa425ceb096439a4ff601e5145532d557edf0e3 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Sun, 6 Oct 2024 15:52:22 +0200 Subject: [PATCH] add error and test for #255 (#258) --- src/types.jl | 23 ++++++++++++++++++++--- test/API.jl | 11 +++++++++++ test/runtests.jl | 4 ++++ 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 test/API.jl diff --git a/src/types.jl b/src/types.jl index 94e2ae04..f94477e6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -46,11 +46,28 @@ mutable struct DifferentialEquation return DifferentialEquation(exprs .~ Int(0), vars) end - function DifferentialEquation(arg1, arg2) - return DifferentialEquation( - arg1 isa Vector ? arg1 : [arg1], arg2 isa Vector ? arg2 : [arg2] + function DifferentialEquation(eq::Equation, var::Num) + typerhs = typeof(eq.rhs) + typelhs = typeof(eq.lhs) + if eq.rhs isa AbstractVector || eq.lhs isa AbstractVector + throw( + ArgumentError( + "The equation is of the form $(typerhs)~$(typelhs) is not supported. Commenly one forgot to broadcast the equation symbol `~`.", + ), + ) + end + return DifferentialEquation([eq], [var]) + end + function DifferentialEquation(eq::Equation, vars::Vector{Num}) + typerhs = typeof(eq.rhs) + typelhs = typeof(eq.lhs) + throw( + ArgumentError( + "The variables are of type $(typeof(vars)). Whereas you gave one equation of type $(typerhs)~$(typelhs). Commenly one forgot to broadcast the equation symbol `~`.", + ), ) end + DifferentialEquation(lhs::Num, var::Num) = DifferentialEquation([lhs ~ Int(0)], [var]) end function Base.show(io::IO, diff_eq::DifferentialEquation) diff --git a/test/API.jl b/test/API.jl new file mode 100644 index 00000000..158cc14b --- /dev/null +++ b/test/API.jl @@ -0,0 +1,11 @@ +using HarmonicBalance + +# define equation of motion +@variables ω1, ω2, t, ω, F, γ, α1, α2, k, x(t), y(t); +rhs = [ + d(x, t, 2) + ω1^2 * x + γ * d(x, t) + α1 * x^3 - k * y, + d(d(y, t), t) + ω2^2 * y + γ * d(y, t) + α2 * y^3 - k * x, +] +eqs = rhs .~ [F * cos(ω * t), 0] + +@test_throws ArgumentError DifferentialEquation(rhs ~ [F * cos(ω * t), 0], [x, y]) diff --git a/test/runtests.jl b/test/runtests.jl index 3450b385..1716d5b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,10 @@ end JET.test_package(HarmonicBalance; target_defined_modules=true) end +@testset "Symbolics customised" begin + include("API.jl") +end + @testset "Symbolics customised" begin include("powers.jl") include("harmonics.jl")