diff --git a/src/context/context.c b/src/context/context.c index 849ae8a9c..ffac2458c 100644 --- a/src/context/context.c +++ b/src/context/context.c @@ -5360,7 +5360,7 @@ static void create_mcarith_solver(context_t *ctx) { // row saving must be enabled unless we're in ONECHECK mode if (ctx->mode != CTX_MODE_ONECHECK) { - simplex_enable_row_saving(solver); + mcarith_enable_row_saving(solver); } if (ctx->egraph != NULL) { diff --git a/src/solvers/mcarith/mcarith.c b/src/solvers/mcarith/mcarith.c index 77d5eb8d3..8e344b5eb 100644 --- a/src/solvers/mcarith/mcarith.c +++ b/src/solvers/mcarith/mcarith.c @@ -429,6 +429,29 @@ void mcarith_check_atom_size(mcarith_solver_t *solver) { static void rba_buffer_add_mono_mcarithvar(mcarith_solver_t* solver, rba_buffer_t* b, rational_t *a, thvar_t v); +/** + * This creates a term from a polynomial while taking care not to + * introduce a polynomial from a constant. + */ +static +term_t mcarith_poly(term_table_t* terms, rba_buffer_t* b) { + + if (b->nterms == 0) { + rational_t q; + q_init(&q); + return arith_constant(terms, &q); + } + + if (b->nterms == 1) { + mono_t* m = b->mono + b->root; + if (m->prod == empty_pp) { + return arith_constant(terms, &m->coeff); + } + } + + return arith_poly(terms, b); +} + static term_t get_term(mcarith_solver_t* solver, thvar_t v) { arith_vartable_t* table; @@ -469,7 +492,7 @@ term_t get_term(mcarith_solver_t* solver, thvar_t v) { rba_buffer_add_mono_mcarithvar(solver, &b, r, v); } } - t = arith_poly(terms, &b); + t = mcarith_poly(terms, &b); } break; case AVARTAG_KIND_PPROD: @@ -648,8 +671,6 @@ term_t get_atom(mcarith_solver_t* solver, int32_t atom_index) { term_t v; rational_t* bnd; rba_buffer_t b; - term_t polyTerm; - mcarith_check_atom_size(solver); @@ -671,26 +692,31 @@ term_t get_atom(mcarith_solver_t* solver, int32_t atom_index) { // Assert v-b >= 0 case GE_ATM: { if (q_is_zero(bnd)) { - polyTerm = get_term(solver, var_of_atom(ap)); + term_t polyTerm; + polyTerm = get_term(solver, v); + result = arith_geq_atom(terms, polyTerm); } else { + term_t polyTerm; // Create buffer b = v - bnd init_rba_buffer(&b, mcarith_solver_pprods(solver)); rba_buffer_add_mcarithvar(solver, &b, v); rba_buffer_sub_const(&b, bnd); // Create term and free buffer. - polyTerm = arith_poly(terms, &b); + polyTerm = mcarith_poly(terms, &b); delete_rba_buffer(&b); + result = arith_geq_atom(terms, polyTerm); } - result = arith_geq_atom(terms, polyTerm); break; } // Assert v <= b by asserting b-v >= 0 case LE_ATM: { + term_t polyTerm; + // Create buffer b = bnd - v init_rba_buffer(&b, mcarith_solver_pprods(solver)); rba_buffer_sub_mcarithvar(solver, &b, v); rba_buffer_add_const(&b, bnd); - polyTerm = arith_poly(terms, &b); + polyTerm = mcarith_poly(terms, &b); delete_rba_buffer(&b); result = arith_geq_atom(terms, polyTerm); break; @@ -716,24 +742,30 @@ thvar_t mcarith_create_pprod(void *s, pprod_t *p, thvar_t *map) { assert(pprod_degree(p) > 1); assert(!pp_is_empty(p)); assert(!pp_is_var(p)); - // Create theory variable thvar_t v = simplex_create_var(&solver->simplex, false); // Remap variables in powerproduct mcarith_check_var_size(solver); pp_buffer_t b; init_pp_buffer(&b, p->len); + uint32_t n = p->len; + int32_t* vars = safe_malloc(sizeof(int32_t) * n); + uint32_t* exps = safe_malloc(sizeof(uint32_t) * n); + for (uint32_t i = 0; i < p->len; ++i) { thvar_t mv = map[i]; assert(mv < v); - term_t t = get_term(solver, mv); - pp_buffer_set_varexp(&b, t, p->prod[i].exp); + vars[i] = get_term(solver, mv); + exps[i] = p->prod[i].exp; } - pp_buffer_normalize(&b); + pp_buffer_set_varexps(&b, n, vars, exps); + free(vars); + free(exps); - // Create term assert(solver->var_terms_size > v); - solver->var_terms[v] = pprod_term_from_buffer(mcarith_solver_terms(solver), &b); + + term_t t = pprod_term_from_buffer(mcarith_solver_terms(solver), &b); + solver->var_terms[v] = t; // Free buffer and return delete_pp_buffer(&b); return v; @@ -839,6 +871,23 @@ void init_rba_buffer_from_poly(mcarith_solver_t* solver, } } +static +bool rba_buffer_get_const(rba_buffer_t* b, rational_t* r) { + if (b->nterms == 0) { + q_init(r); + return true; + } else if (b->nterms == 1) { + mono_t* m = b->mono + b->root; + if (m->prod == empty_pp) { + q_init(r); + q_set(r, &m->coeff); + return true; + } + } + return false; +} + + /* * Check for integer feasibility */ @@ -847,21 +896,18 @@ fcheck_code_t mcarith_final_check(void* s) { mcarith_solver_t *solver; term_table_t* terms; uint32_t acnt; - + term_t t; solver = s; terms = mcarith_solver_terms(solver); + assert(!solver->simplex.unsat_before_search); - yices_pp_t printer; - pp_area_t area; - - area.width = 72; - area.height = 8; - area.offset = 2; - area.stretch = false; - area.truncate = true; + fcheck_code_t result; - init_default_yices_pp(&printer, stdout, &area); + result = simplex_final_check(&solver->simplex); + if (result == FCHECK_CONTINUE) { + return result; + } mcarith_free_mcsat(solver); @@ -892,18 +938,40 @@ fcheck_code_t mcarith_final_check(void* s) { p = arith_eq_atom(terms, get_term(solver, a->def.var)); break; case VAR_GE0: - p = arith_geq_atom(terms, get_term(solver, a->def.var)); + t = get_term(solver, a->def.var); + assert(term_kind(terms, t) != ARITH_CONSTANT); + p = arith_geq_atom(terms, t); break; - case POLY_EQ0: + case POLY_EQ0: { + rational_t br; + init_rba_buffer_from_poly(solver, &b, a->def.poly); - p = arith_eq_atom(terms, arith_poly(terms, &b)); + if (rba_buffer_get_const(&b, &br)) { + p = bool2term(q_is_zero(&br)); + q_clear(&br); + } else { + t = arith_poly(terms, &b); + assert(term_kind(terms, t) != ARITH_CONSTANT); + p = arith_eq_atom(terms, t); + } delete_rba_buffer(&b); break; - case POLY_GE0: + } + case POLY_GE0: { + rational_t br; + init_rba_buffer_from_poly(solver, &b, a->def.poly); - p = arith_geq_atom(terms, arith_poly(terms, &b)); + if (rba_buffer_get_const(&b, &br)) { + p = bool2term(q_is_nonneg(&br)); + q_clear(&br); + } else { + t = arith_poly(terms, &b); + assert(term_kind(terms, t) != ARITH_CONSTANT); + p = arith_geq_atom(terms, t); + } delete_rba_buffer(&b); break; + } case ATOM_ASSERT: p = get_atom(solver, a->def.atom); break; @@ -920,10 +988,9 @@ fcheck_code_t mcarith_final_check(void* s) { literals[literal_count] = end_clause; int32_t r = mcsat_assert_formulas(solver->mcsat, acnt, assertions); - fcheck_code_t result; if (r == TRIVIALLY_UNSAT) { record_theory_conflict(solver->simplex.core, literals); - mcarith_free_mcsat(solver); + mcarith_free_mcsat(solver); result = FCHECK_CONTINUE; } else if (r == CTX_NO_ERROR) { result = FCHECK_SAT; @@ -990,7 +1057,10 @@ void mcarith_increase_decision_level(mcarith_solver_t *solver) { void mcarith_backtrack(mcarith_solver_t *solver, uint32_t back_level) { mcarith_undo_record_t* r = mcarith_undo_backtrack(&solver->undo_stack, back_level); mcarith_backtrack_assertions(solver, r->n_assertions); + uint32_t vc = solver->simplex.vtbl.nvars; simplex_backtrack(&solver->simplex, back_level); + uint32_t vc0 = solver->simplex.vtbl.nvars; + assert(vc == vc0); } /* @@ -1028,7 +1098,10 @@ void mcarith_pop(mcarith_solver_t *solver) { * Reset to the empty solver */ void mcarith_reset(mcarith_solver_t *solver) { + uint32_t vc = solver->simplex.vtbl.nvars; simplex_reset(&solver->simplex); + uint32_t vc0 = solver->simplex.vtbl.nvars; + assert(vc == vc0); // FIXME mcarith_undo_reset(&solver->undo_stack); mcarith_backtrack_assertions(solver, 0); @@ -1134,6 +1207,27 @@ thvar_t mcarith_create_poly(void* s, polynomial_t *p, thvar_t *map) { return simplex_create_poly(&solver->simplex, p, map); } +typedef struct{ + uint32_t bstack_top; + uint32_t matrix_nrows; +} simplex_assertion_count_t; + +static simplex_assertion_count_t simplex_assertion_count(simplex_solver_t* simplex) { + simplex_assertion_count_t r; + r.bstack_top = simplex->bstack.top; + r.matrix_nrows = simplex->matrix.nrows; + return r; +} + +/** + * Returns true if the simplex solver concluded the assertion just added did not + * need to get recorded as it was true or false from previous assertions. + */ +static bool simplex_handled_assertion(simplex_solver_t* simplex, simplex_assertion_count_t c) { + return simplex->unsat_before_search + || (simplex->matrix.nrows == c.matrix_nrows && simplex->bstack.top == c.bstack_top); +} + /* * Assert a top-level equality constraint (either x == 0 or x != 0) * - tt indicates whether the constraint or its negation must be asserted @@ -1145,7 +1239,11 @@ static void mcarith_assert_eq_axiom(void* s, thvar_t x, bool tt) { mcarith_assertion_t* assert; solver = s; + // Get number of assertions + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_eq_axiom(&solver->simplex, x, tt); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; // Record assertion for sending to mcarith solver. assert = alloc_top_assertion(solver); @@ -1166,7 +1264,11 @@ static void mcarith_assert_ge_axiom(void* s, thvar_t x, bool tt){ mcarith_assertion_t* assert; solver = s; + // Get number of assertions + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_ge_axiom(&solver->simplex, x, tt); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; assert = alloc_top_assertion(solver); assert->type = VAR_GE0; @@ -1188,10 +1290,15 @@ static void mcarith_assert_poly_eq_axiom(void * s, polynomial_t *p, thvar_t *map solver = s; assert(p->nterms > 0); + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_poly_eq_axiom(&solver->simplex, p, map, tt); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; assert = alloc_top_assertion(solver); assert->type = POLY_EQ0; + assert(p->nterms > 0); + assert(p->nterms > 1 || p->mono[0].var != const_idx); assert->def.poly = alloc_polynomial_from_map(p, map, solver->simplex.vtbl.nvars); assert->tt = tt; assert->lit = null_literal; @@ -1209,10 +1316,15 @@ static void mcarith_assert_poly_ge_axiom(void *s, polynomial_t *p, thvar_t *map, solver = s; assert(p->nterms > 0); + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_poly_ge_axiom(&solver->simplex, p, map, tt); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; assert = alloc_top_assertion(solver); assert->type = POLY_GE0; + assert(p->nterms > 0); + assert(p->nterms > 1 || p->mono[0].var != const_idx); assert->def.poly = alloc_polynomial_from_map(p, map, solver->simplex.vtbl.nvars); assert->tt = tt; assert->lit = null_literal; @@ -1228,7 +1340,10 @@ void mcarith_assert_vareq_axiom(void* s, thvar_t x, thvar_t y, bool tt) { mcarith_assertion_t* assert; solver = s; + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_vareq_axiom(&solver->simplex, x, y, tt); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; assert = alloc_top_assertion(solver); assert->type = VAR_EQ; @@ -1245,7 +1360,12 @@ static void mcarith_assert_cond_vareq_axiom(void* s, literal_t c, thvar_t x, thvar_t y) { mcarith_solver_t *solver; solver = s; + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_cond_vareq_axiom(&solver->simplex, c, x, y); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; + + assert(false); } /* @@ -1254,7 +1374,12 @@ void mcarith_assert_cond_vareq_axiom(void* s, literal_t c, thvar_t x, thvar_t y) static void mcarith_assert_clause_vareq_axiom(void* s, uint32_t n, literal_t *c, thvar_t x, thvar_t y) { mcarith_solver_t *solver = s; + simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex); simplex_assert_clause_vareq_axiom(&solver->simplex, n, c, x, y); + if (simplex_handled_assertion(&solver->simplex, cinit)) + return; + + assert(false); } /*