Skip to content

Commit

Permalink
Add support for 3D and 2D grouped conolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Nov 18, 2024
1 parent 32fb07d commit 0bdd955
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 49 deletions.
50 changes: 39 additions & 11 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CONV_Q = r"""%c0_i32 = arith.constant 0 : i32
%11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS}
%cst = arith.constant {ZERO} : {OUT_ELEM_TYPE}
%9 = tensor.empty() : tensor<{OUT_TYPE}>
Expand All @@ -33,30 +35,36 @@ class ConvConfig:
Q: int
F: int
S: int
is_grouped_conv: bool
G: int # group count
is_3D_conv: bool
D: int # input depth
R: int # filter depth
S_D: int # stride along depth
OP: str
input_dtype: str
output_dtype: str

def get_name(self) -> str:
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S)
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) + "_groupcount" + str(self.G)

def get_img_shape(self) -> str:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
if "nhwc" in self.OP:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype
if "nchw" in self.OP:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype

if "ngchw" in operation:
return str(self.N) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype

def get_kernel_shape(self) -> str:
if "nhwc" in self.OP:
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
if "nchw" in self.OP:
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype

if "ngchw" in operation:
return str(self.F) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype

def get_byte_count(self) -> int:
dtype_bits_map = {
Expand All @@ -73,15 +81,16 @@ def get_byte_count(self) -> int:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
input_channels = self.C
group_count = self.G
output_channels = self.F
output_width = self.W
output_height = self.H
k_width = self.Q
k_height = self.P
byte_count = (
(batch * input_channels * in_w * in_h * bytes_per_input)
+ (batch * output_channels * output_width * output_height * bytes_per_output)
+ (k_width * k_height * input_channels * output_channels * bytes_per_input)
(batch * group_count * input_channels * in_w * in_h * bytes_per_input)
+ (batch * group_count * output_channels * output_width * output_height * bytes_per_output)
+ (group_count * k_width * k_height * input_channels * output_channels * bytes_per_input)
)
return byte_count

Expand All @@ -90,14 +99,15 @@ def get_flops(self) -> int:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
input_channels = self.C
group_count = self.G
output_channels = self.F
output_width = self.W
output_height = self.H
k_width = self.Q
k_height = self.P
operation_per_pixel = k_width * k_height * input_channels * 2
output_pixels_per_batch = output_width * output_height * output_channels
flops = operation_per_pixel * output_pixels_per_batch * batch
flops = operation_per_pixel * output_pixels_per_batch * group_count * batch
return flops

def generate_mlir(config: ConvConfig):
Expand All @@ -109,11 +119,16 @@ def generate_mlir(config: ConvConfig):
q = config.Q
f = config.F
stride = config.S
g = config.G
d = config.D
r = config.R
s_d = config.S_D
operation = config.OP
dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}"
elem_types = dtypes.split("x")
in_h = str(int(h) * int(stride) + int(p) - 1)
in_w = str(int(w) * int(stride) + int(q) - 1)
in_d = str(int(d) * int(s_d) + int(r) - 1)
if "nhwc" in operation:
conv_type = "nhwc_hwcf"
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
Expand All @@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig):
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ngchw" in operation:
conv_type = "ngchw_fgchw"
lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ncdhw" in operation:
conv_type = "ncdhw_fcdhw"
lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])

one = "1"
zero = "0"
if (elem_types[0][0] == "f"):
Expand All @@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig):
conv_template = CONV
if "q" in operation:
conv_template = CONV_Q
if config.is_3D_conv:
conv_template = CONV_3D
operation = conv_template.format(
INPUT_TYPE=lhs,
FILTER_TYPE=rhs,
Expand Down
76 changes: 38 additions & 38 deletions convbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,49 @@
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def get_conv_configs() -> list[tuple[str, ConvConfig]]:
Expand Down

0 comments on commit 0bdd955

Please sign in to comment.