Skip to content

Commit

Permalink
simplify macc: remove num_bits field from CONFIG, all addends are mul…
Browse files Browse the repository at this point in the history
…, all factors have non-zero length
  • Loading branch information
Emil Tywoniak authored and Emil Tywoniak committed Apr 10, 2024
1 parent c5912f4 commit b440a5e
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 89 deletions.
38 changes: 11 additions & 27 deletions kernel/macc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@

YOSYS_NAMESPACE_BEGIN


struct Macc
{
struct port_t {
RTLIL::SigSpec in_a, in_b;
bool is_signed, do_subtract;
};

const int num_bits = 16;
std::vector<port_t> ports;
RTLIL::SigSpec bit_ports;

Expand All @@ -48,11 +50,6 @@ struct Macc
if (GetSize(port.in_a) < GetSize(port.in_b))
std::swap(port.in_a, port.in_b);

if (GetSize(port.in_a) == 1 && GetSize(port.in_b) == 0 && !port.is_signed && !port.do_subtract) {
bit_ports.append(port.in_a);
continue;
}

if (port.in_a.is_fully_const() && port.in_b.is_fully_const()) {
RTLIL::Const v = port.in_a.as_const();
if (GetSize(port.in_b))
Expand Down Expand Up @@ -110,12 +107,6 @@ struct Macc
int config_width = cell->getParam(ID::CONFIG_WIDTH).as_int();
log_assert(GetSize(config_bits) >= config_width);

int num_bits = 0;
if (config_bits[config_cursor++] == State::S1) num_bits |= 1;
if (config_bits[config_cursor++] == State::S1) num_bits |= 2;
if (config_bits[config_cursor++] == State::S1) num_bits |= 4;
if (config_bits[config_cursor++] == State::S1) num_bits |= 8;

int port_a_cursor = 0;
while (port_a_cursor < GetSize(port_a))
{
Expand All @@ -130,6 +121,7 @@ struct Macc
if (config_bits[config_cursor++] == State::S1)
size_a |= 1 << i;

log_assert(size_a);
this_port.in_a = port_a.extract(port_a_cursor, size_a);
port_a_cursor += size_a;

Expand All @@ -138,11 +130,11 @@ struct Macc
if (config_bits[config_cursor++] == State::S1)
size_b |= 1 << i;

log_assert(size_b);
this_port.in_b = port_a.extract(port_a_cursor, size_b);
port_a_cursor += size_b;

if (size_a || size_b)
ports.push_back(this_port);
ports.push_back(this_port);
}

log_assert(config_cursor == config_width);
Expand All @@ -153,26 +145,22 @@ struct Macc
{
RTLIL::SigSpec port_a;
std::vector<RTLIL::State> config_bits;
int max_size = 0, num_bits = 0;
int max_size = 0;
const int num_bits = 16;

for (auto &port : ports) {
max_size = max(max_size, GetSize(port.in_a));
max_size = max(max_size, GetSize(port.in_b));
}

while (max_size)
num_bits++, max_size /= 2;

log_assert(num_bits < 16);
config_bits.push_back(num_bits & 1 ? State::S1 : State::S0);
config_bits.push_back(num_bits & 2 ? State::S1 : State::S0);
config_bits.push_back(num_bits & 4 ? State::S1 : State::S0);
config_bits.push_back(num_bits & 8 ? State::S1 : State::S0);
log_assert(max_size <= 16);

for (auto &port : ports)
{
if (GetSize(port.in_a) == 0)
continue;
if (GetSize(port.in_b) == 0)
continue;

config_bits.push_back(port.is_signed ? State::S1 : State::S0);
config_bits.push_back(port.do_subtract ? State::S1 : State::S0);
Expand Down Expand Up @@ -207,11 +195,7 @@ struct Macc
if (!port.in_a.is_fully_const() || !port.in_b.is_fully_const())
return false;

RTLIL::Const summand;
if (GetSize(port.in_b) == 0)
summand = const_pos(port.in_a.as_const(), port.in_b.as_const(), port.is_signed, port.is_signed, GetSize(result));
else
summand = const_mul(port.in_a.as_const(), port.in_b.as_const(), port.is_signed, port.is_signed, GetSize(result));
RTLIL::Const summand = const_mul(port.in_a.as_const(), port.in_b.as_const(), port.is_signed, port.is_signed, GetSize(result));

if (port.do_subtract)
result = const_sub(result, summand, port.is_signed, port.is_signed, GetSize(result));
Expand Down
30 changes: 10 additions & 20 deletions kernel/satgen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -760,28 +760,18 @@ bool SatGen::importCell(RTLIL::Cell *cell, int timestep)
in_a.push_back(port.is_signed && !in_a.empty() ? in_a.back() : ez->CONST_FALSE);
in_a.resize(GetSize(y));

if (GetSize(in_b))
{
while (GetSize(in_b) < GetSize(y))
in_b.push_back(port.is_signed && !in_b.empty() ? in_b.back() : ez->CONST_FALSE);
in_b.resize(GetSize(y));

for (int i = 0; i < GetSize(in_b); i++) {
std::vector<int> shifted_a(in_a.size(), ez->CONST_FALSE);
for (int j = i; j < int(in_a.size()); j++)
shifted_a.at(j) = in_a.at(j-i);
if (port.do_subtract)
tmp = ez->vec_ite(in_b.at(i), ez->vec_sub(tmp, shifted_a), tmp);
else
tmp = ez->vec_ite(in_b.at(i), ez->vec_add(tmp, shifted_a), tmp);
}
}
else
{
while (GetSize(in_b) < GetSize(y))
in_b.push_back(port.is_signed && !in_b.empty() ? in_b.back() : ez->CONST_FALSE);
in_b.resize(GetSize(y));

for (int i = 0; i < GetSize(in_b); i++) {
std::vector<int> shifted_a(in_a.size(), ez->CONST_FALSE);
for (int j = i; j < int(in_a.size()); j++)
shifted_a.at(j) = in_a.at(j-i);
if (port.do_subtract)
tmp = ez->vec_sub(tmp, in_a);
tmp = ez->vec_ite(in_b.at(i), ez->vec_sub(tmp, shifted_a), tmp);
else
tmp = ez->vec_add(tmp, in_a);
tmp = ez->vec_ite(in_b.at(i), ez->vec_add(tmp, shifted_a), tmp);
}
}

Expand Down
2 changes: 1 addition & 1 deletion passes/cmds/clean_zerowidth.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct CleanZeroWidthPass : public Pass {
if (cell->getParam(ID::Y_WIDTH).as_int() == 0) {
module->remove(cell);
} else if (cell->type == ID($macc)) {
// TODO: fixing zero-width A and B not supported.
// fixing zero-width A and B not supported.
} else {
if (cell->getParam(ID::A_WIDTH).as_int() == 0) {
cell->setPort(ID::A, State::S0);
Expand Down
37 changes: 15 additions & 22 deletions passes/opt/share.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ struct ShareWorker

static int bits_macc_port(const Macc::port_t &p, int width)
{
if (GetSize(p.in_a) == 0 || GetSize(p.in_b) == 0)
return min(max(GetSize(p.in_a), GetSize(p.in_b)), width);
log_assert(GetSize(p.in_a) && GetSize(p.in_b));
return min(GetSize(p.in_a), width) * min(GetSize(p.in_b), width) / 2;
}

Expand All @@ -132,14 +131,11 @@ struct ShareWorker

static bool cmp_macc_ports(const Macc::port_t &p1, const Macc::port_t &p2)
{
bool mul1 = GetSize(p1.in_a) && GetSize(p1.in_b);
bool mul2 = GetSize(p2.in_a) && GetSize(p2.in_b);
log_assert(GetSize(p1.in_a) && GetSize(p1.in_b));
log_assert(GetSize(p2.in_a) && GetSize(p2.in_b));

int w1 = mul1 ? GetSize(p1.in_a) * GetSize(p1.in_b) : GetSize(p1.in_a) + GetSize(p1.in_b);
int w2 = mul2 ? GetSize(p2.in_a) * GetSize(p2.in_b) : GetSize(p2.in_a) + GetSize(p2.in_b);

if (mul1 != mul2)
return mul1;
int w1 = GetSize(p1.in_a) * GetSize(p1.in_b);
int w2 = GetSize(p2.in_a) * GetSize(p2.in_b);

if (w1 != w2)
return w1 > w2;
Expand All @@ -165,22 +161,19 @@ struct ShareWorker
if (p1.do_subtract != p2.do_subtract)
return -1;

bool mul1 = GetSize(p1.in_a) && GetSize(p1.in_b);
bool mul2 = GetSize(p2.in_a) && GetSize(p2.in_b);

if (mul1 != mul2)
return -1;
log_assert(GetSize(p1.in_a) && GetSize(p1.in_b));
log_assert(GetSize(p2.in_a) && GetSize(p2.in_b));

bool force_signed = false, force_not_signed = false;

if ((GetSize(p1.in_a) && GetSize(p1.in_a) < w1) || (GetSize(p1.in_b) && GetSize(p1.in_b) < w1)) {
if ((GetSize(p1.in_a) < w1) || (GetSize(p1.in_b) < w1)) {
if (p1.is_signed)
force_signed = true;
else
force_not_signed = true;
}

if ((GetSize(p2.in_a) && GetSize(p2.in_a) < w2) || (GetSize(p2.in_b) && GetSize(p2.in_b) < w2)) {
if ((GetSize(p2.in_a) < w2) || (GetSize(p2.in_b) < w2)) {
if (p2.is_signed)
force_signed = true;
else
Expand Down Expand Up @@ -281,12 +274,12 @@ struct ShareWorker
RTLIL::SigSpec sig_a = m1.ports[i].in_a;
RTLIL::SigSpec sig_b = m1.ports[i].in_b;

if (supercell_aux && GetSize(sig_a)) {
log_assert(GetSize(sig_a) && GetSize(sig_b));

if (supercell_aux) {
sig_a = module->addWire(NEW_ID, GetSize(sig_a));
supercell_aux->insert(module->addMux(NEW_ID, RTLIL::SigSpec(0, GetSize(sig_a)), m1.ports[i].in_a, act, sig_a));
}

if (supercell_aux && GetSize(sig_b)) {
sig_b = module->addWire(NEW_ID, GetSize(sig_b));
supercell_aux->insert(module->addMux(NEW_ID, RTLIL::SigSpec(0, GetSize(sig_b)), m1.ports[i].in_b, act, sig_b));
}
Expand All @@ -304,12 +297,12 @@ struct ShareWorker
RTLIL::SigSpec sig_a = m2.ports[i].in_a;
RTLIL::SigSpec sig_b = m2.ports[i].in_b;

if (supercell_aux && GetSize(sig_a)) {
log_assert(GetSize(sig_a) && GetSize(sig_b));

if (supercell_aux) {
sig_a = module->addWire(NEW_ID, GetSize(sig_a));
supercell_aux->insert(module->addMux(NEW_ID, m2.ports[i].in_a, RTLIL::SigSpec(0, GetSize(sig_a)), act, sig_a));
}

if (supercell_aux && GetSize(sig_b)) {
sig_b = module->addWire(NEW_ID, GetSize(sig_b));
supercell_aux->insert(module->addMux(NEW_ID, m2.ports[i].in_b, RTLIL::SigSpec(0, GetSize(sig_b)), act, sig_b));
}
Expand Down
4 changes: 2 additions & 2 deletions passes/tests/test_cell.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ static void create_gold_module(RTLIL::Design *design, RTLIL::IdString cell_type,
for (int i = 0; i < depth; i++)
{
int size_a = xorshift32(width) + 1;
int size_b = depth > 4 ? 0 : xorshift32(width) + 1;
int size_b = depth > 4 ? 1 : xorshift32(width) + 1;

if (mulbits_a + size_a*size_b <= 96 && mulbits_b + size_a + size_b <= 16 && xorshift32(2) == 1) {
mulbits_a += size_a * size_b;
mulbits_b += size_a + size_b;
} else
size_b = 0;
size_b = 1;

Macc::port_t this_port;

Expand Down
32 changes: 15 additions & 17 deletions techlibs/common/simlib.v
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,8 @@ parameter A_WIDTH = 0;
parameter B_WIDTH = 0;
parameter Y_WIDTH = 0;
// CONFIG determines the layout of A, as explained below
parameter CONFIG = 4'b0000;
parameter CONFIG_WIDTH = 4;
parameter CONFIG = 1'b0;
parameter CONFIG_WIDTH = 0;

// In the terms used for this cell, there's mixed meanings for the term "port". To disambiguate:
// A cell port is for example the A input (it is constructed in C++ as cell->setPort(ID::A, ...))
Expand All @@ -946,21 +946,20 @@ function integer my_clog2;
endfunction

// Bits that a factor's length field in CONFIG per factor in cell port A
localparam integer num_bits = CONFIG[3:0] > 0 ? CONFIG[3:0] : 1;
localparam integer num_bits = 16;
// Number of multiplier ports
localparam integer num_ports = (CONFIG_WIDTH-4) / (2 + 2*num_bits);
localparam integer num_ports = (CONFIG_WIDTH) / (2 + 2*num_bits);
// Minium bit width of an induction variable to iterate over all bits of cell port A
localparam integer num_abits = my_clog2(A_WIDTH) > 0 ? my_clog2(A_WIDTH) : 1;

// In this pseudocode, u(foo) means an unsigned int that's foo bits long.
// The CONFIG parameter carries the following information:
// struct CONFIG {
// u4 num_bits;
// struct port_field {
// bool is_signed;
// bool is_subtract;
// u(num_bits) factor1_len;
// u(num_bits) factor2_len;
// u16 factor1_len;
// u16 factor2_len;
// }[num_ports];
// };

Expand Down Expand Up @@ -989,19 +988,19 @@ function [2*num_ports*num_abits-1:0] get_port_offsets;
get_port_offsets = 0;
for (i = 0; i < num_ports; i = i+1) begin
get_port_offsets[(2*i + 0)*num_abits +: num_abits] = cursor;
cursor = cursor + cfg[4 + i*(2 + 2*num_bits) + 2 +: num_bits];
cursor = cursor + cfg[i*(2 + 2*num_bits) + 2 +: num_bits];
get_port_offsets[(2*i + 1)*num_abits +: num_abits] = cursor;
cursor = cursor + cfg[4 + i*(2 + 2*num_bits) + 2 + num_bits +: num_bits];
cursor = cursor + cfg[i*(2 + 2*num_bits) + 2 + num_bits +: num_bits];
end
end
endfunction

localparam [2*num_ports*num_abits-1:0] port_offsets = get_port_offsets(CONFIG);

`define PORT_IS_SIGNED (0 + CONFIG[4 + i*(2 + 2*num_bits)])
`define PORT_DO_SUBTRACT (0 + CONFIG[4 + i*(2 + 2*num_bits) + 1])
`define PORT_SIZE_A (0 + CONFIG[4 + i*(2 + 2*num_bits) + 2 +: num_bits])
`define PORT_SIZE_B (0 + CONFIG[4 + i*(2 + 2*num_bits) + 2 + num_bits +: num_bits])
`define PORT_IS_SIGNED (0 + CONFIG[i*(2 + 2*num_bits)])
`define PORT_DO_SUBTRACT (0 + CONFIG[i*(2 + 2*num_bits) + 1])
`define PORT_SIZE_A (0 + CONFIG[i*(2 + 2*num_bits) + 2 +: num_bits])
`define PORT_SIZE_B (0 + CONFIG[i*(2 + 2*num_bits) + 2 + num_bits +: num_bits])
`define PORT_OFFSET_A (0 + port_offsets[2*i*num_abits +: num_abits])
`define PORT_OFFSET_B (0 + port_offsets[2*i*num_abits + num_abits +: num_abits])

Expand All @@ -1018,19 +1017,18 @@ always @* begin
for (j = 0; j < `PORT_SIZE_A; j = j+1)
tmp_a[j] = A[`PORT_OFFSET_A + j];

if (`PORT_IS_SIGNED && `PORT_SIZE_A > 0)
if (`PORT_IS_SIGNED)
for (j = `PORT_SIZE_A; j < Y_WIDTH; j = j+1)
tmp_a[j] = tmp_a[`PORT_SIZE_A-1];

for (j = 0; j < `PORT_SIZE_B; j = j+1)
tmp_b[j] = A[`PORT_OFFSET_B + j];

if (`PORT_IS_SIGNED && `PORT_SIZE_B > 0)
if (`PORT_IS_SIGNED)
for (j = `PORT_SIZE_B; j < Y_WIDTH; j = j+1)
tmp_b[j] = tmp_b[`PORT_SIZE_B-1];

if (`PORT_SIZE_B > 0)
tmp_a = tmp_a * tmp_b;
tmp_a = tmp_a * tmp_b;

if (`PORT_DO_SUBTRACT)
Y = Y - tmp_a;
Expand Down

0 comments on commit b440a5e

Please sign in to comment.