Skip to content

Commit

Permalink
Bugfixes to mcarith
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Hendrix committed Sep 26, 2023
1 parent c49758c commit c017ba4
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/context/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -5360,7 +5360,7 @@ static void create_mcarith_solver(context_t *ctx) {

// row saving must be enabled unless we're in ONECHECK mode
if (ctx->mode != CTX_MODE_ONECHECK) {
simplex_enable_row_saving(solver);
mcarith_enable_row_saving(solver);
}

if (ctx->egraph != NULL) {
Expand Down
185 changes: 155 additions & 30 deletions src/solvers/mcarith/mcarith.c
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,29 @@ void mcarith_check_atom_size(mcarith_solver_t *solver) {
static
void rba_buffer_add_mono_mcarithvar(mcarith_solver_t* solver, rba_buffer_t* b, rational_t *a, thvar_t v);

/**
* This creates a term from a polynomial while taking care not to
* introduce a polynomial from a constant.
*/
static
term_t mcarith_poly(term_table_t* terms, rba_buffer_t* b) {

if (b->nterms == 0) {
rational_t q;
q_init(&q);
return arith_constant(terms, &q);
}

if (b->nterms == 1) {
mono_t* m = b->mono + b->root;
if (m->prod == empty_pp) {
return arith_constant(terms, &m->coeff);
}
}

return arith_poly(terms, b);
}

static
term_t get_term(mcarith_solver_t* solver, thvar_t v) {
arith_vartable_t* table;
Expand Down Expand Up @@ -469,7 +492,7 @@ term_t get_term(mcarith_solver_t* solver, thvar_t v) {
rba_buffer_add_mono_mcarithvar(solver, &b, r, v);
}
}
t = arith_poly(terms, &b);
t = mcarith_poly(terms, &b);
}
break;
case AVARTAG_KIND_PPROD:
Expand Down Expand Up @@ -648,8 +671,6 @@ term_t get_atom(mcarith_solver_t* solver, int32_t atom_index) {
term_t v;
rational_t* bnd;
rba_buffer_t b;
term_t polyTerm;


mcarith_check_atom_size(solver);

Expand All @@ -671,26 +692,31 @@ term_t get_atom(mcarith_solver_t* solver, int32_t atom_index) {
// Assert v-b >= 0
case GE_ATM: {
if (q_is_zero(bnd)) {
polyTerm = get_term(solver, var_of_atom(ap));
term_t polyTerm;
polyTerm = get_term(solver, v);
result = arith_geq_atom(terms, polyTerm);
} else {
term_t polyTerm;
// Create buffer b = v - bnd
init_rba_buffer(&b, mcarith_solver_pprods(solver));
rba_buffer_add_mcarithvar(solver, &b, v);
rba_buffer_sub_const(&b, bnd);
// Create term and free buffer.
polyTerm = arith_poly(terms, &b);
polyTerm = mcarith_poly(terms, &b);
delete_rba_buffer(&b);
result = arith_geq_atom(terms, polyTerm);
}
result = arith_geq_atom(terms, polyTerm);
break;
}
// Assert v <= b by asserting b-v >= 0
case LE_ATM: {
term_t polyTerm;

// Create buffer b = bnd - v
init_rba_buffer(&b, mcarith_solver_pprods(solver));
rba_buffer_sub_mcarithvar(solver, &b, v);
rba_buffer_add_const(&b, bnd);
polyTerm = arith_poly(terms, &b);
polyTerm = mcarith_poly(terms, &b);
delete_rba_buffer(&b);
result = arith_geq_atom(terms, polyTerm);
break;
Expand All @@ -716,24 +742,30 @@ thvar_t mcarith_create_pprod(void *s, pprod_t *p, thvar_t *map) {
assert(pprod_degree(p) > 1);
assert(!pp_is_empty(p));
assert(!pp_is_var(p));

// Create theory variable
thvar_t v = simplex_create_var(&solver->simplex, false);
// Remap variables in powerproduct
mcarith_check_var_size(solver);
pp_buffer_t b;
init_pp_buffer(&b, p->len);
uint32_t n = p->len;
int32_t* vars = safe_malloc(sizeof(int32_t) * n);
uint32_t* exps = safe_malloc(sizeof(uint32_t) * n);

for (uint32_t i = 0; i < p->len; ++i) {
thvar_t mv = map[i];
assert(mv < v);
term_t t = get_term(solver, mv);
pp_buffer_set_varexp(&b, t, p->prod[i].exp);
vars[i] = get_term(solver, mv);
exps[i] = p->prod[i].exp;
}
pp_buffer_normalize(&b);
pp_buffer_set_varexps(&b, n, vars, exps);
free(vars);
free(exps);

// Create term
assert(solver->var_terms_size > v);
solver->var_terms[v] = pprod_term_from_buffer(mcarith_solver_terms(solver), &b);

term_t t = pprod_term_from_buffer(mcarith_solver_terms(solver), &b);
solver->var_terms[v] = t;
// Free buffer and return
delete_pp_buffer(&b);
return v;
Expand Down Expand Up @@ -839,6 +871,23 @@ void init_rba_buffer_from_poly(mcarith_solver_t* solver,
}
}

static
bool rba_buffer_get_const(rba_buffer_t* b, rational_t* r) {
if (b->nterms == 0) {
q_init(r);
return true;
} else if (b->nterms == 1) {
mono_t* m = b->mono + b->root;
if (m->prod == empty_pp) {
q_init(r);
q_set(r, &m->coeff);
return true;
}
}
return false;
}


/*
* Check for integer feasibility
*/
Expand All @@ -847,21 +896,18 @@ fcheck_code_t mcarith_final_check(void* s) {
mcarith_solver_t *solver;
term_table_t* terms;
uint32_t acnt;

term_t t;

solver = s;
terms = mcarith_solver_terms(solver);
assert(!solver->simplex.unsat_before_search);

yices_pp_t printer;
pp_area_t area;

area.width = 72;
area.height = 8;
area.offset = 2;
area.stretch = false;
area.truncate = true;
fcheck_code_t result;

init_default_yices_pp(&printer, stdout, &area);
result = simplex_final_check(&solver->simplex);
if (result == FCHECK_CONTINUE) {
return result;
}

mcarith_free_mcsat(solver);

Expand Down Expand Up @@ -892,18 +938,40 @@ fcheck_code_t mcarith_final_check(void* s) {
p = arith_eq_atom(terms, get_term(solver, a->def.var));
break;
case VAR_GE0:
p = arith_geq_atom(terms, get_term(solver, a->def.var));
t = get_term(solver, a->def.var);
assert(term_kind(terms, t) != ARITH_CONSTANT);
p = arith_geq_atom(terms, t);
break;
case POLY_EQ0:
case POLY_EQ0: {
rational_t br;

init_rba_buffer_from_poly(solver, &b, a->def.poly);
p = arith_eq_atom(terms, arith_poly(terms, &b));
if (rba_buffer_get_const(&b, &br)) {
p = bool2term(q_is_zero(&br));
q_clear(&br);
} else {
t = arith_poly(terms, &b);
assert(term_kind(terms, t) != ARITH_CONSTANT);
p = arith_eq_atom(terms, t);
}
delete_rba_buffer(&b);
break;
case POLY_GE0:
}
case POLY_GE0: {
rational_t br;

init_rba_buffer_from_poly(solver, &b, a->def.poly);
p = arith_geq_atom(terms, arith_poly(terms, &b));
if (rba_buffer_get_const(&b, &br)) {
p = bool2term(q_is_nonneg(&br));
q_clear(&br);
} else {
t = arith_poly(terms, &b);
assert(term_kind(terms, t) != ARITH_CONSTANT);
p = arith_geq_atom(terms, t);
}
delete_rba_buffer(&b);
break;
}
case ATOM_ASSERT:
p = get_atom(solver, a->def.atom);
break;
Expand All @@ -920,10 +988,9 @@ fcheck_code_t mcarith_final_check(void* s) {
literals[literal_count] = end_clause;

int32_t r = mcsat_assert_formulas(solver->mcsat, acnt, assertions);
fcheck_code_t result;
if (r == TRIVIALLY_UNSAT) {
record_theory_conflict(solver->simplex.core, literals);
mcarith_free_mcsat(solver);
mcarith_free_mcsat(solver);
result = FCHECK_CONTINUE;
} else if (r == CTX_NO_ERROR) {
result = FCHECK_SAT;
Expand Down Expand Up @@ -990,7 +1057,10 @@ void mcarith_increase_decision_level(mcarith_solver_t *solver) {
void mcarith_backtrack(mcarith_solver_t *solver, uint32_t back_level) {
mcarith_undo_record_t* r = mcarith_undo_backtrack(&solver->undo_stack, back_level);
mcarith_backtrack_assertions(solver, r->n_assertions);
uint32_t vc = solver->simplex.vtbl.nvars;
simplex_backtrack(&solver->simplex, back_level);
uint32_t vc0 = solver->simplex.vtbl.nvars;
assert(vc == vc0);
}

/*
Expand Down Expand Up @@ -1028,7 +1098,10 @@ void mcarith_pop(mcarith_solver_t *solver) {
* Reset to the empty solver
*/
void mcarith_reset(mcarith_solver_t *solver) {
uint32_t vc = solver->simplex.vtbl.nvars;
simplex_reset(&solver->simplex);
uint32_t vc0 = solver->simplex.vtbl.nvars;
assert(vc == vc0); // FIXME

mcarith_undo_reset(&solver->undo_stack);
mcarith_backtrack_assertions(solver, 0);
Expand Down Expand Up @@ -1134,6 +1207,27 @@ thvar_t mcarith_create_poly(void* s, polynomial_t *p, thvar_t *map) {
return simplex_create_poly(&solver->simplex, p, map);
}

typedef struct{
uint32_t bstack_top;
uint32_t matrix_nrows;
} simplex_assertion_count_t;

static simplex_assertion_count_t simplex_assertion_count(simplex_solver_t* simplex) {
simplex_assertion_count_t r;
r.bstack_top = simplex->bstack.top;
r.matrix_nrows = simplex->matrix.nrows;
return r;
}

/**
* Returns true if the simplex solver concluded the assertion just added did not
* need to get recorded as it was true or false from previous assertions.
*/
static bool simplex_handled_assertion(simplex_solver_t* simplex, simplex_assertion_count_t c) {
return simplex->unsat_before_search
|| (simplex->matrix.nrows == c.matrix_nrows && simplex->bstack.top == c.bstack_top);
}

/*
* Assert a top-level equality constraint (either x == 0 or x != 0)
* - tt indicates whether the constraint or its negation must be asserted
Expand All @@ -1145,7 +1239,11 @@ static void mcarith_assert_eq_axiom(void* s, thvar_t x, bool tt) {
mcarith_assertion_t* assert;

solver = s;
// Get number of assertions
simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_eq_axiom(&solver->simplex, x, tt);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

// Record assertion for sending to mcarith solver.
assert = alloc_top_assertion(solver);
Expand All @@ -1166,7 +1264,11 @@ static void mcarith_assert_ge_axiom(void* s, thvar_t x, bool tt){
mcarith_assertion_t* assert;

solver = s;
// Get number of assertions
simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_ge_axiom(&solver->simplex, x, tt);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert = alloc_top_assertion(solver);
assert->type = VAR_GE0;
Expand All @@ -1188,10 +1290,15 @@ static void mcarith_assert_poly_eq_axiom(void * s, polynomial_t *p, thvar_t *map
solver = s;
assert(p->nterms > 0);

simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_poly_eq_axiom(&solver->simplex, p, map, tt);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert = alloc_top_assertion(solver);
assert->type = POLY_EQ0;
assert(p->nterms > 0);
assert(p->nterms > 1 || p->mono[0].var != const_idx);
assert->def.poly = alloc_polynomial_from_map(p, map, solver->simplex.vtbl.nvars);
assert->tt = tt;
assert->lit = null_literal;
Expand All @@ -1209,10 +1316,15 @@ static void mcarith_assert_poly_ge_axiom(void *s, polynomial_t *p, thvar_t *map,
solver = s;
assert(p->nterms > 0);

simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_poly_ge_axiom(&solver->simplex, p, map, tt);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert = alloc_top_assertion(solver);
assert->type = POLY_GE0;
assert(p->nterms > 0);
assert(p->nterms > 1 || p->mono[0].var != const_idx);
assert->def.poly = alloc_polynomial_from_map(p, map, solver->simplex.vtbl.nvars);
assert->tt = tt;
assert->lit = null_literal;
Expand All @@ -1228,7 +1340,10 @@ void mcarith_assert_vareq_axiom(void* s, thvar_t x, thvar_t y, bool tt) {
mcarith_assertion_t* assert;

solver = s;
simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_vareq_axiom(&solver->simplex, x, y, tt);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert = alloc_top_assertion(solver);
assert->type = VAR_EQ;
Expand All @@ -1245,7 +1360,12 @@ static
void mcarith_assert_cond_vareq_axiom(void* s, literal_t c, thvar_t x, thvar_t y) {
mcarith_solver_t *solver;
solver = s;
simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_cond_vareq_axiom(&solver->simplex, c, x, y);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert(false);
}

/*
Expand All @@ -1254,7 +1374,12 @@ void mcarith_assert_cond_vareq_axiom(void* s, literal_t c, thvar_t x, thvar_t y)
static
void mcarith_assert_clause_vareq_axiom(void* s, uint32_t n, literal_t *c, thvar_t x, thvar_t y) {
mcarith_solver_t *solver = s;
simplex_assertion_count_t cinit = simplex_assertion_count(&solver->simplex);
simplex_assert_clause_vareq_axiom(&solver->simplex, n, c, x, y);
if (simplex_handled_assertion(&solver->simplex, cinit))
return;

assert(false);
}

/*
Expand Down

0 comments on commit c017ba4

Please sign in to comment.