diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index 4dde6531..a4692edb 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -1502,6 +1502,32 @@ class Relational(Boolean): def is_Relational(self): return True + def __bool__(self): + # We will narrow down the boolean value of our relational with some simple checks + # Get the Left- and Right-hand-sides of the relation, since two expressions are equal if their difference + # is equal to 0. + # If the expand method will not cancel out free symbols in the given expression, then this + # will throw a TypeError. + lhs, rhs = self.args + difference = (lhs - rhs).expand() + + if len(difference.free_symbols): + # If there are any free symbols, then boolean evaluation is ambiguous in most cases. Throw a Type Error + raise TypeError(f'Relational with free symbols cannot be cast as bool: {self}') + else: + # Instantiating relationals that are obviously True or False (according to symengine) will automatically + # simplify to BooleanTrue or BooleanFalse + relational_type = type(self) + simplified = relational_type(difference, S.Zero) + if isinstance(simplified, BooleanAtom): + return bool(simplified) + # If we still cannot determine whether or not the relational is true, then we can either outsource the + # evaluation to sympy (if available) or raise a ValueError expressing that the evaluation is unclear. + try: + return bool(self.simplify()) + except ImportError: + raise ValueError(f'Boolean evaluation is unclear for relational: {self}') + Rel = Relational diff --git a/symengine/tests/CMakeLists.txt b/symengine/tests/CMakeLists.txt index ebd4dfaa..ee04a9be 100644 --- a/symengine/tests/CMakeLists.txt +++ b/symengine/tests/CMakeLists.txt @@ -9,6 +9,7 @@ install(FILES __init__.py test_matrices.py test_ntheory.py test_printing.py + test_relationals.py test_sage.py test_series_expansion.py test_sets.py diff --git a/symengine/tests/test_eval.py b/symengine/tests/test_eval.py index 1ea2b51f..c1cca41f 100644 --- a/symengine/tests/test_eval.py +++ b/symengine/tests/test_eval.py @@ -16,7 +16,7 @@ def test_eval_double2(): x = Symbol("x") e = sin(x)**2 + sqrt(2) raises(RuntimeError, lambda: e.n(real=True)) - assert abs(e.n() - x**2 - 1.414) < 1e-3 + assert abs(e.n() - sin(x)**2.0 - 1.414) < 1e-3 def test_n(): x = Symbol("x") diff --git a/symengine/tests/test_relationals.py b/symengine/tests/test_relationals.py new file mode 100644 index 00000000..81b2e483 --- /dev/null +++ b/symengine/tests/test_relationals.py @@ -0,0 +1,138 @@ +from symengine.utilities import raises +from symengine import (Symbol, sympify, Eq, Ne, Lt, Le, Ge, Gt, sqrt, pi) + +from unittest.case import SkipTest + +try: + import sympy + HAVE_SYMPY = True +except ImportError: + HAVE_SYMPY = False + + +def assert_equal(x, y): + """Asserts that x and y are equal. This will test Equality, Unequality, LE, and GE classes.""" + assert bool(Eq(x, y)) + assert not bool(Ne(x, y)) + assert bool(Ge(x, y)) + assert bool(Le(x, y)) + + +def assert_not_equal(x, y): + """Asserts that x and y are not equal. This will test Equality and Unequality""" + assert not bool(Eq(x, y)) + assert bool(Ne(x, y)) + + +def assert_less_than(x, y): + """Asserts that x is less than y. This will test Le, Lt, Ge, Gt classes.""" + assert bool(Le(x, y)) + assert bool(Lt(x, y)) + assert not bool(Ge(x, y)) + assert not bool(Gt(x, y)) + + +def assert_greater_than(x, y): + """Asserts that x is greater than y. This will test Le, Lt, Ge, Gt classes.""" + assert not bool(Le(x, y)) + assert not bool(Lt(x, y)) + assert bool(Ge(x, y)) + assert bool(Gt(x, y)) + + +def test_equals_constants_easy(): + assert_equal(3, 3) + assert_equal(4, 2 ** 2) + + +def test_equals_constants_hard(): + # Short and long are symbolically equivalent, but sufficiently different in form that expand() does not + # catch it. Ideally, our equality should still catch these, but until symengine supports as robust simplification as + # sympy, we can forgive failing, as long as it raises a ValueError + short = sympify('(3/2)*sqrt(11 + sqrt(21))') + long = sympify('sqrt((33/8 + (1/24)*sqrt(27)*sqrt(63))**2 + ((3/8)*sqrt(27) + (-1/8)*sqrt(63))**2)') + assert_equal(short, short) + assert_equal(long, long) + if HAVE_SYMPY: + assert_equal(short, long) + else: + raises(ValueError, lambda: bool(Eq(short, long))) + + +def test_not_equals_constants(): + assert_not_equal(3, 4) + assert_not_equal(4, 4 - .000000001) + + +def test_equals_symbols(): + x = Symbol("x") + y = Symbol("y") + assert_equal(x, x) + assert_equal(x ** 2, x * x) + assert_equal(x * y, y * x) + + +def test_not_equals_symbols(): + x = Symbol("x") + y = Symbol("y") + assert_not_equal(x, x + 1) + assert_not_equal(x ** 2, x ** 2 + 1) + assert_not_equal(x * y, y * x + 1) + + +def test_not_equals_symbols_raise_typeerror(): + x = Symbol("x") + y = Symbol("y") + raises(TypeError, lambda: bool(Eq(x, 1))) + raises(TypeError, lambda: bool(Eq(x, y))) + raises(TypeError, lambda: bool(Eq(x ** 2, x))) + + +def test_less_than_constants_easy(): + assert_less_than(1, 2) + assert_less_than(-1, 1) + + +def test_less_than_constants_hard(): + # Each of the below pairs are distinct numbers, with the one on the left less than the one on the right. + # Ideally, Less-than will catch this when evaluated, but until symengine has a more robust simplification, + # we can forgive a failure to evaluate as long as it raises a ValueError. + if HAVE_SYMPY: + assert_less_than(sqrt(2), 2) + assert_less_than(3.14, pi) + else: + raises(ValueError, lambda: bool(Lt(sqrt(2), 2))) + raises(ValueError, lambda: bool(Lt(3.14, pi))) + + +def test_greater_than_constants(): + assert_greater_than(2, 1) + assert_greater_than(1, -1) + + +def test_greater_than_constants_hard(): + # Each of the below pairs are distinct numbers, with the one on the left less than the one on the right. + # Ideally, Greater-than will catch this when evaluated, but until symengine has a more robust simplification, + # we can forgive a failure to evaluate as long as it raises a ValueError. + if HAVE_SYMPY: + assert_greater_than(2, sqrt(2)) + assert_greater_than(pi, 3.14) + else: + raises(ValueError, lambda: bool(Gt(2, sqrt(2)))) + raises(ValueError, lambda: bool(Gt(pi, 3.14))) + + +def test_less_than_raises_typeerror(): + x = Symbol("x") + y = Symbol("y") + raises(TypeError, lambda: bool(Lt(x, 1))) + raises(TypeError, lambda: bool(Lt(x, y))) + raises(TypeError, lambda: bool(Lt(x ** 2, x))) + + +def test_greater_than_raises_typeerror(): + x = Symbol("x") + y = Symbol("y") + raises(TypeError, lambda: bool(Gt(x, 1))) + raises(TypeError, lambda: bool(Gt(x, y))) + raises(TypeError, lambda: bool(Gt(x ** 2, x)))