From 09026cc9d540bee27e10f00eec1266dd3c9154a2 Mon Sep 17 00:00:00 2001
From: Dag Lem <dag@nimrod.no>
Date: Fri, 8 Dec 2023 20:47:43 +0100
Subject: [PATCH] Assign from rvalue via temporary register in nowrshmsk CASE

Avoid repeating complex rvalue expressions for each condition.
---
 frontends/ast/ast.cc      | 20 ++++++++++++++++++++
 frontends/ast/ast.h       |  3 +++
 frontends/ast/simplify.cc | 22 +++++++++++++++++++---
 3 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/frontends/ast/ast.cc b/frontends/ast/ast.cc
index 5335a3992d6..9fceda5a08e 100644
--- a/frontends/ast/ast.cc
+++ b/frontends/ast/ast.cc
@@ -841,6 +841,26 @@ AstNode *AstNode::mkconst_str(const std::string &str)
 	return node;
 }
 
+// create a temporary register
+AstNode *AstNode::mktemp_logic(const std::string &name, AstNode *mod, bool nosync, int range_left, int range_right, bool is_signed)
+{
+	AstNode *wire = new AstNode(AST_WIRE, new AstNode(AST_RANGE, mkconst_int(range_left, true), mkconst_int(range_right, true)));
+	wire->str = stringf("%s%s:%d$%d", name.c_str(), RTLIL::encode_filename(filename).c_str(), location.first_line, autoidx++);
+	if (nosync)
+		wire->set_attribute(ID::nosync, AstNode::mkconst_int(1, false));
+	wire->is_signed = is_signed;
+	wire->is_logic = true;
+	mod->children.push_back(wire);
+	while (wire->simplify(true, 1, -1, false)) { }
+
+	AstNode *ident = new AstNode(AST_IDENTIFIER);
+	ident->str = wire->str;
+	ident->id2ast = wire;
+	ident->was_checked = true;
+
+	return ident;
+}
+
 bool AstNode::bits_only_01() const
 {
 	for (auto bit : bits)
diff --git a/frontends/ast/ast.h b/frontends/ast/ast.h
index f789e930b3e..97903d0a046 100644
--- a/frontends/ast/ast.h
+++ b/frontends/ast/ast.h
@@ -321,6 +321,9 @@ namespace AST
 		static AstNode *mkconst_str(const std::vector<RTLIL::State> &v);
 		static AstNode *mkconst_str(const std::string &str);
 
+		// helper function to create an AST node for a temporary register
+		AstNode *mktemp_logic(const std::string &name, AstNode *mod, bool nosync, int range_left, int range_right, bool is_signed);
+
 		// helper function for creating sign-extended const objects
 		RTLIL::Const bitsAsConst(int width, bool is_signed);
 		RTLIL::Const bitsAsConst(int width = -1);
diff --git a/frontends/ast/simplify.cc b/frontends/ast/simplify.cc
index 6cd92670201..f313e3360b8 100644
--- a/frontends/ast/simplify.cc
+++ b/frontends/ast/simplify.cc
@@ -2930,8 +2930,24 @@ bool AstNode::simplify(bool const_fold, int stage, int width_hint, bool sign_hin
 			long long max_offset = (1ll << (max_width - case_sign_hint)) - 1;
 			long long min_offset = case_sign_hint ? -(1ll << (max_width - 1)) : 0;
 
+			AstNode *caseNode = new AstNode(AST_CASE, shift_expr);
+			AstNode *rvalue;
+			if (children[1]->type == AST_CONSTANT || (children[1]->type == AST_IDENTIFIER && children[1]->id2ast->type == AST_WIRE)) {
+				rvalue = children[1];
+				newNode = caseNode;
+			} else {
+				// Temporary register holding the result of the (possibly complex) rvalue expression,
+				// avoiding repetition in each AST_COND below.
+				int rvalue_width;
+				bool rvalue_sign;
+				children[1]->detectSignWidth(rvalue_width, rvalue_sign);
+				rvalue = mktemp_logic("$bitselwrite$rvalue$", current_ast_mod, true, rvalue_width - 1, 0, rvalue_sign);
+				newNode = new AstNode(AST_BLOCK,
+						      new AstNode(AST_ASSIGN_EQ, rvalue, children[1]->clone()),
+						      caseNode);
+                        }
+
 			did_something = true;
-			newNode = new AstNode(AST_CASE, shift_expr);
 			for (int i = 1 - result_width; i < wire_width; i++) {
 				// Out of range indexes are handled in genrtlil.cc
 				int start_bit = wire_offset + i;
@@ -2949,8 +2965,8 @@ bool AstNode::simplify(bool const_fold, int stage, int width_hint, bool sign_hin
 					lvalue->set_attribute(ID::wiretype, member_node->clone());
 				lvalue->children.push_back(new AstNode(AST_RANGE,
 						mkconst_int(end_bit, true), mkconst_int(start_bit, true)));
-				cond->children.push_back(new AstNode(AST_BLOCK, new AstNode(type, lvalue, children[1]->clone())));
-				newNode->children.push_back(cond);
+				cond->children.push_back(new AstNode(AST_BLOCK, new AstNode(type, lvalue, rvalue->clone())));
+				caseNode->children.push_back(cond);
 			}
 		} else {
 			// mask and shift operations