Skip to content

Commit 5468bdf

Browse files
smtr: Make Rosette compatible
Convert most of the operators, except pmux and memory. Convert formatting for non-stateful modules.
1 parent 42e54c3 commit 5468bdf

File tree

1 file changed

+77
-124
lines changed

1 file changed

+77
-124
lines changed

backends/functional/smtlib_rosette.cc

+77-124
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ template <class NodeNames> struct SmtPrintVisitor {
107107

108108
std::string slice(Node, Node a, int, int offset, int out_width)
109109
{
110-
return format("((_ extract %2 %1) %0)", np(a), offset, offset + out_width - 1);
110+
return format("(extract %2 %1 %0)", np(a), offset, offset + out_width - 1);
111111
}
112112

113-
std::string zero_extend(Node, Node a, int, int out_width) { return format("((_ zero_extend %1) %0)", np(a), out_width - a.width()); }
113+
std::string zero_extend(Node, Node a, int, int out_width) { return format("(zero-extend %0 (bitvector %1))", np(a), out_width); }
114114

115-
std::string sign_extend(Node, Node a, int, int out_width) { return format("((_ sign_extend %1) %0)", np(a), out_width - a.width()); }
115+
std::string sign_extend(Node, Node a, int, int out_width) { return format("(sign-extend %0 (bitvector %1))", np(a), out_width); }
116116

117117
std::string concat(Node, Node a, int, Node b, int) { return format("(concat %0 %1)", np(a), np(b)); }
118118

@@ -136,137 +136,64 @@ template <class NodeNames> struct SmtPrintVisitor {
136136

137137
std::string unary_minus(Node, Node a, int) { return format("(bvneg %0)", np(a)); }
138138

139-
std::string reduce_and(Node, Node a, int) {
140-
std::stringstream ss;
141-
// We use ite to set the result to bit vector, to ensure appropriate type
142-
ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '1') << ") #b1 #b0)";
143-
return ss.str();
144-
}
139+
std::string reduce_and(Node, Node a, int) { return format("(apply bvand (bitvector->bits %0))", np(a)); }
145140

146-
std::string reduce_or(Node, Node a, int)
147-
{
148-
std::stringstream ss;
149-
// We use ite to set the result to bit vector, to ensure appropriate type
150-
ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '0') << ") #b0 #b1)";
151-
return ss.str();
152-
}
141+
std::string reduce_or(Node, Node a, int) { return format("(apply bvor (bitvector->bits %0))", np(a)); }
153142

154-
std::string reduce_xor(Node, Node a, int) {
155-
std::stringstream ss;
156-
ss << "(bvxor ";
157-
for (int i = 0; i < a.width(); ++i) {
158-
if (i > 0) ss << " ";
159-
ss << "((_ extract " << i << " " << i << ") " << np(a) << ")";
160-
}
161-
ss << ")";
162-
return ss.str();
163-
}
143+
std::string reduce_xor(Node, Node a, int) { return format("(apply bvxor (bitvector->bits %0))", np(a)); }
164144

165145
std::string equal(Node, Node a, Node b, int) {
166-
return format("(ite (= %0 %1) #b1 #b0)", np(a), np(b));
146+
return format("(bool->bitvector (bveq %0 %1))", np(a), np(b));
167147
}
168148

169149
std::string not_equal(Node, Node a, Node b, int) {
170-
return format("(ite (distinct %0 %1) #b1 #b0)", np(a), np(b));
150+
return format("(bool->bitvector (not (bveq %0 %1)))", np(a), np(b));
171151
}
172152

173153
std::string signed_greater_than(Node, Node a, Node b, int) {
174-
return format("(ite (bvsgt %0 %1) #b1 #b0)", np(a), np(b));
154+
return format("(bool->bitvector (bvsgt %0 %1))", np(a), np(b));
175155
}
176156

177157
std::string signed_greater_equal(Node, Node a, Node b, int) {
178-
return format("(ite (bvsge %0 %1) #b1 #b0)", np(a), np(b));
158+
return format("(bool->bitvector (bvsge %0 %1))", np(a), np(b));
179159
}
180160

181161
std::string unsigned_greater_than(Node, Node a, Node b, int) {
182-
return format("(ite (bvugt %0 %1) #b1 #b0)", np(a), np(b));
162+
return format("(bool->bitvector (bvugt %0 %1))", np(a), np(b));
183163
}
184164

185165
std::string unsigned_greater_equal(Node, Node a, Node b, int) {
186-
return format("(ite (bvuge %0 %1) #b1 #b0)", np(a), np(b));
187-
}
188-
189-
std::string logical_shift_left(Node, Node a, Node b, int, int) {
190-
// Get the bit-widths of a and b
191-
int bit_width_a = a.width();
192-
int bit_width_b = b.width();
193-
194-
// Extend b to match the bit-width of a if necessary
195-
std::ostringstream oss;
196-
if (bit_width_a > bit_width_b) {
197-
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
198-
} else {
199-
oss << np(b); // No extension needed if b's width is already sufficient
200-
}
201-
std::string b_extended = oss.str();
202-
203-
// Format the bvshl operation with the extended b
204-
oss.str(""); // Clear the stringstream
205-
oss << "(bvshl " << np(a) << " " << b_extended << ")";
206-
return oss.str();
207-
}
208-
209-
std::string logical_shift_right(Node, Node a, Node b, int, int) {
210-
// Get the bit-widths of a and b
211-
int bit_width_a = a.width();
212-
int bit_width_b = b.width();
213-
214-
// Extend b to match the bit-width of a if necessary
215-
std::ostringstream oss;
216-
if (bit_width_a > bit_width_b) {
217-
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
218-
} else {
219-
oss << np(b); // No extension needed if b's width is already sufficient
220-
}
221-
std::string b_extended = oss.str();
222-
223-
// Format the bvlshr operation with the extended b
224-
oss.str(""); // Clear the stringstream
225-
oss << "(bvlshr " << np(a) << " " << b_extended << ")";
226-
return oss.str();
166+
return format("(bool->bitvector (bvuge %0 %1))", np(a), np(b));
227167
}
228168

229-
std::string arithmetic_shift_right(Node, Node a, Node b, int, int) {
230-
// Get the bit-widths of a and b
231-
int bit_width_a = a.width();
232-
int bit_width_b = b.width();
169+
std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("(bvshl %0 %1)", np(a), np(b)); }
233170

234-
// Extend b to match the bit-width of a if necessary
235-
std::ostringstream oss;
236-
if (bit_width_a > bit_width_b) {
237-
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
238-
} else {
239-
oss << np(b); // No extension needed if b's width is already sufficient
240-
}
241-
std::string b_extended = oss.str();
171+
std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("(bvlshr %0 %1)", np(a), np(b)); }
242172

243-
// Format the bvashr operation with the extended b
244-
oss.str(""); // Clear the stringstream
245-
oss << "(bvashr " << np(a) << " " << b_extended << ")";
246-
return oss.str();
247-
}
173+
std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("(bvashr %0 %1)", np(a), np(b)); }
248174

249-
std::string mux(Node, Node a, Node b, Node s, int) {
250-
return format("(ite (= %2 #b1) %0 %1)", np(a), np(b), np(s));
251-
}
175+
std::string mux(Node, Node a, Node b, Node s, int) { return format("(if %2 %0 %1)", np(a), np(b), np(s)); }
252176

177+
// How does pmux?
253178
std::string pmux(Node, Node a, Node b, Node s, int, int)
254179
{
255180
// Assume s is a bit vector, combine a and b based on the selection bits
256181
return format("(pmux %0 %1 %2)", np(a), np(b), np(s));
257182
}
258183

259-
std::string constant(Node, RTLIL::Const value) { return format("#b%0", value.as_string()); }
184+
std::string constant(Node, RTLIL::Const value) { return format("(bv #b%0 %1)", value.as_string(), value.size()); }
260185

261186
std::string input(Node, IdString name) { return format("%0", scope[name]); }
262187

188+
// How does state?
263189
std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); }
264190

191+
// How does memory?
265192
std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); }
266193

267194
std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("(store %0 %1 %2)", np(mem), np(addr), np(data)); }
268195

269-
std::string undriven(Node, int width) { return format("#b%0", std::string(width, '0')); }
196+
std::string undriven(Node, int width) { return format("(bv 0 %0)", width); }
270197
};
271198

272199
struct SmtModule {
@@ -281,23 +208,46 @@ struct SmtModule {
281208
const bool stateful = ir.state().size() != 0;
282209
SmtWriter writer(out);
283210

284-
writer.print("(declare-fun %s () Bool)\n\n", name.c_str());
211+
// Rosette lang header
212+
writer.print("#lang rosette\n\n");
213+
std::string end_part = "\n";
214+
std::string indent = "\t";
215+
216+
// Not sure if this is actually necessary or not, so make it optional I guess?
217+
bool guarded = true;
285218

286-
writer.print("(declare-datatypes () ((Inputs (mk_inputs");
219+
// ???
220+
// writer.print("(declare-fun %s () Bool)\n\n", name.c_str());
221+
222+
// Inputs
223+
std::stringstream input_list;
224+
std::stringstream input_values;
287225
for (const auto &input : ir.inputs()) {
288-
std::string input_name = scope[input.first];
289-
writer.print(" (%s (_ BitVec %d))", input_name.c_str(), input.second.width());
226+
auto input_name = scope[input.first];
227+
input_list << input_name << " ";
228+
if (guarded) {
229+
input_values << end_part << indent << indent << indent;
230+
auto width = input.second.width();
231+
input_values << "(extract " << width-1 << " 0 (concat (bv 0 " << width << ") " << input_name << "))";
232+
}
290233
}
291-
writer.print("))))\n\n");
234+
writer.print("(struct Inputs (%s)", input_list.str().c_str());
235+
if (guarded) {
236+
writer.print("%s%s#:guard (lambda (%sname)%s", end_part.c_str(), indent.c_str(), input_list.str().c_str(), end_part.c_str());
237+
writer.print("%s%s(values%s))", indent.c_str(), indent.c_str(), input_values.str().c_str());
238+
}
239+
writer.print(")\n");
292240

293-
writer.print("(declare-datatypes () ((Outputs (mk_outputs");
241+
// Outputs
242+
writer.print("(struct Outputs (");
294243
for (const auto &output : ir.outputs()) {
295-
std::string output_name = scope[output.first];
296-
writer.print(" (%s (_ BitVec %d))", output_name.c_str(), output.second.width());
244+
auto output_name = scope[output.first];
245+
writer.print("%s ", output_name.c_str());
297246
}
298-
writer.print("))))\n");
247+
writer.print("))\n");
299248

300249
if (stateful) {
250+
// ?
301251
writer.print("(declare-datatypes () ((State (mk_state");
302252
for (const auto &state : ir.state()) {
303253
std::string state_name = scope[state.first];
@@ -308,21 +258,21 @@ struct SmtModule {
308258
writer.print("(declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n");
309259
}
310260

311-
if (stateful)
312-
writer.print("(define-fun %s_step ((current_state State) (inputs Inputs)) Pair", name.c_str());
313-
else
314-
writer.print("(define-fun %s_step ((inputs Inputs)) Outputs", name.c_str());
261+
// Function start
262+
writer.print("(define (%s_step inputs)%s", name.c_str(), end_part.c_str());
315263

316-
writer.print(" (let (");
264+
// Bind inputs
265+
writer.print("%s(let (", indent.c_str());
317266
for (const auto &input : ir.inputs()) {
318-
std::string input_name = scope[input.first];
319-
writer.print(" (%s (%s inputs))", input_name.c_str(), input_name.c_str());
267+
auto input_name = scope[input.first];
268+
writer.print("[%s (Inputs-%s inputs)] ", input_name.c_str(), input_name.c_str());
320269
}
321-
writer.print(" )");
270+
writer.print(")");
322271

323272
auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name()]; };
324273
SmtPrintVisitor<decltype(node_to_string)> visitor(node_to_string, scope);
325274

275+
// Bind operators
326276
for (auto it = ir.begin(); it != ir.end(); ++it) {
327277
const FunctionalIR::Node &node = *it;
328278

@@ -332,10 +282,12 @@ struct SmtModule {
332282
std::string node_name = scope[node.name()];
333283
std::string node_expr = node.visit(visitor);
334284

335-
writer.print(" (let ( (%s %s))", node_name.c_str(), node_expr.c_str());
285+
writer.print(" (let ([%s %s])", node_name.c_str(), node_expr.c_str());
336286
}
337287

288+
// Bind next state
338289
if (stateful) {
290+
// ?
339291
writer.print(" (let ( (next_state (mk_state ");
340292
for (const auto &state : ir.state()) {
341293
std::string state_name = scope[state.first];
@@ -345,7 +297,9 @@ struct SmtModule {
345297
writer.print(" )))");
346298
}
347299

300+
// Bind outputs
348301
if (stateful) {
302+
// ?
349303
writer.print(" (let ( (outputs (mk_outputs ");
350304
for (const auto &output : ir.outputs()) {
351305
std::string output_name = scope[output.first];
@@ -354,27 +308,26 @@ struct SmtModule {
354308
writer.print(" )))");
355309

356310
writer.print("(mk-pair outputs next_state)");
311+
writer.print(" )"); // Closing outputs let statement
312+
writer.print(" )"); // Closing next_state let statement
357313
}
358314
else {
359-
writer.print(" (mk_outputs ");
315+
writer.print(" (Outputs ");
360316
for (const auto &output : ir.outputs()) {
361-
std::string output_name = scope[output.first];
362-
writer.print(" %s", output_name.c_str());
317+
auto output_name = scope[output.first];
318+
writer.print("%s ", output_name.c_str());
363319
}
364-
writer.print(" )"); // Closing mk_outputs
365-
}
366-
if (stateful) {
367-
writer.print(" )"); // Closing outputs let statement
368-
writer.print(" )"); // Closing next_state let statement
320+
writer.print(")"); // Closing outputs
369321
}
322+
370323
// Close the nested lets
371-
for (size_t i = 0; i < ir.size() - ir.inputs().size(); ++i) {
372-
writer.print(" )"); // Closing each node
324+
for (auto i = ir.inputs().size(); i < ir.size(); ++i) {
325+
writer.print(")"); // Closing each node
373326
}
374327
if (ir.size() == ir.inputs().size())
375-
writer.print(" )"); // Corner case
328+
writer.print(")"); // Corner case
376329

377-
writer.print(" )"); // Closing inputs let statement
330+
writer.print(")"); // Closing inputs let statement
378331
writer.print(")\n"); // Closing step function
379332
}
380333
};

0 commit comments

Comments
 (0)