diff --git a/benchmarks/OpOptimization/Pooling/IREE/main.py b/benchmarks/OpOptimization/Pooling/IREE/main.py new file mode 100644 index 00000000..1cb3f965 --- /dev/null +++ b/benchmarks/OpOptimization/Pooling/IREE/main.py @@ -0,0 +1,147 @@ +# ===- main.py ----------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This file implements the IREE optimization entry for Pooling. +# you can choose run on CPU/GPU by change iree_backend = "cuda" or "llvm-cpu" in pooling_iree.py. +# See the IREE license at: https://github.com/openxla/iree/blob/main/LICENSE +# +# ===--------------------------------------------------------------------------- + +import numpy +import time +from pooling_iree import * + +# ------------------------------------------------------------------------------ +# User Configurable Variables +# ------------------------------------------------------------------------------ +dtype = "float32" + + +# ------------------------------------------------------------------------------ +# Helper Function +# ------------------------------------------------------------------------------ +def iree_evaluator(s, inputs, num): + result = s.forward(inputs) + all_time = [] + for i in range(num): + start = time.time() + s.forward(inputs) + end = time.time() + elapsed_time = end - start + all_time.append(elapsed_time) + average_time = sum(all_time) / num + return average_time + + +def numpy_evaluator(a_tensor, b_tensor, num): + a_tensor_np = a_tensor.numpy() + b_tensor_np = b_tensor.numpy() + batch_size = a_tensor.shape[0] + result_size1 = a_tensor.shape[1] + result_size2 = b_tensor.shape[2] + result = np.random.randn(batch_size, result_size1, result_size2) + all_time = [] + for i in range(num): + for j in range(batch_size): + start = time.time() + result[j] = np.dot(a_tensor_np[j], b_tensor_np[j]) + end = time.time() + elapsed_time = end - start + all_time.append(elapsed_time) + average_time = sum(all_time) / num + return average_time + + +def evaluator(s, inputs, num): + result = s(inputs) + all_time = [] + for i in range(num): + start = time.time() + s(inputs) + end = time.time() + elapsed_time = end - start + all_time.append(elapsed_time) + average_time = sum(all_time) / num + return average_time + + +def evaluate_operation(s, inputs, optimization, log): + """Evaluate operation correctness and print the performance information. + Args: + s: The schedule to be built. + inputs: The input tensors. + optimization: The name of the optimization. + log: The log list. + """ + if optimization == "IREE": + mean_time = iree_evaluator(s, inputs, 10) + else: + mean_time = evaluator(s, inputs, 10) + log.append((optimization, mean_time)) + + +def report_performance(log): + """Convert the log into a performance table. + Args: + log: The log list. + """ + baseline = log[-1][1] + header = ( + "Benchmark".ljust(20) + "\t" + "Time".rjust(10) + "\t" + "SpeedUp".rjust(10) + ) + split_line = "-" * 50 + print(split_line) + print(header) + print(split_line) + for result in log: + formatted_time = "{:.2f}".format(result[1]) + formatted_performance = "{:.2f}".format(baseline / result[1]) + print( + "\033[32m%s\033[0m\t\033[33m%s\033[0m\t\033[34m%s\033[0m" + % ( + result[0].ljust(20), + str(formatted_time + " ms").rjust(10), + str(formatted_performance).rjust(10), + ) + ) + + +def main(): + # ---------------------------------------------------------------------------- + # Initialization and Baseline + # ---------------------------------------------------------------------------- + # Initialize the log list. + log = [] + # Generate random tensor for testing. + size = (512, 64, 3) + c, n, k, p, s = size[0], size[0], size[1], size[2], 1 + oc, ic, n, k, p, s = size[0], size[0], size[1], size[2], 1, 1 + data, out_max = get_pool_data_torch(c, n, k, p, s) + model = torch_pooling(k, p, s) + model(data) + example_input = data + invoker = iree_pooling(model, example_input) + # ---------------------------------------------------------------------------- + # Register Benchmarks and Dump Report + # ---------------------------------------------------------------------------- + # Register default schedule. + evaluate_operation(invoker, inputs=example_input, optimization="IREE", log=log) + evaluate_operation(model, inputs=data, optimization="torch_cpu", log=log) + report_performance(log) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/OpOptimization/Pooling/IREE/pooling_iree.py b/benchmarks/OpOptimization/Pooling/IREE/pooling_iree.py new file mode 100644 index 00000000..5ecf5858 --- /dev/null +++ b/benchmarks/OpOptimization/Pooling/IREE/pooling_iree.py @@ -0,0 +1,88 @@ +# ===- pooling_iree.py --------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This file implements the IREE optimization for Pooling. +# you can choose run on CPU/GPU by change iree_backend = "cuda" or "llvm-cpu" in pooling_iree.py. +# See the IREE license at: https://github.com/openxla/iree/blob/main/LICENSE +# +# ===--------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torch_mlir +import iree_torch +import io +import numpy as np + + +def conv_out_size(n, k, p, s): + """Compute the output size by given input size n (width or height), + kernel size k, padding p, and stride s + Return output size (width or height) + """ + return (n - k + 2 * p) // s + 1 + + +def get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None): + """Return random 3-D data tensor, 3-D kernel tenor and empty 3-D output + tensor with the shapes specified by input arguments. + oc, ic : output and input channels + n : input width and height + k : kernel width and height + p : padding size, default 0 + s : stride, default 1 + constructor : user-defined tensor constructor + """ + np.random.seed(0) + data = np.random.normal(size=(ic, n, n)).astype("float32") + weight = np.random.normal(size=(oc, ic, k, k)).astype("float32") + on = conv_out_size(n, k, p, s) + out = np.empty((oc, on, on), dtype="float32") + if constructor: + data, weight, out = (constructor(x) for x in [data, weight, out]) + return data, weight, out + + +def get_pool_data_torch(c, n, k, p, s): + data, _, out = get_conv_data(c, c, n, k, p, s, lambda x: torch.from_numpy(x)) + data = data.unsqueeze(0) + out = out.unsqueeze(0) + return data, out + + +class pooling_model(nn.Module): + def __init__(self, k, p, s): + super(pooling_model, self).__init__() + self.pool = nn.MaxPool2d(k, s, p) + + def forward(self, x): + result = self.pool(x) + return result + + +def torch_pooling(k, p, s): + model = pooling_model(k, s, p) + return model + + +def iree_pooling(model, example_input): + linalg_on_tensors_mlir = torch_mlir.compile( + model, example_input, output_type="linalg-on-tensors", use_tracing=False + ) + iree_backend = "llvm-cpu" + iree_vmfb = iree_torch.compile_to_vmfb(linalg_on_tensors_mlir, iree_backend) + invoker = iree_torch.load_vmfb(iree_vmfb, iree_backend) + return invoker