diff --git a/clorm/_clingo.py b/clorm/_clingo.py index 6774f17..6ddaded 100644 --- a/clorm/_clingo.py +++ b/clorm/_clingo.py @@ -24,6 +24,7 @@ import clingo as oclingo from .orm import FactBase, Predicate, Symbol, SymbolPredicateUnifier, control_add_facts +from .util.oset import OrderedSet from .util.wrapper import init_wrapper, make_class_wrapper __all__ = ["ClormControl", "ClormModel", "ClormSolveHandle", "_expand_assumptions"] @@ -246,16 +247,11 @@ def _expand_assumptions( Tuple[Union[Iterable[Union[Predicate, Symbol]], Predicate, Symbol], bool] ] ) -> List[Tuple[Symbol, bool]]: - pos_assump = set() - neg_assump = set() + clingo_assump = [] def _add_fact(fact: Union[Predicate, Symbol], bval: bool) -> None: - nonlocal pos_assump, neg_assump raw = fact.raw if isinstance(fact, Predicate) else fact - if bval: - pos_assump.add(raw) - else: - neg_assump.add(raw) + clingo_assump.append((raw, bool(bval))) try: for (arg, bval) in assumptions: @@ -274,11 +270,7 @@ def _add_fact(fact: Union[Predicate, Symbol], bval: bool) -> None: "of raw-symbols/predicates). Got: {}" ).format(assumptions) ) - - # Now returned a list of raw assumptions combining pos and neg - pos = [(raw, True) for raw in pos_assump] - neg = [(raw, False) for raw in neg_assump] - return list(itertools.chain(pos, neg)) + return clingo_assump # ------------------------------------------------------------------------------ diff --git a/tests/test_clingo.py b/tests/test_clingo.py index 07fc483..8354c85 100644 --- a/tests/test_clingo.py +++ b/tests/test_clingo.py @@ -782,9 +782,10 @@ class Meta: num_models += 1 self.assertEqual(num_models, 3) - # -------------------------------------------------------------------------- - # Test the solvehandle - # -------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------- + # Test expanding the assumptions - note: the order matters so for an input list of + # predicates or symbols the output list is the corresponding symbols in the same order. + # ---------------------------------------------------------------------------------------- def test_expand_assumptions(self): class F(Predicate): num1 = IntegerField() @@ -796,20 +797,20 @@ class G(Predicate): f2 = F(2) g1 = G(1) - r = set(_expand_assumptions([(f1, True), (g1, False)])) - self.assertEqual(r, set([(f1.raw, True), (g1.raw, False)])) + r = _expand_assumptions([(f1, True), (g1, False)]) + self.assertEqual(r, [(f1.raw, True), (g1.raw, False)]) - r = set(_expand_assumptions([(FactBase([f1, f2]), True), (set([g1]), False)])) - self.assertEqual(r, set([(f1.raw, True), (f2.raw, True), (g1.raw, False)])) + r = _expand_assumptions([(FactBase([f1, f2]), True), (set([g1]), False)]) + self.assertEqual(r, [(f1.raw, True), (f2.raw, True), (g1.raw, False)]) with self.assertRaises(TypeError) as ctx: _expand_assumptions([g1]) with self.assertRaises(TypeError) as ctx: _expand_assumptions(g1) - # -------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------- # Test the solvehandle - # -------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------------- def test_solve_with_assumptions_simple(self): spu = SymbolPredicateUnifier()