diff --git a/clorm/orm/core.py b/clorm/orm/core.py index 487fa7f..2e1b588 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -1,12 +1,21 @@ -# ----------------------------------------------------------------------------- -# Implementation of the core part of the Clorm ORM. In particular this provides -# the base classes and metaclasses for the definition of fields, predicates, -# predicate paths, and the specification of query conditions. Note: query -# condition specification is provided here because the predicate path comparison -# operators are overloads to return these objects. However, the rest of the -# query API is specified with the FactBase and select querying mechanisms -# (see factbase.py). -# ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------- +# Implementation of the core part of the Clorm ORM. In particular this provides the base +# classes and metaclasses for the definition of fields, predicates, predicate paths, and the +# specification of query conditions. Note: query condition specification is provided here +# because the predicate path comparison operators are overloads to return these +# objects. However, the rest of the query API is specified with the FactBase and select +# querying mechanisms (see factbase.py). +# ------------------------------------------------------------------------------------------- + +# ------------------------------------------------------------------------------------------- +# NOTE: 20242028 the semantics for the comparison operators has changed. Instead of using the +# python field representation we use the underlying clingo symbol object. The symbol object is +# well defined for any comparison between symbols, whereas tuples are only well defined if the +# types of the individual parameters are compatible. So this change leads to more natural +# behaviour for the queries. Note: users should avoid defining unintuitive fields (for example +# a swap field that changes the sign of an int) to avoid unintuitive Python behaviour. +# ------------------------------------------------------------------------------------------- + from __future__ import annotations @@ -1341,7 +1350,6 @@ class DateField(StringField): ``FactBase```. Defaults to ``False``. """ - def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None: self._index = index if index is not MISSING else False @@ -1349,33 +1357,35 @@ def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None: self._default = (False, None) return - self._default = (True, default) cmplx = self.complex - # Check and convert the default to a valid value. Note: if the default - # is a callable then we can't do this check because it could break a - # counter type procedure. - if callable(default) or (cmplx and isinstance(default, cmplx)): - return + def _process_basic_value(v): + return v - try: - if cmplx: + def _process_cmplx_value(v): + if isinstance(v, cmplx): + return v + if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple): + return cmplx(*v) + raise TypeError(f"Value {v} ({type(v)}) cannot be converted to type {cmplx}") - def _instance_from_tuple(v): - if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple): - return cmplx(*v) - raise TypeError(f"Value {v} ({type(v)}) is not a tuple") + _process_value = _process_basic_value if cmplx is None else _process_cmplx_value + + # If the default is not a factory function than make sure the value can be converted to + # clingo without error. + if not callable(default): + try: + self._default = (True, _process_value(default)) + self.pytocl(self._default[1]) + except (TypeError, ValueError): + raise TypeError( + 'Invalid default value "{}" for {}'.format(default, type(self).__name__) + ) + else: + def _process_default(): + return _process_value(default()) + self._default = (True, _process_default) - if cmplx.meta.is_tuple: - self._default = (True, _instance_from_tuple(default)) - else: - raise ValueError("Bad default") - else: - self.pytocl(default) - except (TypeError, ValueError): - raise TypeError( - 'Invalid default value "{}" for {}'.format(default, type(self).__name__) - ) @staticmethod @abc.abstractmethod @@ -1503,11 +1513,11 @@ def field( raise TypeError(f"{basefield} can just be of Type '{BaseField}' or '{Sequence}'") -# ------------------------------------------------------------------------------ -# RawField is a sub-class of BaseField for storing Symbol or NoSymbol -# objects. The behaviour of Raw with respect to using clingo.Symbol or -# noclingo.NoSymbol is modified by the symbol mode (get_symbol_mode()) -# ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------ +# RawField is a sub-class of BaseField for storing Symbol objects. The behaviour of Raw with +# respect to using clingo.Symbol or noclingo.Symbol is modified by the symbol mode +# (get_symbol_mode()) +# ------------------------------------------------------------------------------------------ class Raw(object): @@ -2441,13 +2451,9 @@ def get_field_definition(defn: Any, module: str = "") -> BaseField: def _create_complex_term(defn: Any, default_value: Any = MISSING, module: str = "") -> BaseField: - # NOTE: I was using a dict rather than OrderedDict which just happened to - # work. Apparently, in Python 3.6 this was an implmentation detail and - # Python 3.7 it is a language specification (see: - # https://stackoverflow.com/questions/1867861/how-to-keep-keys-values-in-same-order-as-declared/39537308#39537308). - # However, since Clorm is meant to be Python 3.5 compatible change this to - # use an OrderedDict. - # proto = { "arg{}".format(i+1) : get_field_definition(d) for i,d in enumerate(defn) } + # NOTE: relies on a dict preserving insertion order - this is true from Python 3.7+. Python + # 3.7 is already end-of-life so there is no longer a reason to use OrderedDict. + #proto = {f"arg{idx+1}": get_field_definition(dn) for idx, dn in enumerate(defn)} proto: Dict[str, Any] = collections.OrderedDict( [(f"arg{i+1}", get_field_definition(d, module)) for i, d in enumerate(defn)] ) @@ -2705,13 +2711,16 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N gdict = { "Predicate": Predicate, + "Symbol": Symbol, "Function": Function, "MISSING": MISSING, "AnySymbol": AnySymbol, "Type": Type, + "Any": Any, "Optional": Optional, "Sequence": Sequence, "_P": _P, + "PREDICATE_IS_TUPLE": pdefn.is_tuple, } for f in pdefn: @@ -2786,21 +2795,26 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N template = PREDICATE_TEMPLATE.format(pdefn=pdefn) predicate_functions = expand_template(template, **expansions) - # print(f"INIT:\n\n{predicate_functions}\n\n") +# print(f"INIT {class_name}:\n\n{predicate_functions}\n\n") ldict: Dict[str, Any] = {} exec(predicate_functions, gdict, ldict) - init_doc_args = f"{args_signature}*, sign=True, raw=None" - predicate_init = ldict["__init__"] - predicate_init.__name__ = "__init__" - predicate_init.__doc__ = f"{class_name}({init_doc_args})" - predicate_unify = ldict["_unify"] - predicate_unify.__name__ = "_unify" - predicate_unify.__doc__ = PREDICATE_UNIFY_DOCSTRING + def _set_fn(fname: str, docstring: str): + tmp = ldict[fname] + tmp.__name__ = fname + tmp.__doc = docstring + namespace[fname] = tmp - namespace["__init__"] = predicate_init - namespace["_unify"] = predicate_unify + # Assign the __init__, _unify, __hash__, and appropriate comparison functions + _set_fn("__init__", f"{class_name}({args_signature}*, sign=True, raw=None)") + _set_fn("_unify", PREDICATE_UNIFY_DOCSTRING) + _set_fn("__hash__", "Hash operator") + _set_fn("__eq__", "Equality operator") + _set_fn("__lt__", "Less than operator") + _set_fn("__le__", "Less than or equal operator") + _set_fn("__gt__", "Greater than operator") + _set_fn("__ge__", "Greater than operator") # ------------------------------------------------------------------------------ @@ -3160,6 +3174,7 @@ def _cltopy(v): # ------------------------------------------------------------------------------ # A Metaclass for the Predicate base class # ------------------------------------------------------------------------------ + @__dataclass_transform__(field_descriptors=(field,)) class _PredicateMeta(type): if TYPE_CHECKING: @@ -3244,8 +3259,12 @@ def __iter__(self) -> Iterator[PredicatePath]: # underlying Symbol object. # ------------------------------------------------------------------------------ +# Mixin class to be able to use both MetaClasses +class _AbstractPredicateMeta(abc.ABCMeta, _PredicateMeta): + pass + -class Predicate(object, metaclass=_PredicateMeta): +class Predicate(object, metaclass=_AbstractPredicateMeta): """Abstract base class to encapsulate an ASP predicate or complex term. This is the heart of the ORM model for defining the mapping of a predicate @@ -3324,7 +3343,7 @@ def _unify( def symbol(self): """Returns the Symbol object corresponding to the fact. - The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.NoSymbol``. + The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.Symbol``. """ return self._raw @@ -3413,74 +3432,35 @@ def __neg__(self): # -------------------------------------------------------------------------- # Overloaded operators # -------------------------------------------------------------------------- + @abc.abstractmethod def __eq__(self, other): """Overloaded boolean operator.""" - if isinstance(other, self.__class__): - return self._field_values == other._field_values and self._sign == other._sign - if self.meta.is_tuple: - return self._field_values == other - elif isinstance(other, Predicate): - return False - return NotImplemented + raise NotImplementedError("Predicate.__eq__() must be overriden") + @abc.abstractmethod def __lt__(self, other): """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__lt__() must be overriden") - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - - # Negative literals are less than positive literals - if self.sign != other.sign: - return self.sign < other.sign - - return self._field_values < other._field_values - - # If different predicates then compare the raw value - elif isinstance(other, Predicate): - return self.raw < other.raw - - # Else an error - return NotImplemented + def __le__(self, other): + """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__le__() must be overriden") + @abc.abstractmethod def __ge__(self, other): """Overloaded boolean operator.""" - result = self.__lt__(other) - if result is NotImplemented: - return NotImplemented - return not result + raise NotImplementedError("Predicate.__ge__() must be overriden") + @abc.abstractmethod def __gt__(self, other): """Overloaded boolean operator.""" + raise NotImplementedError("Predicate.__gt__() must be overriden") - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - # Positive literals are greater than negative literals - if self.sign != other.sign: - return self.sign > other.sign - - return self._field_values > other._field_values - - # If different predicates then compare the raw value - if not isinstance(other, Predicate): - return self.raw > other.raw - - # Else an error - return NotImplemented - - def __le__(self, other): - """Overloaded boolean operator.""" - result = self.__gt__(other) - if result is NotImplemented: - return NotImplemented - return not result + @abc.abstractmethod def __hash__(self): - if self._hash is None: - if self.meta.is_tuple: - self._hash = hash(self._field_values) - else: - self._hash = hash((self.meta.name, self._field_values)) - return self._hash + """Overload the hash function.""" + raise NotImplementedError("Predicate.__hash__() must be overriden") def __str__(self): """Returns the Predicate as the string representation of an ASP fact.""" diff --git a/clorm/orm/templating.py b/clorm/orm/templating.py index 8c4c4c3..287c8f0 100644 --- a/clorm/orm/templating.py +++ b/clorm/orm/templating.py @@ -37,16 +37,16 @@ def add_spaces(num, text): lines = template.expandtabs(4).splitlines() outlines = [] for line in lines: - start = line.find("{%") + start = line.find(r"{%") if start == -1: outlines.append(line) continue - end = line.find("%}", start) + end = line.find(r"%}", start) if end == -1: raise ValueError("Bad template expansion in {line}") - keyword = line[start + 2 : end] + keyword = line[start + 2:end] text = add_spaces(start, kwargs[keyword]) - line = line[0:start] + text + line[end + 2 :] + line = line[0:start] + text + line[end + 2:] outlines.append(line) return "\n".join(outlines) @@ -73,6 +73,7 @@ def __init__(self, ({{%args_raw%}}), self._sign) + @classmethod def _unify(cls: Type[_P], raw: AnySymbol, raw_args: Optional[Sequence[AnySymbol]]=None, raw_name: Optional[str]=None) -> Optional[_P]: try: @@ -98,6 +99,92 @@ def _unify(cls: Type[_P], raw: AnySymbol, raw_args: Optional[Sequence[AnySymbol] raise ValueError((f"Cannot unify with object {{raw}} ({{type(raw)}}) as " "it is not a clingo Symbol Function object")) + +def nontuple__eq__(self, other: Any) -> bool: + # Deal with a non-tuple predicate + if isinstance(other, Predicate): + return self._raw == other._raw + if isinstance(other, Symbol): + return self._raw == other + return NotImplemented + + +def tuple__eq__(self, other: Any) -> bool: + # Deal with a predicate that is a tuple + if isinstance(other, Predicate): + return self._raw == other._raw + if isinstance(other, Symbol): + return self._raw == other +# if isinstance(other, tuple): +# return self._field_values == other + return NotImplemented + + +def nontuple__lt__(self, other): + # If it is the same predicate class then compare the underlying clingo symbol + if isinstance(other, Predicate): + return self._raw < other._raw + if isinstance(other, Symbol): + return self._raw < other + return NotImplemented + + +def tuple__lt__(self, other): + # self is always less than a non-tuple predicate + if isinstance(other, Predicate): + return self._raw < other._raw + if isinstance(other, Symbol): + return self._raw < other +# if isinstance(other, tuple): +# return self._field_values < other + return NotImplemented + + +def nontuple__gt__(self, other): + if isinstance(other, Predicate): + return self._raw > other._raw + if isinstance(other, Symbol): + return self._raw > other + return NotImplemented + + +def tuple__gt__(self, other): + # If it is the same predicate class then compare the sign and fields + if isinstance(other, Predicate): + return self._raw > other._raw + if isinstance(other, Symbol): + return self._raw > other +# if isinstance(other, tuple): +# return self._field_values > other + return NotImplemented + + +def __ge__(self, other): + result = self.__lt__(other) + if result is NotImplemented: + return NotImplemented + return not result + + +def __le__(self, other): + result = self.__gt__(other) + if result is NotImplemented: + return NotImplemented + return not result + + +def __hash__(self): + if self._hash is None: + self._hash = hash(self._raw) + return self._hash + + +__eq__ = tuple__eq__ if PREDICATE_IS_TUPLE else nontuple__eq__ +__lt__ = tuple__lt__ if PREDICATE_IS_TUPLE else nontuple__lt__ +__gt__ = tuple__gt__ if PREDICATE_IS_TUPLE else nontuple__gt__ + + + """ CHECK_SIGN_TEMPLATE = r""" diff --git a/tests/support.py b/tests/support.py index d3fd022..2ce3be6 100644 --- a/tests/support.py +++ b/tests/support.py @@ -1,3 +1,5 @@ +from clorm import Predicate + # ------------------------------------------------------------------------------ # Support functions for the unit tests # ------------------------------------------------------------------------------ @@ -19,6 +21,17 @@ def check_errmsg_contains(contmsg, ctx): raise AssertionError(msg) +# ------------------------------------------------------------------------------ +# +# ------------------------------------------------------------------------------ + +def to_tuple(value): + """Recursively convert a predicate/normal tuple into a Python tuple""" + if isinstance(value, tuple) or (isinstance(value, Predicate) and value.meta.is_tuple): + return tuple(to_tuple(x) for x in value) + return value + + # ------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------ diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index aa24145..469a5f7 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -8,6 +8,12 @@ # to be completed. # ------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------------- +# NOTE: 20242028 See orm/core.py for changes to the semantics of comparison operators for +# Predicate objects. +# ------------------------------------------------------------------------------------------- + + import collections.abc as cabc import datetime import enum @@ -71,7 +77,7 @@ trueall, ) -from .support import check_errmsg, check_errmsg_contains +from .support import check_errmsg, check_errmsg_contains, to_tuple # Error messages for CPython and PyPy vary PYPY = sys.implementation.name == "pypy" @@ -280,10 +286,12 @@ def test_api_field_function(self): self.assertEqual(t, (StringField, IntegerField)) t = field((StringField, IntegerField), default=("3", 4)) + self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, IntegerField) - self.assertEqual(t.default, ("3", 4)) + self.assertEqual(t.default, t.complex("3", 4)) + self.assertEqual(to_tuple(t.default), ("3", 4)) with self.subTest("with custom field"): INLField = define_flat_list_field(IntegerField, name="INLField") @@ -302,8 +310,8 @@ def factory(): return ("3", x) t = field((StringField, IntegerField), default_factory=factory) - self.assertEqual(t.default, ("3", 1)) - self.assertEqual(t.default, ("3", 2)) + self.assertEqual(to_tuple(t.default), ("3", 1)) + self.assertEqual(to_tuple(t.default), ("3", 2)) with self.subTest("with nested tuple and default"): t = field((StringField, (StringField, IntegerField))) @@ -313,7 +321,7 @@ def factory(): self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, BaseField) - self.assertEqual(t.default, ("3", ("1", 4))) + self.assertEqual(to_tuple(t.default), ("3", ("1", 4))) def test_api_field_function_illegal_arguments(self): with self.subTest("illegal basefield type"): @@ -783,7 +791,7 @@ def test_api_nested_list_field_complex_element_field(self): value2 = (2, ("b", "B")) nlist = (value1, value2) - self.assertEqual(XField.cltopy(symnlist), nlist) + self.assertEqual(to_tuple(XField.cltopy(symnlist)), nlist) self.assertEqual(XField.pytocl(nlist), symnlist) # -------------------------------------------------------------------------- @@ -835,7 +843,7 @@ def test_api_flat_list_field_complex_element_field(self): value2 = (2, ("b", "B")) nlist = (value1, value2) - self.assertEqual(XField.cltopy(symnlist), nlist) + self.assertEqual(to_tuple(XField.cltopy(symnlist)), nlist) self.assertEqual(XField.pytocl(nlist), symnlist) # -------------------------------------------------------------------------- @@ -1014,6 +1022,25 @@ class Blah(object): clresult = td.pytocl((1, "blah")) self.assertEqual(clresult, clob) + + # -------------------------------------------------------------------------- + # Test that we define the new comparison operators in the template + # -------------------------------------------------------------------------- + def test_predicate_comparison_operator_creation(self): + + class P(Predicate, name="p"): + a = IntegerField + b = ConstantField + + p1 = P(a=1, b="x") + p2 = P(a=1, b="x") + p3 = P(a=2, b="x") + + tmp = {} + tmp[p1] = "P" + self.assertEqual(p1, p2) + self.assertNotEqual(p1, p3) + # -------------------------------------------------------------------------- # Test that we can define predicates using the class syntax and test that # the getters and setters are connected properly to the predicate classes. @@ -1177,21 +1204,25 @@ class T(Predicate): tuple1 = tuple([1, "a"]) tuple2 = tuple([2, "b"]) - # Equality works even when the types are different + # Equality works even when the predicate types are different self.assertTrue(p1.tuple_ == p1_alt.tuple_) self.assertTrue(p1.tuple_ == q1.tuple_) self.assertTrue(p1.tuple_ == t1.tuple_) - self.assertTrue(p1.tuple_ == tuple1) + + # New behaviour. Doesn't compare directly to tuple + self.assertFalse(p1.tuple_ == tuple1) + self.assertTrue(tuple(p1.tuple_) == tuple1) self.assertNotEqual(type(p1.tuple_), type(t1.tuple_)) - # self.assertNotEqual(type(p1.tuple_), type(t1)) self.assertTrue(p1.tuple_ != p2.tuple_) self.assertTrue(p1.tuple_ != q2.tuple_) self.assertTrue(p1.tuple_ != r2.tuple_) self.assertTrue(p1.tuple_ != s2.tuple_) self.assertTrue(p1.tuple_ != t2.tuple_) - self.assertTrue(p1.tuple_ != tuple2) + + self.assertTrue(p1.tuple_ != tuple1) + self.assertFalse(tuple(p1.tuple_) != tuple1) # -------------------------------------------------------------------------- # Test predicates with default fields @@ -2113,33 +2144,18 @@ def test_predicate_comparison_operator_overload_signed(self): class P(Predicate): a = IntegerField - class Q(Predicate): - a = IntegerField - p1 = P(1) neg_p1 = P(1, sign=False) p2 = P(2) neg_p2 = P(2, sign=False) - q1 = Q(1) - - self.assertTrue(neg_p1 < neg_p2) - self.assertTrue(neg_p1 < p1) - self.assertTrue(neg_p1 < p2) - self.assertTrue(neg_p2 < p1) - self.assertTrue(neg_p2 < p2) - self.assertTrue(p1 < p2) - - self.assertTrue(p2 > p1) - self.assertTrue(p2 > neg_p2) - self.assertTrue(p2 > neg_p1) - self.assertTrue(p1 > neg_p2) - self.assertTrue(p1 > neg_p1) - self.assertTrue(neg_p2 > neg_p1) - # Different predicate sub-classes are incomparable + # NOTE: 20240428 see note at top about change of semantics + self.assertTrue((neg_p1.raw < p1.raw) == (neg_p1 < p1)) + self.assertTrue((neg_p1.raw < p2.raw) == (neg_p1 < p2)) + self.assertTrue((neg_p2.raw < p1.raw) == (neg_p2 < p1)) + self.assertTrue((neg_p2.raw < p2.raw) == (neg_p2 < p2)) + self.assertTrue((p1.raw < p2.raw) == (p1 < p2)) - # with self.assertRaises(TypeError) as ctx: - # self.assertTrue(p1 < q1) # -------------------------------------------------------------------------- # Test a simple predicate with a field that has a function default @@ -2320,9 +2336,9 @@ class Fact(Predicate): with self.assertRaises(TypeError) as ctx: class Fact2(Predicate): - afun = Fun.Field(default=(1, "str")) + afun = Fun.Field(default=6) - check_errmsg("""Invalid default value "(1, 'str')" for FunField""", ctx) + check_errmsg("""Invalid default value "6" for FunField""", ctx) # -------------------------------------------------------------------------- # Test the simple_predicate function as a mechanism for defining @@ -2423,6 +2439,7 @@ class Fact(Predicate): # Test predicate equality # -------------------------------------------------------------------------- def test_predicate_comparison_operator_overloads(self): + # NOTE: 20240428 see note at top about change of semantics f1 = Function("fact", [Number(1)]) f2 = Function("fact", [Number(2)]) @@ -2452,9 +2469,9 @@ class Meta: self.assertEqual(f1, af1.raw) self.assertEqual(af1.raw, f1) self.assertEqual(af1.raw, ag1.raw) - self.assertNotEqual(af1, ag1) - self.assertNotEqual(af1, f1) - self.assertNotEqual(f1, af1) + self.assertEqual(af1, ag1) + self.assertEqual(af1, f1) + self.assertEqual(f1, af1) self.assertTrue(af1 < af2) self.assertTrue(af1 <= af2) @@ -2478,6 +2495,8 @@ class Meta: # Test predicate equality # -------------------------------------------------------------------------- def test_comparison_operator_overloads_complex(self): + # NOTE: 20240428 see note at top about change of semantics + class SwapField(IntegerField): pytocl = lambda x: 100 - x cltopy = lambda x: 100 - x @@ -2497,17 +2516,15 @@ class AComplex(ComplexTerm): for rf in [rf1, rf2, rf3]: self.assertEqual(rf.arguments[0], rf.arguments[1]) - # Test the the comparison operator for the complex term is using the - # swapped values so that the comparison is opposite to what the raw - # field says. self.assertTrue(rf1 < rf2) self.assertTrue(rf2 < rf3) - self.assertTrue(f1 > f2) - self.assertTrue(f2 > f3) - self.assertTrue(f2 < f1) - self.assertTrue(f3 < f2) + self.assertTrue(f1 < f2) + self.assertTrue(f2 < f3) + self.assertTrue(f2 > f1) + self.assertTrue(f3 > f2) self.assertEqual(f3, f4) + # -------------------------------------------------------------------------- # Test unifying a symbol with a predicate # -------------------------------------------------------------------------- diff --git a/tests/test_orm_factbase.py b/tests/test_orm_factbase.py index 15f4be6..9710e6b 100644 --- a/tests/test_orm_factbase.py +++ b/tests/test_orm_factbase.py @@ -1118,6 +1118,9 @@ class Afact(Predicate): # Test that select works with order_by for complex term # -------------------------------------------------------------------------- def test_api_factbase_select_order_by_complex_term(self): + # NOTE: behavior change 20240428 - ordering is based on the underlying clingo.Symbol + # object and not the python translation. So using SwapField won't change the ordering + # for AComplex objects. class SwapField(IntegerField): pytocl = lambda x: 100 - x cltopy = lambda x: 100 - x @@ -1145,10 +1148,13 @@ class AFact(Predicate): self.assertEqual([f1, f2, f3, f4], list(q.get())) q = fb.select(AFact).order_by(AFact.cmplx, AFact.astr) - self.assertEqual([f3, f4, f2, f1], list(q.get())) + self.assertEqual([f1, f2, f3, f4], list(q.get())) q = fb.select(AFact).where(AFact.cmplx <= ph1_).order_by(AFact.cmplx, AFact.astr) - self.assertEqual([f3, f4, f2], list(q.get(cmplx2))) + self.assertEqual([f1, f2], list(q.get(cmplx2))) + + q = fb.select(AFact).where(AFact.cmplx >= ph1_).order_by(AFact.cmplx, AFact.astr) + self.assertEqual([f2, f3, f4], list(q.get(cmplx2))) # -------------------------------------------------------------------------- # Test that select works with order_by for complex term diff --git a/tests/test_orm_query.py b/tests/test_orm_query.py index 46d0b93..a1810db 100644 --- a/tests/test_orm_query.py +++ b/tests/test_orm_query.py @@ -84,7 +84,7 @@ where_expression_to_nnf, ) -from .support import check_errmsg, check_errmsg_contains +from .support import check_errmsg, check_errmsg_contains, to_tuple ###### NOTE: The QueryOutput tests need to be turned into QueryExecutor ###### tests. We can then delete QueryOutput which is not being used for @@ -432,7 +432,7 @@ class F(Predicate): getter = make_input_alignment_functor([F], [F.acomplex]) result = getter((f1,)) tmp = ((1, 2),) - self.assertEqual(result, tmp) + self.assertEqual(to_tuple(result), tmp) # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ @@ -545,9 +545,9 @@ class F(Predicate): getter = make_input_alignment_functor([F], [F.acomplex]) result = getter((f1,)) - self.assertEqual(result, ((1, 2),)) + self.assertEqual(to_tuple(result), ((1, 2),)) - sc = SC(operator.eq, [F.acomplex, (1, 2)]) + sc = SC(operator.eq, [F.acomplex, F.meta[1].defn.complex(1, 2)]) cmp = sc.make_callable([F]) self.assertTrue(cmp((f1,))) self.assertFalse(cmp((f2,)))