From fd309457511e78cb8c14a5ef6d01ca64c6ff59ef Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Sun, 5 May 2024 12:09:58 +1000 Subject: [PATCH] Main change to the semantics of Predicate comparison Predicate comparison now uses the underlying clingo.Symbol object. Will allow querying with ordering to work even when the types are incomparable because Symbol objects are always comparable. Still needs to update some of the querying behaviour --- clorm/orm/core.py | 198 +++++++++++++++++-------------------- clorm/orm/templating.py | 95 +++++++++++++++++- tests/support.py | 13 +++ tests/test_orm_core.py | 105 +++++++++++--------- tests/test_orm_factbase.py | 10 +- tests/test_orm_query.py | 8 +- 6 files changed, 266 insertions(+), 163 deletions(-) diff --git a/clorm/orm/core.py b/clorm/orm/core.py index 487fa7f..2e1b588 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -1,12 +1,21 @@ -# ----------------------------------------------------------------------------- -# Implementation of the core part of the Clorm ORM. In particular this provides -# the base classes and metaclasses for the definition of fields, predicates, -# predicate paths, and the specification of query conditions. Note: query -# condition specification is provided here because the predicate path comparison -# operators are overloads to return these objects. However, the rest of the -# query API is specified with the FactBase and select querying mechanisms -# (see factbase.py). -# ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------- +# Implementation of the core part of the Clorm ORM. In particular this provides the base +# classes and metaclasses for the definition of fields, predicates, predicate paths, and the +# specification of query conditions. Note: query condition specification is provided here +# because the predicate path comparison operators are overloads to return these +# objects. However, the rest of the query API is specified with the FactBase and select +# querying mechanisms (see factbase.py). +# ------------------------------------------------------------------------------------------- + +# ------------------------------------------------------------------------------------------- +# NOTE: 20242028 the semantics for the comparison operators has changed. Instead of using the +# python field representation we use the underlying clingo symbol object. The symbol object is +# well defined for any comparison between symbols, whereas tuples are only well defined if the +# types of the individual parameters are compatible. So this change leads to more natural +# behaviour for the queries. Note: users should avoid defining unintuitive fields (for example +# a swap field that changes the sign of an int) to avoid unintuitive Python behaviour. +# ------------------------------------------------------------------------------------------- + from __future__ import annotations @@ -1341,7 +1350,6 @@ class DateField(StringField): ``FactBase```. Defaults to ``False``. """ - def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None: self._index = index if index is not MISSING else False @@ -1349,33 +1357,35 @@ def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None: self._default = (False, None) return - self._default = (True, default) cmplx = self.complex - # Check and convert the default to a valid value. Note: if the default - # is a callable then we can't do this check because it could break a - # counter type procedure. - if callable(default) or (cmplx and isinstance(default, cmplx)): - return + def _process_basic_value(v): + return v - try: - if cmplx: + def _process_cmplx_value(v): + if isinstance(v, cmplx): + return v + if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple): + return cmplx(*v) + raise TypeError(f"Value {v} ({type(v)}) cannot be converted to type {cmplx}") - def _instance_from_tuple(v): - if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple): - return cmplx(*v) - raise TypeError(f"Value {v} ({type(v)}) is not a tuple") + _process_value = _process_basic_value if cmplx is None else _process_cmplx_value + + # If the default is not a factory function than make sure the value can be converted to + # clingo without error. + if not callable(default): + try: + self._default = (True, _process_value(default)) + self.pytocl(self._default[1]) + except (TypeError, ValueError): + raise TypeError( + 'Invalid default value "{}" for {}'.format(default, type(self).__name__) + ) + else: + def _process_default(): + return _process_value(default()) + self._default = (True, _process_default) - if cmplx.meta.is_tuple: - self._default = (True, _instance_from_tuple(default)) - else: - raise ValueError("Bad default") - else: - self.pytocl(default) - except (TypeError, ValueError): - raise TypeError( - 'Invalid default value "{}" for {}'.format(default, type(self).__name__) - ) @staticmethod @abc.abstractmethod @@ -1503,11 +1513,11 @@ def field( raise TypeError(f"{basefield} can just be of Type '{BaseField}' or '{Sequence}'") -# ------------------------------------------------------------------------------ -# RawField is a sub-class of BaseField for storing Symbol or NoSymbol -# objects. The behaviour of Raw with respect to using clingo.Symbol or -# noclingo.NoSymbol is modified by the symbol mode (get_symbol_mode()) -# ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------ +# RawField is a sub-class of BaseField for storing Symbol objects. The behaviour of Raw with +# respect to using clingo.Symbol or noclingo.Symbol is modified by the symbol mode +# (get_symbol_mode()) +# ------------------------------------------------------------------------------------------ class Raw(object): @@ -2441,13 +2451,9 @@ def get_field_definition(defn: Any, module: str = "") -> BaseField: def _create_complex_term(defn: Any, default_value: Any = MISSING, module: str = "") -> BaseField: - # NOTE: I was using a dict rather than OrderedDict which just happened to - # work. Apparently, in Python 3.6 this was an implmentation detail and - # Python 3.7 it is a language specification (see: - # https://stackoverflow.com/questions/1867861/how-to-keep-keys-values-in-same-order-as-declared/39537308#39537308). - # However, since Clorm is meant to be Python 3.5 compatible change this to - # use an OrderedDict. - # proto = { "arg{}".format(i+1) : get_field_definition(d) for i,d in enumerate(defn) } + # NOTE: relies on a dict preserving insertion order - this is true from Python 3.7+. Python + # 3.7 is already end-of-life so there is no longer a reason to use OrderedDict. + #proto = {f"arg{idx+1}": get_field_definition(dn) for idx, dn in enumerate(defn)} proto: Dict[str, Any] = collections.OrderedDict( [(f"arg{i+1}", get_field_definition(d, module)) for i, d in enumerate(defn)] ) @@ -2705,13 +2711,16 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N gdict = { "Predicate": Predicate, + "Symbol": Symbol, "Function": Function, "MISSING": MISSING, "AnySymbol": AnySymbol, "Type": Type, + "Any": Any, "Optional": Optional, "Sequence": Sequence, "_P": _P, + "PREDICATE_IS_TUPLE": pdefn.is_tuple, } for f in pdefn: @@ -2786,21 +2795,26 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N template = PREDICATE_TEMPLATE.format(pdefn=pdefn) predicate_functions = expand_template(template, **expansions) - # print(f"INIT:\n\n{predicate_functions}\n\n") +# print(f"INIT {class_name}:\n\n{predicate_functions}\n\n") ldict: Dict[str, Any] = {} exec(predicate_functions, gdict, ldict) - init_doc_args = f"{args_signature}*, sign=True, raw=None" - predicate_init = ldict["__init__"] - predicate_init.__name__ = "__init__" - predicate_init.__doc__ = f"{class_name}({init_doc_args})" - predicate_unify = ldict["_unify"] - predicate_unify.__name__ = "_unify" - predicate_unify.__doc__ = PREDICATE_UNIFY_DOCSTRING + def _set_fn(fname: str, docstring: str): + tmp = ldict[fname] + tmp.__name__ = fname + tmp.__doc = docstring + namespace[fname] = tmp - namespace["__init__"] = predicate_init - namespace["_unify"] = predicate_unify + # Assign the __init__, _unify, __hash__, and appropriate comparison functions + _set_fn("__init__", f"{class_name}({args_signature}*, sign=True, raw=None)") + _set_fn("_unify", PREDICATE_UNIFY_DOCSTRING) + _set_fn("__hash__", "Hash operator") + _set_fn("__eq__", "Equality operator") + _set_fn("__lt__", "Less than operator") + _set_fn("__le__", "Less than or equal operator") + _set_fn("__gt__", "Greater than operator") + _set_fn("__ge__", "Greater than operator") # ------------------------------------------------------------------------------ @@ -3160,6 +3174,7 @@ def _cltopy(v): # ------------------------------------------------------------------------------ # A Metaclass for the Predicate base class # ------------------------------------------------------------------------------ + @__dataclass_transform__(field_descriptors=(field,)) class _PredicateMeta(type): if TYPE_CHECKING: @@ -3244,8 +3259,12 @@ def __iter__(self) -> Iterator[PredicatePath]: # underlying Symbol object. # ------------------------------------------------------------------------------ +# Mixin class to be able to use both MetaClasses +class _AbstractPredicateMeta(abc.ABCMeta, _PredicateMeta): + pass + -class Predicate(object, metaclass=_PredicateMeta): +class Predicate(object, metaclass=_AbstractPredicateMeta): """Abstract base class to encapsulate an ASP predicate or complex term. This is the heart of the ORM model for defining the mapping of a predicate @@ -3324,7 +3343,7 @@ def _unify( def symbol(self): """Returns the Symbol object corresponding to the fact. - The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.NoSymbol``. + The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.Symbol``. """ return self._raw @@ -3413,74 +3432,35 @@ def __neg__(self): # -------------------------------------------------------------------------- # Overloaded operators # -------------------------------------------------------------------------- + @abc.abstractmethod def __eq__(self, other): """Overloaded boolean operator.""" - if isinstance(other, self.__class__): - return self._field_values == other._field_values and self._sign == other._sign - if self.meta.is_tuple: - return self._field_values == other - elif isinstance(other, Predicate): - return False - return NotImplemented + raise NotImplementedError("Predicate.__eq__() must be overriden") + @abc.abstractmethod def __lt__(self, other): """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__lt__() must be overriden") - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - - # Negative literals are less than positive literals - if self.sign != other.sign: - return self.sign < other.sign - - return self._field_values < other._field_values - - # If different predicates then compare the raw value - elif isinstance(other, Predicate): - return self.raw < other.raw - - # Else an error - return NotImplemented + def __le__(self, other): + """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__le__() must be overriden") + @abc.abstractmethod def __ge__(self, other): """Overloaded boolean operator.""" - result = self.__lt__(other) - if result is NotImplemented: - return NotImplemented - return not result + raise NotImplementedError("Predicate.__ge__() must be overriden") + @abc.abstractmethod def __gt__(self, other): """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__gt__() must be overriden") - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - # Positive literals are greater than negative literals - if self.sign != other.sign: - return self.sign > other.sign - - return self._field_values > other._field_values - - # If different predicates then compare the raw value - if not isinstance(other, Predicate): - return self.raw > other.raw - - # Else an error - return NotImplemented - - def __le__(self, other): - """Overloaded boolean operator.""" - result = self.__gt__(other) - if result is NotImplemented: - return NotImplemented - return not result + @abc.abstractmethod def __hash__(self): - if self._hash is None: - if self.meta.is_tuple: - self._hash = hash(self._field_values) - else: - self._hash = hash((self.meta.name, self._field_values)) - return self._hash + """Overload the hash function.""" + raise NotImplementedError("Predicate.__hash__() must be overriden") def __str__(self): """Returns the Predicate as the string representation of an ASP fact.""" diff --git a/clorm/orm/templating.py b/clorm/orm/templating.py index 8c4c4c3..287c8f0 100644 --- a/clorm/orm/templating.py +++ b/clorm/orm/templating.py @@ -37,16 +37,16 @@ def add_spaces(num, text): lines = template.expandtabs(4).splitlines() outlines = [] for line in lines: - start = line.find("{%") + start = line.find(r"{%") if start == -1: outlines.append(line) continue - end = line.find("%}", start) + end = line.find(r"%}", start) if end == -1: raise ValueError("Bad template expansion in {line}") - keyword = line[start + 2 : end] + keyword = line[start + 2:end] text = add_spaces(start, kwargs[keyword]) - line = line[0:start] + text + line[end + 2 :] + line = line[0:start] + text + line[end + 2:] outlines.append(line) return "\n".join(outlines) @@ -73,6 +73,7 @@ def __init__(self, ({{%args_raw%}}), self._sign) + @classmethod def _unify(cls: Type[_P], raw: AnySymbol, raw_args: Optional[Sequence[AnySymbol]]=None, raw_name: Optional[str]=None) -> Optional[_P]: try: @@ -98,6 +99,92 @@ def _unify(cls: Type[_P], raw: AnySymbol, raw_args: Optional[Sequence[AnySymbol] raise ValueError((f"Cannot unify with object {{raw}} ({{type(raw)}}) as " "it is not a clingo Symbol Function object")) + +def nontuple__eq__(self, other: Any) -> bool: + # Deal with a non-tuple predicate + if isinstance(other, Predicate): + return self._raw == other._raw + if isinstance(other, Symbol): + return self._raw == other + return NotImplemented + + +def tuple__eq__(self, other: Any) -> bool: + # Deal with a predicate that is a tuple + if isinstance(other, Predicate): + return self._raw == other._raw + if isinstance(other, Symbol): + return self._raw == other +# if isinstance(other, tuple): +# return self._field_values == other + return NotImplemented + + +def nontuple__lt__(self, other): + # If it is the same predicate class then compare the underlying clingo symbol + if isinstance(other, Predicate): + return self._raw < other._raw + if isinstance(other, Symbol): + return self._raw < other + return NotImplemented + + +def tuple__lt__(self, other): + # self is always less than a non-tuple predicate + if isinstance(other, Predicate): + return self._raw < other._raw + if isinstance(other, Symbol): + return self._raw < other +# if isinstance(other, tuple): +# return self._field_values < other + return NotImplemented + + +def nontuple__gt__(self, other): + if isinstance(other, Predicate): + return self._raw > other._raw + if isinstance(other, Symbol): + return self._raw > other + return NotImplemented + + +def tuple__gt__(self, other): + # If it is the same predicate class then compare the sign and fields + if isinstance(other, Predicate): + return self._raw > other._raw + if isinstance(other, Symbol): + return self._raw > other +# if isinstance(other, tuple): +# return self._field_values > other + return NotImplemented + + +def __ge__(self, other): + result = self.__lt__(other) + if result is NotImplemented: + return NotImplemented + return not result + + +def __le__(self, other): + result = self.__gt__(other) + if result is NotImplemented: + return NotImplemented + return not result + + +def __hash__(self): + if self._hash is None: + self._hash = hash(self._raw) + return self._hash + + +__eq__ = tuple__eq__ if PREDICATE_IS_TUPLE else nontuple__eq__ +__lt__ = tuple__lt__ if PREDICATE_IS_TUPLE else nontuple__lt__ +__gt__ = tuple__gt__ if PREDICATE_IS_TUPLE else nontuple__gt__ + + + """ CHECK_SIGN_TEMPLATE = r""" diff --git a/tests/support.py b/tests/support.py index d3fd022..2ce3be6 100644 --- a/tests/support.py +++ b/tests/support.py @@ -1,3 +1,5 @@ +from clorm import Predicate + # ------------------------------------------------------------------------------ # Support functions for the unit tests # ------------------------------------------------------------------------------ @@ -19,6 +21,17 @@ def check_errmsg_contains(contmsg, ctx): raise AssertionError(msg) +# ------------------------------------------------------------------------------ +# +# ------------------------------------------------------------------------------ + +def to_tuple(value): + """Recursively convert a predicate/normal tuple into a Python tuple""" + if isinstance(value, tuple) or (isinstance(value, Predicate) and value.meta.is_tuple): + return tuple(to_tuple(x) for x in value) + return value + + # ------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------ diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index aa24145..469a5f7 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -8,6 +8,12 @@ # to be completed. # ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------- +# NOTE: 20242028 See orm/core.py for changes to the semantics of comparison operators for +# Predicate objects. +# ------------------------------------------------------------------------------------------- + + import collections.abc as cabc import datetime import enum @@ -71,7 +77,7 @@ trueall, ) -from .support import check_errmsg, check_errmsg_contains +from .support import check_errmsg, check_errmsg_contains, to_tuple # Error messages for CPython and PyPy vary PYPY = sys.implementation.name == "pypy" @@ -280,10 +286,12 @@ def test_api_field_function(self): self.assertEqual(t, (StringField, IntegerField)) t = field((StringField, IntegerField), default=("3", 4)) + self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, IntegerField) - self.assertEqual(t.default, ("3", 4)) + self.assertEqual(t.default, t.complex("3", 4)) + self.assertEqual(to_tuple(t.default), ("3", 4)) with self.subTest("with custom field"): INLField = define_flat_list_field(IntegerField, name="INLField") @@ -302,8 +310,8 @@ def factory(): return ("3", x) t = field((StringField, IntegerField), default_factory=factory) - self.assertEqual(t.default, ("3", 1)) - self.assertEqual(t.default, ("3", 2)) + self.assertEqual(to_tuple(t.default), ("3", 1)) + self.assertEqual(to_tuple(t.default), ("3", 2)) with self.subTest("with nested tuple and default"): t = field((StringField, (StringField, IntegerField))) @@ -313,7 +321,7 @@ def factory(): self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, BaseField) - self.assertEqual(t.default, ("3", ("1", 4))) + self.assertEqual(to_tuple(t.default), ("3", ("1", 4))) def test_api_field_function_illegal_arguments(self): with self.subTest("illegal basefield type"): @@ -783,7 +791,7 @@ def test_api_nested_list_field_complex_element_field(self): value2 = (2, ("b", "B")) nlist = (value1, value2) - self.assertEqual(XField.cltopy(symnlist), nlist) + self.assertEqual(to_tuple(XField.cltopy(symnlist)), nlist) self.assertEqual(XField.pytocl(nlist), symnlist) # -------------------------------------------------------------------------- @@ -835,7 +843,7 @@ def test_api_flat_list_field_complex_element_field(self): value2 = (2, ("b", "B")) nlist = (value1, value2) - self.assertEqual(XField.cltopy(symnlist), nlist) + self.assertEqual(to_tuple(XField.cltopy(symnlist)), nlist) self.assertEqual(XField.pytocl(nlist), symnlist) # -------------------------------------------------------------------------- @@ -1014,6 +1022,25 @@ class Blah(object): clresult = td.pytocl((1, "blah")) self.assertEqual(clresult, clob) + + # -------------------------------------------------------------------------- + # Test that we define the new comparison operators in the template + # -------------------------------------------------------------------------- + def test_predicate_comparison_operator_creation(self): + + class P(Predicate, name="p"): + a = IntegerField + b = ConstantField + + p1 = P(a=1, b="x") + p2 = P(a=1, b="x") + p3 = P(a=2, b="x") + + tmp = {} + tmp[p1] = "P" + self.assertEqual(p1, p2) + self.assertNotEqual(p1, p3) + # -------------------------------------------------------------------------- # Test that we can define predicates using the class syntax and test that # the getters and setters are connected properly to the predicate classes. @@ -1177,21 +1204,25 @@ class T(Predicate): tuple1 = tuple([1, "a"]) tuple2 = tuple([2, "b"]) - # Equality works even when the types are different + # Equality works even when the predicate types are different self.assertTrue(p1.tuple_ == p1_alt.tuple_) self.assertTrue(p1.tuple_ == q1.tuple_) self.assertTrue(p1.tuple_ == t1.tuple_) - self.assertTrue(p1.tuple_ == tuple1) + + # New behaviour. Doesn't compare directly to tuple + self.assertFalse(p1.tuple_ == tuple1) + self.assertTrue(tuple(p1.tuple_) == tuple1) self.assertNotEqual(type(p1.tuple_), type(t1.tuple_)) - # self.assertNotEqual(type(p1.tuple_), type(t1)) self.assertTrue(p1.tuple_ != p2.tuple_) self.assertTrue(p1.tuple_ != q2.tuple_) self.assertTrue(p1.tuple_ != r2.tuple_) self.assertTrue(p1.tuple_ != s2.tuple_) self.assertTrue(p1.tuple_ != t2.tuple_) - self.assertTrue(p1.tuple_ != tuple2) + + self.assertTrue(p1.tuple_ != tuple1) + self.assertFalse(tuple(p1.tuple_) != tuple1) # -------------------------------------------------------------------------- # Test predicates with default fields @@ -2113,33 +2144,18 @@ def test_predicate_comparison_operator_overload_signed(self): class P(Predicate): a = IntegerField - class Q(Predicate): - a = IntegerField - p1 = P(1) neg_p1 = P(1, sign=False) p2 = P(2) neg_p2 = P(2, sign=False) - q1 = Q(1) - - self.assertTrue(neg_p1 < neg_p2) - self.assertTrue(neg_p1 < p1) - self.assertTrue(neg_p1 < p2) - self.assertTrue(neg_p2 < p1) - self.assertTrue(neg_p2 < p2) - self.assertTrue(p1 < p2) - - self.assertTrue(p2 > p1) - self.assertTrue(p2 > neg_p2) - self.assertTrue(p2 > neg_p1) - self.assertTrue(p1 > neg_p2) - self.assertTrue(p1 > neg_p1) - self.assertTrue(neg_p2 > neg_p1) - # Different predicate sub-classes are incomparable + # NOTE: 20240428 see note at top about change of semantics + self.assertTrue((neg_p1.raw < p1.raw) == (neg_p1 < p1)) + self.assertTrue((neg_p1.raw < p2.raw) == (neg_p1 < p2)) + self.assertTrue((neg_p2.raw < p1.raw) == (neg_p2 < p1)) + self.assertTrue((neg_p2.raw < p2.raw) == (neg_p2 < p2)) + self.assertTrue((p1.raw < p2.raw) == (p1 < p2)) - # with self.assertRaises(TypeError) as ctx: - # self.assertTrue(p1 < q1) # -------------------------------------------------------------------------- # Test a simple predicate with a field that has a function default @@ -2320,9 +2336,9 @@ class Fact(Predicate): with self.assertRaises(TypeError) as ctx: class Fact2(Predicate): - afun = Fun.Field(default=(1, "str")) + afun = Fun.Field(default=6) - check_errmsg("""Invalid default value "(1, 'str')" for FunField""", ctx) + check_errmsg("""Invalid default value "6" for FunField""", ctx) # -------------------------------------------------------------------------- # Test the simple_predicate function as a mechanism for defining @@ -2423,6 +2439,7 @@ class Fact(Predicate): # Test predicate equality # -------------------------------------------------------------------------- def test_predicate_comparison_operator_overloads(self): + # NOTE: 20240428 see note at top about change of semantics f1 = Function("fact", [Number(1)]) f2 = Function("fact", [Number(2)]) @@ -2452,9 +2469,9 @@ class Meta: self.assertEqual(f1, af1.raw) self.assertEqual(af1.raw, f1) self.assertEqual(af1.raw, ag1.raw) - self.assertNotEqual(af1, ag1) - self.assertNotEqual(af1, f1) - self.assertNotEqual(f1, af1) + self.assertEqual(af1, ag1) + self.assertEqual(af1, f1) + self.assertEqual(f1, af1) self.assertTrue(af1 < af2) self.assertTrue(af1 <= af2) @@ -2478,6 +2495,8 @@ class Meta: # Test predicate equality # -------------------------------------------------------------------------- def test_comparison_operator_overloads_complex(self): + # NOTE: 20240428 see note at top about change of semantics + class SwapField(IntegerField): pytocl = lambda x: 100 - x cltopy = lambda x: 100 - x @@ -2497,17 +2516,15 @@ class AComplex(ComplexTerm): for rf in [rf1, rf2, rf3]: self.assertEqual(rf.arguments[0], rf.arguments[1]) - # Test the the comparison operator for the complex term is using the - # swapped values so that the comparison is opposite to what the raw - # field says. self.assertTrue(rf1 < rf2) self.assertTrue(rf2 < rf3) - self.assertTrue(f1 > f2) - self.assertTrue(f2 > f3) - self.assertTrue(f2 < f1) - self.assertTrue(f3 < f2) + self.assertTrue(f1 < f2) + self.assertTrue(f2 < f3) + self.assertTrue(f2 > f1) + self.assertTrue(f3 > f2) self.assertEqual(f3, f4) + # -------------------------------------------------------------------------- # Test unifying a symbol with a predicate # -------------------------------------------------------------------------- diff --git a/tests/test_orm_factbase.py b/tests/test_orm_factbase.py index 15f4be6..9710e6b 100644 --- a/tests/test_orm_factbase.py +++ b/tests/test_orm_factbase.py @@ -1118,6 +1118,9 @@ class Afact(Predicate): # Test that select works with order_by for complex term # -------------------------------------------------------------------------- def test_api_factbase_select_order_by_complex_term(self): + # NOTE: behavior change 20240428 - ordering is based on the underlying clingo.Symbol + # object and not the python translation. So using SwapField won't change the ordering + # for AComplex objects. class SwapField(IntegerField): pytocl = lambda x: 100 - x cltopy = lambda x: 100 - x @@ -1145,10 +1148,13 @@ class AFact(Predicate): self.assertEqual([f1, f2, f3, f4], list(q.get())) q = fb.select(AFact).order_by(AFact.cmplx, AFact.astr) - self.assertEqual([f3, f4, f2, f1], list(q.get())) + self.assertEqual([f1, f2, f3, f4], list(q.get())) q = fb.select(AFact).where(AFact.cmplx <= ph1_).order_by(AFact.cmplx, AFact.astr) - self.assertEqual([f3, f4, f2], list(q.get(cmplx2))) + self.assertEqual([f1, f2], list(q.get(cmplx2))) + + q = fb.select(AFact).where(AFact.cmplx >= ph1_).order_by(AFact.cmplx, AFact.astr) + self.assertEqual([f2, f3, f4], list(q.get(cmplx2))) # -------------------------------------------------------------------------- # Test that select works with order_by for complex term diff --git a/tests/test_orm_query.py b/tests/test_orm_query.py index 46d0b93..a1810db 100644 --- a/tests/test_orm_query.py +++ b/tests/test_orm_query.py @@ -84,7 +84,7 @@ where_expression_to_nnf, ) -from .support import check_errmsg, check_errmsg_contains +from .support import check_errmsg, check_errmsg_contains, to_tuple ###### NOTE: The QueryOutput tests need to be turned into QueryExecutor ###### tests. We can then delete QueryOutput which is not being used for @@ -432,7 +432,7 @@ class F(Predicate): getter = make_input_alignment_functor([F], [F.acomplex]) result = getter((f1,)) tmp = ((1, 2),) - self.assertEqual(result, tmp) + self.assertEqual(to_tuple(result), tmp) # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ @@ -545,9 +545,9 @@ class F(Predicate): getter = make_input_alignment_functor([F], [F.acomplex]) result = getter((f1,)) - self.assertEqual(result, ((1, 2),)) + self.assertEqual(to_tuple(result), ((1, 2),)) - sc = SC(operator.eq, [F.acomplex, (1, 2)]) + sc = SC(operator.eq, [F.acomplex, F.meta[1].defn.complex(1, 2)]) cmp = sc.make_callable([F]) self.assertTrue(cmp((f1,))) self.assertFalse(cmp((f2,)))