@@ -107,12 +107,12 @@ template <class NodeNames> struct SmtPrintVisitor {
107
107
108
108
std::string slice (Node, Node a, int , int offset, int out_width)
109
109
{
110
- return format (" ((_ extract %2 %1) %0)" , np (a), offset, offset + out_width - 1 );
110
+ return format (" (extract %2 %1 %0)" , np (a), offset, offset + out_width - 1 );
111
111
}
112
112
113
- std::string zero_extend (Node, Node a, int , int out_width) { return format (" ((_ zero_extend %1) %0 )" , np (a), out_width - a. width () ); }
113
+ std::string zero_extend (Node, Node a, int , int out_width) { return format (" (zero-extend %0 (bitvector %1))" , np (a), out_width); }
114
114
115
- std::string sign_extend (Node, Node a, int , int out_width) { return format (" ((_ sign_extend %1) %0 )" , np (a), out_width - a. width () ); }
115
+ std::string sign_extend (Node, Node a, int , int out_width) { return format (" (sign-extend %0 (bitvector %1))" , np (a), out_width); }
116
116
117
117
std::string concat (Node, Node a, int , Node b, int ) { return format (" (concat %0 %1)" , np (a), np (b)); }
118
118
@@ -136,137 +136,64 @@ template <class NodeNames> struct SmtPrintVisitor {
136
136
137
137
std::string unary_minus (Node, Node a, int ) { return format (" (bvneg %0)" , np (a)); }
138
138
139
- std::string reduce_and (Node, Node a, int ) {
140
- std::stringstream ss;
141
- // We use ite to set the result to bit vector, to ensure appropriate type
142
- ss << " (ite (= " << np (a) << " #b" << std::string (a.width (), ' 1' ) << " ) #b1 #b0)" ;
143
- return ss.str ();
144
- }
139
+ std::string reduce_and (Node, Node a, int ) { return format (" (apply bvand (bitvector->bits %0))" , np (a)); }
145
140
146
- std::string reduce_or (Node, Node a, int )
147
- {
148
- std::stringstream ss;
149
- // We use ite to set the result to bit vector, to ensure appropriate type
150
- ss << " (ite (= " << np (a) << " #b" << std::string (a.width (), ' 0' ) << " ) #b0 #b1)" ;
151
- return ss.str ();
152
- }
141
+ std::string reduce_or (Node, Node a, int ) { return format (" (apply bvor (bitvector->bits %0))" , np (a)); }
153
142
154
- std::string reduce_xor (Node, Node a, int ) {
155
- std::stringstream ss;
156
- ss << " (bvxor " ;
157
- for (int i = 0 ; i < a.width (); ++i) {
158
- if (i > 0 ) ss << " " ;
159
- ss << " ((_ extract " << i << " " << i << " ) " << np (a) << " )" ;
160
- }
161
- ss << " )" ;
162
- return ss.str ();
163
- }
143
+ std::string reduce_xor (Node, Node a, int ) { return format (" (apply bvxor (bitvector->bits %0))" , np (a)); }
164
144
165
145
std::string equal (Node, Node a, Node b, int ) {
166
- return format (" (ite (= %0 %1) #b1 #b0 )" , np (a), np (b));
146
+ return format (" (bool->bitvector (bveq %0 %1))" , np (a), np (b));
167
147
}
168
148
169
149
std::string not_equal (Node, Node a, Node b, int ) {
170
- return format (" (ite (distinct %0 %1) #b1 #b0 )" , np (a), np (b));
150
+ return format (" (bool->bitvector (not (bveq %0 %1)) )" , np (a), np (b));
171
151
}
172
152
173
153
std::string signed_greater_than (Node, Node a, Node b, int ) {
174
- return format (" (ite (bvsgt %0 %1) #b1 #b0 )" , np (a), np (b));
154
+ return format (" (bool->bitvector (bvsgt %0 %1))" , np (a), np (b));
175
155
}
176
156
177
157
std::string signed_greater_equal (Node, Node a, Node b, int ) {
178
- return format (" (ite (bvsge %0 %1) #b1 #b0 )" , np (a), np (b));
158
+ return format (" (bool->bitvector (bvsge %0 %1))" , np (a), np (b));
179
159
}
180
160
181
161
std::string unsigned_greater_than (Node, Node a, Node b, int ) {
182
- return format (" (ite (bvugt %0 %1) #b1 #b0 )" , np (a), np (b));
162
+ return format (" (bool->bitvector (bvugt %0 %1))" , np (a), np (b));
183
163
}
184
164
185
165
std::string unsigned_greater_equal (Node, Node a, Node b, int ) {
186
- return format (" (ite (bvuge %0 %1) #b1 #b0)" , np (a), np (b));
187
- }
188
-
189
- std::string logical_shift_left (Node, Node a, Node b, int , int ) {
190
- // Get the bit-widths of a and b
191
- int bit_width_a = a.width ();
192
- int bit_width_b = b.width ();
193
-
194
- // Extend b to match the bit-width of a if necessary
195
- std::ostringstream oss;
196
- if (bit_width_a > bit_width_b) {
197
- oss << " ((_ zero_extend " << (bit_width_a - bit_width_b) << " ) " << np (b) << " )" ;
198
- } else {
199
- oss << np (b); // No extension needed if b's width is already sufficient
200
- }
201
- std::string b_extended = oss.str ();
202
-
203
- // Format the bvshl operation with the extended b
204
- oss.str (" " ); // Clear the stringstream
205
- oss << " (bvshl " << np (a) << " " << b_extended << " )" ;
206
- return oss.str ();
207
- }
208
-
209
- std::string logical_shift_right (Node, Node a, Node b, int , int ) {
210
- // Get the bit-widths of a and b
211
- int bit_width_a = a.width ();
212
- int bit_width_b = b.width ();
213
-
214
- // Extend b to match the bit-width of a if necessary
215
- std::ostringstream oss;
216
- if (bit_width_a > bit_width_b) {
217
- oss << " ((_ zero_extend " << (bit_width_a - bit_width_b) << " ) " << np (b) << " )" ;
218
- } else {
219
- oss << np (b); // No extension needed if b's width is already sufficient
220
- }
221
- std::string b_extended = oss.str ();
222
-
223
- // Format the bvlshr operation with the extended b
224
- oss.str (" " ); // Clear the stringstream
225
- oss << " (bvlshr " << np (a) << " " << b_extended << " )" ;
226
- return oss.str ();
166
+ return format (" (bool->bitvector (bvuge %0 %1))" , np (a), np (b));
227
167
}
228
168
229
- std::string arithmetic_shift_right (Node, Node a, Node b, int , int ) {
230
- // Get the bit-widths of a and b
231
- int bit_width_a = a.width ();
232
- int bit_width_b = b.width ();
169
+ std::string logical_shift_left (Node, Node a, Node b, int , int ) { return format (" (bvshl %0 %1)" , np (a), np (b)); }
233
170
234
- // Extend b to match the bit-width of a if necessary
235
- std::ostringstream oss;
236
- if (bit_width_a > bit_width_b) {
237
- oss << " ((_ zero_extend " << (bit_width_a - bit_width_b) << " ) " << np (b) << " )" ;
238
- } else {
239
- oss << np (b); // No extension needed if b's width is already sufficient
240
- }
241
- std::string b_extended = oss.str ();
171
+ std::string logical_shift_right (Node, Node a, Node b, int , int ) { return format (" (bvlshr %0 %1)" , np (a), np (b)); }
242
172
243
- // Format the bvashr operation with the extended b
244
- oss.str (" " ); // Clear the stringstream
245
- oss << " (bvashr " << np (a) << " " << b_extended << " )" ;
246
- return oss.str ();
247
- }
173
+ std::string arithmetic_shift_right (Node, Node a, Node b, int , int ) { return format (" (bvashr %0 %1)" , np (a), np (b)); }
248
174
249
- std::string mux (Node, Node a, Node b, Node s, int ) {
250
- return format (" (ite (= %2 #b1) %0 %1)" , np (a), np (b), np (s));
251
- }
175
+ std::string mux (Node, Node a, Node b, Node s, int ) { return format (" (if %2 %0 %1)" , np (a), np (b), np (s)); }
252
176
177
+ // How does pmux?
253
178
std::string pmux (Node, Node a, Node b, Node s, int , int )
254
179
{
255
180
// Assume s is a bit vector, combine a and b based on the selection bits
256
181
return format (" (pmux %0 %1 %2)" , np (a), np (b), np (s));
257
182
}
258
183
259
- std::string constant (Node, RTLIL::Const value) { return format (" #b%0" , value.as_string ()); }
184
+ std::string constant (Node, RTLIL::Const value) { return format (" (bv #b%0 %1) " , value.as_string (), value. size ()); }
260
185
261
186
std::string input (Node, IdString name) { return format (" %0" , scope[name]); }
262
187
188
+ // How does state?
263
189
std::string state (Node, IdString name) { return format (" (%0 current_state)" , scope[name]); }
264
190
191
+ // How does memory?
265
192
std::string memory_read (Node, Node mem, Node addr, int , int ) { return format (" (select %0 %1)" , np (mem), np (addr)); }
266
193
267
194
std::string memory_write (Node, Node mem, Node addr, Node data, int , int ) { return format (" (store %0 %1 %2)" , np (mem), np (addr), np (data)); }
268
195
269
- std::string undriven (Node, int width) { return format (" #b%0 " , std::string ( width, ' 0 ' ) ); }
196
+ std::string undriven (Node, int width) { return format (" (bv 0 %0) " , width); }
270
197
};
271
198
272
199
struct SmtModule {
@@ -281,23 +208,46 @@ struct SmtModule {
281
208
const bool stateful = ir.state ().size () != 0 ;
282
209
SmtWriter writer (out);
283
210
284
- writer.print (" (declare-fun %s () Bool)\n\n " , name.c_str ());
211
+ // Rosette lang header
212
+ writer.print (" #lang rosette\n\n " );
213
+ std::string end_part = " \n " ;
214
+ std::string indent = " \t " ;
215
+
216
+ // Not sure if this is actually necessary or not, so make it optional I guess?
217
+ bool guarded = true ;
285
218
286
- writer.print (" (declare-datatypes () ((Inputs (mk_inputs" );
219
+ // ???
220
+ // writer.print("(declare-fun %s () Bool)\n\n", name.c_str());
221
+
222
+ // Inputs
223
+ std::stringstream input_list;
224
+ std::stringstream input_values;
287
225
for (const auto &input : ir.inputs ()) {
288
- std::string input_name = scope[input.first ];
289
- writer.print (" (%s (_ BitVec %d))" , input_name.c_str (), input.second .width ());
226
+ auto input_name = scope[input.first ];
227
+ input_list << input_name << " " ;
228
+ if (guarded) {
229
+ input_values << end_part << indent << indent << indent;
230
+ auto width = input.second .width ();
231
+ input_values << " (extract " << width-1 << " 0 (concat (bv 0 " << width << " ) " << input_name << " ))" ;
232
+ }
290
233
}
291
- writer.print (" ))))\n\n " );
234
+ writer.print (" (struct Inputs (%s)" , input_list.str ().c_str ());
235
+ if (guarded) {
236
+ writer.print (" %s%s#:guard (lambda (%sname)%s" , end_part.c_str (), indent.c_str (), input_list.str ().c_str (), end_part.c_str ());
237
+ writer.print (" %s%s(values%s))" , indent.c_str (), indent.c_str (), input_values.str ().c_str ());
238
+ }
239
+ writer.print (" )\n " );
292
240
293
- writer.print (" (declare-datatypes () ((Outputs (mk_outputs" );
241
+ // Outputs
242
+ writer.print (" (struct Outputs (" );
294
243
for (const auto &output : ir.outputs ()) {
295
- std::string output_name = scope[output.first ];
296
- writer.print (" ( %s (_ BitVec %d)) " , output_name.c_str (), output. second . width ());
244
+ auto output_name = scope[output.first ];
245
+ writer.print (" %s " , output_name.c_str ());
297
246
}
298
- writer.print (" )))) \n " );
247
+ writer.print (" ))\n " );
299
248
300
249
if (stateful) {
250
+ // ?
301
251
writer.print (" (declare-datatypes () ((State (mk_state" );
302
252
for (const auto &state : ir.state ()) {
303
253
std::string state_name = scope[state.first ];
@@ -308,21 +258,21 @@ struct SmtModule {
308
258
writer.print (" (declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n " );
309
259
}
310
260
311
- if (stateful)
312
- writer.print (" (define-fun %s_step ((current_state State) (inputs Inputs)) Pair" , name.c_str ());
313
- else
314
- writer.print (" (define-fun %s_step ((inputs Inputs)) Outputs" , name.c_str ());
261
+ // Function start
262
+ writer.print (" (define (%s_step inputs)%s" , name.c_str (), end_part.c_str ());
315
263
316
- writer.print (" (let (" );
264
+ // Bind inputs
265
+ writer.print (" %s(let (" , indent.c_str ());
317
266
for (const auto &input : ir.inputs ()) {
318
- std::string input_name = scope[input.first ];
319
- writer.print (" ( %s (%s inputs)) " , input_name.c_str (), input_name.c_str ());
267
+ auto input_name = scope[input.first ];
268
+ writer.print (" [ %s (Inputs- %s inputs)] " , input_name.c_str (), input_name.c_str ());
320
269
}
321
- writer.print (" )" );
270
+ writer.print (" )" );
322
271
323
272
auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name ()]; };
324
273
SmtPrintVisitor<decltype (node_to_string)> visitor (node_to_string, scope);
325
274
275
+ // Bind operators
326
276
for (auto it = ir.begin (); it != ir.end (); ++it) {
327
277
const FunctionalIR::Node &node = *it;
328
278
@@ -332,10 +282,12 @@ struct SmtModule {
332
282
std::string node_name = scope[node.name ()];
333
283
std::string node_expr = node.visit (visitor);
334
284
335
- writer.print (" (let ( ( %s %s) )" , node_name.c_str (), node_expr.c_str ());
285
+ writer.print (" (let ([ %s %s] )" , node_name.c_str (), node_expr.c_str ());
336
286
}
337
287
288
+ // Bind next state
338
289
if (stateful) {
290
+ // ?
339
291
writer.print (" (let ( (next_state (mk_state " );
340
292
for (const auto &state : ir.state ()) {
341
293
std::string state_name = scope[state.first ];
@@ -345,7 +297,9 @@ struct SmtModule {
345
297
writer.print (" )))" );
346
298
}
347
299
300
+ // Bind outputs
348
301
if (stateful) {
302
+ // ?
349
303
writer.print (" (let ( (outputs (mk_outputs " );
350
304
for (const auto &output : ir.outputs ()) {
351
305
std::string output_name = scope[output.first ];
@@ -354,27 +308,26 @@ struct SmtModule {
354
308
writer.print (" )))" );
355
309
356
310
writer.print (" (mk-pair outputs next_state)" );
311
+ writer.print (" )" ); // Closing outputs let statement
312
+ writer.print (" )" ); // Closing next_state let statement
357
313
}
358
314
else {
359
- writer.print (" (mk_outputs " );
315
+ writer.print (" (Outputs " );
360
316
for (const auto &output : ir.outputs ()) {
361
- std::string output_name = scope[output.first ];
362
- writer.print (" %s " , output_name.c_str ());
317
+ auto output_name = scope[output.first ];
318
+ writer.print (" %s " , output_name.c_str ());
363
319
}
364
- writer.print (" )" ); // Closing mk_outputs
365
- }
366
- if (stateful) {
367
- writer.print (" )" ); // Closing outputs let statement
368
- writer.print (" )" ); // Closing next_state let statement
320
+ writer.print (" )" ); // Closing outputs
369
321
}
322
+
370
323
// Close the nested lets
371
- for (size_t i = 0 ; i < ir.size () - ir. inputs () .size (); ++i) {
372
- writer.print (" )" ); // Closing each node
324
+ for (auto i = ir.inputs (). size (); i < ir .size (); ++i) {
325
+ writer.print (" )" ); // Closing each node
373
326
}
374
327
if (ir.size () == ir.inputs ().size ())
375
- writer.print (" )" ); // Corner case
328
+ writer.print (" )" ); // Corner case
376
329
377
- writer.print (" )" ); // Closing inputs let statement
330
+ writer.print (" )" ); // Closing inputs let statement
378
331
writer.print (" )\n " ); // Closing step function
379
332
}
380
333
};
0 commit comments