diff --git a/programming_examples/ml/resnet/ptq_conv2x/aie2.py b/programming_examples/ml/resnet/ptq_conv2x/aie2.py index 27710c2c71..6992f3da81 100755 --- a/programming_examples/ml/resnet/ptq_conv2x/aie2.py +++ b/programming_examples/ml/resnet/ptq_conv2x/aie2.py @@ -580,7 +580,10 @@ def core_body(): @core(cores[i][1], "conv2dk3.o") def core_body(): - scale = 11 + if(i==2): + scale = 9 + else: + scale = 9 for _ in for_(sys.maxsize): # acquire weights and rtps once @@ -697,7 +700,10 @@ def core_body(): @core(cores[i][3], "conv2dk3.o") def core_body(): - scale = 11 + if(i==2): + scale = 9 + else: + scale = 9 for _ in for_(sys.maxsize): # acquire weights and rtps once @@ -930,30 +936,28 @@ def sequence(inputFromL3, weightsFromL3, outputToL3): # for c, col in enumerate(rtp_name): # for r, row in enumerate(col): # NpuWriteRTPOp(row, col=c, row=r + 2, index=0, value=1) # scale - - # NpuWriteRTPOp("rtpComputeTile05", col=0, row=4, index=1, value=0) - # NpuWriteRTPOp("rtpComputeTile05", col=0, row=4, index=2, value=1) - - # NpuWriteRTPOp("rtpComputeTile13", col=1, row=3, index=1, value=0) - - # NpuWriteRTPOp("rtpComputeTile24", col=2, row=4, index=1, value=0) - - # # # write RTP parameters - # npuWriteRTPOp( - # "rtpComputeTile02", col=0, row=2, index=0, value=1 - # ) # scale - # npuWriteRTPOp( - # "rtpComputeTile03", col=0, row=3, index=0, value=1 - # ) # scale - # npuWriteRTPOp( - # "rtpComputeTile05", col=0, row=5, index=0, value=1 - # ) # scale - # npuWriteRTPOp( - # "rtpComputeTile04", col=0, row=4, index=0, value=1 - # ) # scale: conv1x1 with the same scale as the input so we match the scaling factor of output after conv1x1 and the initial input - # npuWriteRTPOp( - # "rtpComputeTile04", col=0, row=4, index=1, value=0 - # ) # skip_scale + NpuWriteRTPOp("rtpComputeTile02", col=0, row=2, index=0, value=8) + NpuWriteRTPOp("rtpComputeTile03", col=0, row=3, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile04", col=0, row=5, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile05", col=0, row=4, index=0, value=11) + NpuWriteRTPOp("rtpComputeTile05", col=0, row=4, index=1, value=0) + NpuWriteRTPOp("rtpComputeTile05", col=0, row=4, index=2, value=7) + + NpuWriteRTPOp("rtpComputeTile15", col=1, row=5, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile14", col=1, row=4, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile12", col=1, row=2, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile13", col=1, row=3, index=0, value=12) + NpuWriteRTPOp("rtpComputeTile13", col=1, row=3, index=1, value=0) + + NpuWriteRTPOp("rtpComputeTile22", col=2, row=2, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile23", col=2, row=3, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile25", col=2, row=5, index=0, value=9) + NpuWriteRTPOp("rtpComputeTile24", col=2, row=4, index=0, value=12) + NpuWriteRTPOp("rtpComputeTile24", col=2, row=4, index=1, value=0) + + rtp_1=[7,10,13,-2,10] + rtp_2=[8,10,12] + rtp_3=[9,9,12] npu_dma_memcpy_nd( metadata="act1_00_02_01", diff --git a/programming_examples/ml/resnet/ptq_conv2x/data/cifar10_label_map.txt b/programming_examples/ml/resnet/ptq_conv2x/data/cifar10_label_map.txt new file mode 100644 index 0000000000..1fc508024c --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/data/cifar10_label_map.txt @@ -0,0 +1 @@ +{"0": "airplane", "1": "automobile", "2": "bird", "3": "cat", "4": "deer", "5": "dog", "6": "frog", "7": "horse", "8": "ship", "9": "truck"} \ No newline at end of file diff --git a/programming_examples/ml/resnet/ptq_conv2x/model.py b/programming_examples/ml/resnet/ptq_conv2x/model.py new file mode 100644 index 0000000000..68c6feaa8b --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/model.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CombinedModel(nn.Module): + def __init__(self, first, aie, post): + super(CombinedModel, self).__init__() + self.first = first + self.aie = aie + self.post = post + + def forward(self, x): + x = self.first(x) + x = self.aie(x) + x = self.post(x) + return x + +class PreAIELayers(nn.Module): + def __init__(self): + super(PreAIELayers, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + # print( out) + out = F.relu(out) + return out + + +class AIEConv2xOffload(nn.Module): + def __init__(self, block, num_blocks): + super(AIEConv2xOffload, self).__init__() + self.in_planes = 64 + self.layer1 = block(in_planes=64, planes=64) + self.layer2 = block(in_planes=256, planes=64) + self.layer3 = block(in_planes=256, planes=64) + + def forward(self, x): + out = self.layer1(x) + out = self.layer2(out) + out = self.layer3(out) + return out + + +class PostAIELayers(nn.Module): + def __init__(self, block, num_blocks, num_classes): + super(PostAIELayers, self).__init__() + + self.in_planes = 256 + self.layer2 = self._make_layer(block, 128, num_blocks[0], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[1], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[2], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.layer2(x) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 32) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out +class Bottleneck_projected(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, option="A"): + super(Bottleneck_projected, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, padding_mode="zeros", bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, self.expansion * planes, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + + self.shortcut = nn.Sequential() + if in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, self.expansion * planes, kernel_size=1, bias=False + ), + nn.BatchNorm2d(self.expansion * planes), + ) + def forward(self, x): + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out = out + self.shortcut(x) + out = self.relu3(out) + return out + +class Bottleneck_fused_projected(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, option="A"): + super(Bottleneck_fused_projected, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, padding_mode="zeros", bias=False + ) + + self.conv3 = nn.Conv2d( + planes, self.expansion * planes, kernel_size=1, bias=False + ) + + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + + self.shortcut = nn.Sequential() + if in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, bias=False) + ) + + def forward(self, x): + out = self.relu1((self.conv1(x))) + out = self.relu2((self.conv2(out))) + out = self.conv3(out) + out += self.shortcut(x) + out = self.relu3(out) + return out + +def Resnet50_conv2x_offload(num_classes): + return CombinedModel( + PreAIELayers(), + AIEConv2xOffload( + Bottleneck_fused_projected, + [ + 1, + ], + ), + PostAIELayers(Bottleneck_projected, [4, 6, 3], num_classes), + ) \ No newline at end of file diff --git a/programming_examples/ml/resnet/ptq_conv2x/requirements.txt b/programming_examples/ml/resnet/ptq_conv2x/requirements.txt new file mode 100644 index 0000000000..47a9883564 --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/requirements.txt @@ -0,0 +1,4 @@ +brevitas +torchvision +tqdm +opencv-python \ No newline at end of file diff --git a/programming_examples/ml/resnet/ptq_conv2x/run_makefile.lit b/programming_examples/ml/resnet/ptq_conv2x/run_makefile.lit new file mode 100644 index 0000000000..6097345491 --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/run_makefile.lit @@ -0,0 +1,9 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// REQUIRES: ryzen_ai, chess, torch +// +// RUN: make -f %S/Makefile clean +// RUN: make -f %S/Makefile +// RUN: %run_on_npu make -f %S/Makefile run_py | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/ml/resnet/ptq_conv2x/test.py b/programming_examples/ml/resnet/ptq_conv2x/test.py index 06989d55fa..8bf5857bd6 100755 --- a/programming_examples/ml/resnet/ptq_conv2x/test.py +++ b/programming_examples/ml/resnet/ptq_conv2x/test.py @@ -13,12 +13,31 @@ import time import os import numpy as np +import model as res + from aie.utils.xrt import setup_aie, extract_trace, write_out_trace, execute import aie.utils.test as test_utils torch.use_deterministic_algorithms(True) torch.manual_seed(0) - +from utils import unpickle,load_class_label +import torchvision +from torchvision import transforms +from PIL import Image +from brevitas.nn import QuantConv2d, QuantIdentity, QuantReLU +from brevitas.quant.fixed_point import ( + Int8ActPerTensorFixedPoint, + Int8WeightPerTensorFixedPoint, + Uint8ActPerTensorFixedPoint, +) +from brevitas.graph.target.flexml import preprocess_for_flexml_quantize +from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model +import torch.utils.data as data_utils +from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate +from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate_bn +from brevitas_examples.imagenet_classification.utils import generate_dataloader +from brevitas_examples.imagenet_classification.utils import SEED +from brevitas_examples.imagenet_classification.utils import validate def main(opts): design = "resnet_conv2_x_int8" @@ -48,108 +67,235 @@ def main(opts): shape_out = (32, 32, 32, 8) # ------------------------------------------------------ - # Initialize activation, weights, scaling factor for int8 model + # Post training quantization to get int8 weights and activation for AIE # ------------------------------------------------------ - int_inp = torch.randint(1, 10, (1, 64, 32, 32)).type(torch.FloatTensor) - block_0_int_weight_1 = torch.randint(10, 20, (64, 64, 1, 1)).type(torch.FloatTensor) - block_0_int_weight_2 = torch.randint(10, 20, (64, 64, 3, 3)).type(torch.FloatTensor) - block_0_int_weight_3 = torch.randint(10, 20, (256, 64, 1, 1)).type( - torch.FloatTensor - ) - block_0_int_weight_skip = torch.randint(10, 20, (256, 64, 1, 1)).type( - torch.FloatTensor + num_classes = 10 + model = res.Resnet50_conv2x_offload(num_classes) + weights = "trained_resnet50/weight.tar" #trained FP model + saved_model_dict = torch.load(weights, map_location=torch.device("cpu")) + model.load_state_dict(saved_model_dict) + + data_dir = "data" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + transform = transforms.Compose( + [ + transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor(), + ] ) - - block_1_int_weight_1 = torch.randint(20, 30, (64, 256, 1, 1)).type( - torch.FloatTensor + transform_train = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] ) - block_1_int_weight_2 = torch.randint(20, 30, (64, 64, 3, 3)).type(torch.FloatTensor) - block_1_int_weight_3 = torch.randint(20, 30, (256, 64, 1, 1)).type( - torch.FloatTensor + transform_test = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] ) - block_2_int_weight_1 = torch.randint(30, 40, (64, 256, 1, 1)).type( - torch.FloatTensor + # CIFAR-10 dataset + train_dataset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, transform=transform_train, download=True ) - block_2_int_weight_2 = torch.randint(30, 40, (64, 64, 3, 3)).type(torch.FloatTensor) - block_2_int_weight_3 = torch.randint(30, 40, (256, 64, 1, 1)).type( - torch.FloatTensor + test_dataset = torchvision.datasets.CIFAR10( + root=data_dir, train=False, transform=transform_test, download=True ) - init_scale = 0.5 - block_0_relu_1 = 0.5 - block_0_relu_2 = 0.5 - block_0_relu_3 = 0.5 - - block_0_weight_scale1 = 0.5 - block_0_weight_scale2 = 0.5 - block_0_weight_scale3 = 0.5 - block_0_weight_scale_skip = 0.5 - - block_1_relu_1 = 0.5 - block_1_relu_2 = 0.5 - block_1_relu_3 = 0.5 - - block_1_weight_scale1 = 0.5 - block_1_weight_scale2 = 0.5 - block_1_weight_scale3 = 0.5 - block_1_quant_add_1 = 0.5 - - block_2_relu_1 = 0.5 - block_2_relu_2 = 0.5 - block_2_relu_3 = 0.5 - - block_2_weight_scale1 = 0.5 - block_2_weight_scale2 = 0.5 - block_2_weight_scale3 = 0.5 - block_2_quant_add_1 = 0.5 - - block_0_combined_scale1 = -math.log2( - init_scale * block_0_weight_scale1 / block_0_relu_1 - ) # RHS after first conv1x1 | clip 0-->255 - block_0_combined_scale2 = -math.log2( - block_0_relu_1 * block_0_weight_scale2 / block_0_relu_2 - ) # RHS after second conv3x3 | clip 0-->255 - block_0_combined_scale3 = -math.log2( - block_0_relu_2 * block_0_weight_scale3 / init_scale - ) # RHS after third conv1x1 | clip -128-->+127 - block_0_combined_scale_skip = -math.log2( - init_scale * block_0_weight_scale_skip / init_scale - ) # LHS after conv1x1 | clip -128-->+127 - block_0_combined_scale4 = -math.log2( - init_scale / block_0_relu_3 - ) # After addition | clip 0-->255 + # Data loader + indices = torch.arange(256) + tr_sub = data_utils.Subset(train_dataset, indices) + val_sub = data_utils.Subset(test_dataset, indices) + calib_loader = torch.utils.data.DataLoader(dataset=tr_sub, batch_size=64, shuffle=True) + val_loader = torch.utils.data.DataLoader(dataset=val_sub, batch_size=64, shuffle=False) + img_shape = 32 + model_aie = preprocess_for_flexml_quantize( + model.aie, + torch.ones(1, 64, img_shape, img_shape), + equalize_iters=1000, + equalize_merge_bias=True, + merge_bn=True, + ) - block_1_combined_scale1 = -math.log2( - block_0_relu_3 * block_1_weight_scale1 / block_1_relu_1 - ) # RHS after first conv1x1 | clip 0-->255 - block_1_combined_scale2 = -math.log2( - block_1_relu_1 * block_1_weight_scale2 / block_1_relu_2 - ) # RHS after second conv3x3 | clip 0-->255 - block_1_combined_scale3 = -math.log2( - block_1_relu_2 * block_1_weight_scale3 / block_1_quant_add_1 - ) # RHS after third conv1x1 | clip -128-->+127 - block_1_combined_scale4 = -math.log2( - block_1_quant_add_1 / block_1_relu_3 - ) # After addition | clip 0-->255 + quant_model = quantize_model( + model_aie, + backend="flexml", + scale_factor_type="po2_scale", + bias_bit_width=32, + weight_bit_width=8, + weight_narrow_range=False, + weight_param_method="stats", + weight_quant_granularity="per_tensor", + weight_quant_type="sym", + layerwise_first_last_bit_width=8, + act_bit_width=8, + act_param_method="stats", + act_quant_percentile=99.999, + act_quant_type="sym", + quant_format="int", + layerwise_first_last_mantissa_bit_width=4, + layerwise_first_last_exponent_bit_width=3, + weight_mantissa_bit_width=4, + weight_exponent_bit_width=3, + act_mantissa_bit_width=4, + act_exponent_bit_width=3, + ) - block_2_combined_scale1 = -math.log2( - block_1_relu_3 * block_2_weight_scale1 / block_2_relu_1 + model.aie = quant_model + model.eval() + print("Starting post training quantization:") + calibrate(calib_loader, model) + model.eval() + device, dtype = ( + next(model.parameters()).device, + next(model.parameters()).dtype, + ) + # ----------------------- + + + from numpy import load + + params = {} + weights = {} + for name, module in model.named_modules(): + if isinstance(module, QuantConv2d): + # print(name) + # print(module.quant_weight().scale) + weights[name + ".int_weight"] = module.quant_weight().int(float_datatype=False) + params[name + "_scale"] = module.quant_weight().scale.detach().numpy() + if isinstance(module, QuantIdentity): + # print(name) + # print(module.quant_act_scale()) + params[name + "_scale"] = module.quant_act_scale() + if isinstance(module, QuantReLU): + # print(name) + # print(module.quant_act_scale()) + params[name + "_scale"] = module.quant_act_scale() + np.savez(os.path.join(os.getcwd(), "int_weights.npz"), **weights) + np.savez(os.path.join(os.getcwd(), "int_conv_scale.npz"), **params) + int_wts_data = load("int_weights.npz", allow_pickle=True) + int_scale_data = load("int_conv_scale.npz", allow_pickle=True) + + int_wts_data_lst = int_wts_data.files + block_0_int_weight_1 = torch.from_numpy(int_wts_data["aie.layer1.conv1.int_weight"]) + block_0_int_weight_2 = torch.from_numpy(int_wts_data["aie.layer1.conv2.int_weight"]) + block_0_int_weight_3 = torch.from_numpy(int_wts_data["aie.layer1.conv3.int_weight"]) + block_0_int_weight_skip = torch.from_numpy(int_wts_data["aie.layer1.shortcut.0.int_weight"]) + + block_1_int_weight_1 = torch.from_numpy(int_wts_data["aie.layer2.conv1.int_weight"]) + block_1_int_weight_2 = torch.from_numpy(int_wts_data["aie.layer2.conv2.int_weight"]) + block_1_int_weight_3 = torch.from_numpy(int_wts_data["aie.layer2.conv3.int_weight"]) + + block_2_int_weight_1 = torch.from_numpy(int_wts_data["aie.layer3.conv1.int_weight"]) + block_2_int_weight_2 = torch.from_numpy(int_wts_data["aie.layer3.conv2.int_weight"]) + block_2_int_weight_3 = torch.from_numpy(int_wts_data["aie.layer3.conv3.int_weight"]) + + int_scale_data_lst = int_scale_data.files + + init_scale = int_scale_data["aie.x_quant_scale"] + block_0_relu_1 = int_scale_data["aie.layer1.relu1_scale"] + block_0_relu_2 = int_scale_data["aie.layer1.relu2_scale"] + block_0_relu_3 = int_scale_data["aie.layer1.relu3_scale"] + block_0_add_scale = int_scale_data["aie.add_quant_scale"] + + block_0_weight_scale_1 = int_scale_data["aie.layer1.conv1_scale"] + block_0_weight_scale_2 = int_scale_data["aie.layer1.conv2_scale"] + block_0_weight_scale_3 = int_scale_data["aie.layer1.conv3_scale"] + block_0_weight_scale_skip = int_scale_data["aie.layer1.shortcut.0_scale"] + + block_1_relu_1 = int_scale_data["aie.layer2.relu1_scale"] + block_1_relu_2 = int_scale_data["aie.layer2.relu2_scale"] + block_1_relu_3 = int_scale_data["aie.layer2.relu3_scale"] + block_1_add_scale = int_scale_data["aie.add_1_quant_scale"] + + block_1_weight_scale_1 = int_scale_data["aie.layer2.conv1_scale"] + block_1_weight_scale_2 = int_scale_data["aie.layer2.conv2_scale"] + block_1_weight_scale_3 = int_scale_data["aie.layer2.conv3_scale"] + + block_2_relu_1 = int_scale_data["aie.layer3.relu1_scale"] + block_2_relu_2 = int_scale_data["aie.layer3.relu2_scale"] + block_2_relu_3 = int_scale_data["aie.layer3.relu3_scale"] + block_2_add_scale = int_scale_data["aie.add_2_quant_scale"] + + block_2_weight_scale_1 = int_scale_data["aie.layer3.conv1_scale"] + block_2_weight_scale_2 = int_scale_data["aie.layer3.conv2_scale"] + block_2_weight_scale_3 = int_scale_data["aie.layer3.conv3_scale"] + + for name, param in model.named_parameters(): + if name.endswith(".bias"): + param.data.fill_(0) + + block_0_combined_scale1 = -math.log( + init_scale * block_0_weight_scale_1 / block_0_relu_1, 2 + ) # after conv1x1 + block_0_combined_scale2 = -math.log( + block_0_relu_1 * block_0_weight_scale_2 / block_0_relu_2, 2 + ) # after conv3x3 + block_0_combined_scale3 = -math.log( + block_0_relu_2 * block_0_weight_scale_3 / block_0_add_scale, 2 + ) # after conv1x1 + block_0_combined_scale4 = -math.log( + block_0_add_scale / block_0_relu_3, 2 + ) # after skip addition using init scale + # combined_scale4=-math.log(inp_scale1/inp_scale4) + block_0_combined_scale_skip = -math.log( + init_scale * block_0_weight_scale_skip / block_0_add_scale, 2 + ) # after LHS conv1x1 + + block_1_combined_scale1 = -math.log( + block_0_relu_3 * block_1_weight_scale_1 / block_1_relu_1, 2 + ) # after conv1x1 + block_1_combined_scale2 = -math.log( + block_1_relu_1 * block_1_weight_scale_2 / block_1_relu_2, 2 + ) # after conv3x3 + block_1_combined_scale3 = -math.log( + block_1_relu_2 * block_1_weight_scale_3 / block_1_add_scale, 2 + ) # after conv1x1 + block_1_combined_scale4 = -math.log( + block_1_add_scale / block_1_relu_3, 2 + ) # after skip addition using init scale + + block_2_combined_scale1 = -math.log( + block_1_relu_3 * block_2_weight_scale_1 / block_2_relu_1, 2 ) # RHS after first conv1x1 | clip 0-->255 - block_2_combined_scale2 = -math.log2( - block_2_relu_1 * block_2_weight_scale2 / block_2_relu_2 + block_2_combined_scale2 = -math.log( + block_2_relu_1 * block_2_weight_scale_2 / block_2_relu_2, 2 ) # RHS after second conv3x3 | clip 0-->255 - block_2_combined_scale3 = -math.log2( - block_2_relu_2 * block_2_weight_scale3 / block_2_quant_add_1 + block_2_combined_scale3 = -math.log( + block_2_relu_2 * block_2_weight_scale_3 / block_2_add_scale, 2 ) # RHS after third conv1x1 | clip -128-->+127 - block_2_combined_scale4 = -math.log2( - block_2_quant_add_1 / block_2_relu_3 + block_2_combined_scale4 = -math.log( + block_2_add_scale / block_2_relu_3, 2 ) # After addition | clip 0-->255 - min = 0 - max = 255 - - # ------------------------------------------------------ + print("--------------------------------------------------------------") + print("Block0 combined_scale after first conv1x1:", block_0_combined_scale1) + print("Block0 combined_scale after second conv3x3:", block_0_combined_scale2) + print("Block0 combined_scale after third conv1x1:", block_0_combined_scale3) + print("Block0 combined_scale after adding skip connection:", (block_0_combined_scale4)) + print("Block0 combined_scale after skip conv1x1:", block_0_combined_scale_skip) + + print("--------------------------------------------------------------") + print("Block1 combined_scale after first conv1x1:", block_1_combined_scale1) + print("Block1 combined_scale after second conv3x3:", block_1_combined_scale2) + print("Block1 combined_scale after third conv1x1:", block_1_combined_scale3) + print("Block1 combined_scale after adding skip connection:", (block_1_combined_scale4)) + print("--------------------------------------------------------------") + print("Block2 combined_scale block2 after first conv1x1:", block_2_combined_scale1) + print("Block2 combined_scale block2 after second conv3x3:", block_2_combined_scale2) + print("Block2 combined_scale block2 after third conv1x1:", block_2_combined_scale3) + print( + "Block2 combined_scale block2 after adding skip connection:", + (block_2_combined_scale4), + ) + print("------------------------------------------------------------------") + # ------------------------------------------------------ # Get device, load the xclbin & kernel and register them # ------------------------------------------------------ app = setup_aie( @@ -164,231 +310,10 @@ def main(opts): enable_trace=enable_trace, trace_size=trace_size, ) - - # ------------------------------------------------------ - # Define your golden reference - # ------------------------------------------------------ - class resnet_conv2_x_int8(nn.Module): - expansion = 4 - - def __init__(self, in_planes=64, planes=64): - super(resnet_conv2_x_int8, self).__init__() - - self.shortcut = nn.Conv2d( - in_planes, self.expansion * planes, kernel_size=1, bias=False - ) - # Bottleneck 0 - self.block_0_conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.block_0_conv2 = nn.Conv2d( - planes, - planes, - kernel_size=3, - padding=1, - padding_mode="zeros", - bias=False, - ) - self.block_0_conv3 = nn.Conv2d( - planes, self.expansion * planes, kernel_size=1, bias=False - ) - - self.block_0_relu1 = nn.ReLU() - self.block_0_relu2 = nn.ReLU() - self.block_0_relu3 = nn.ReLU() - - # Bottleneck 1 - self.block_1_conv1 = nn.Conv2d( - self.expansion * planes, planes, kernel_size=1, bias=False - ) - self.block_1_conv2 = nn.Conv2d( - planes, - planes, - kernel_size=3, - padding=1, - padding_mode="zeros", - bias=False, - ) - self.block_1_conv3 = nn.Conv2d( - planes, self.expansion * planes, kernel_size=1, bias=False - ) - - self.block_1_relu1 = nn.ReLU() - self.block_1_relu2 = nn.ReLU() - self.block_1_relu3 = nn.ReLU() - - # Bottleneck 2 - self.block_2_conv1 = nn.Conv2d( - self.expansion * planes, planes, kernel_size=1, bias=False - ) - self.block_2_conv2 = nn.Conv2d( - planes, - planes, - kernel_size=3, - padding=1, - padding_mode="zeros", - bias=False, - ) - self.block_2_conv3 = nn.Conv2d( - planes, self.expansion * planes, kernel_size=1, bias=False - ) - - self.block_2_relu1 = nn.ReLU() - self.block_2_relu2 = nn.ReLU() - self.block_2_relu3 = nn.ReLU() - - def forward(self, x): - # **************** Bottleneck 0 **************** - block_0_conv1_out = ( - self.block_0_conv1(x) * init_scale * block_0_weight_scale1 - ) - block_0_relu1_out = torch.clamp( - torch.round(self.block_0_relu1(block_0_conv1_out) / block_0_relu_1), - min, - max, - ) # convert to int and apply relu - block_0_conv2_out = ( - self.block_0_conv2(block_0_relu1_out) - * block_0_relu_1 - * block_0_weight_scale2 - ) - block_0_relu2_out = torch.clamp( - torch.round(self.block_0_relu2(block_0_conv2_out) / block_0_relu_2), - min, - max, - ) - block_0_conv3_out = ( - self.block_0_conv3(block_0_relu2_out) - * block_0_relu_2 - * block_0_weight_scale3 - ) - block_0_rhf_same_scale = torch.clamp( - torch.round(block_0_conv3_out / init_scale), -128, 127 - ) - - block_0_lhs_conv = self.shortcut(x) * init_scale * block_0_weight_scale_skip - block_0_lhs_same_scale = torch.clamp( - torch.round(block_0_lhs_conv / init_scale), -128, 127 - ) - # convert to int and apply relu - - block_0_skip_add = init_scale * ( - block_0_rhf_same_scale + block_0_lhs_same_scale - ) - block_0_final_out = torch.clamp( - torch.round(self.block_0_relu3(block_0_skip_add) / block_0_relu_3), - min, - max, - ) - # **************** Bottleneck 1 **************** - block_1_conv1_out = ( - self.block_1_conv1(block_0_final_out) - * block_0_relu_3 - * block_1_weight_scale1 - ) - block_1_relu1_out = torch.clamp( - torch.round(self.block_1_relu1(block_1_conv1_out) / block_1_relu_1), - min, - max, - ) # convert to int and apply relu - block_1_conv2_out = ( - self.block_1_conv2(block_1_relu1_out) - * block_1_relu_1 - * block_1_weight_scale2 - ) - block_1_relu2_out = torch.clamp( - torch.round(self.block_1_relu2(block_1_conv2_out) / block_1_relu_2), - min, - max, - ) - block_1_conv3_out = ( - self.block_1_conv3(block_1_relu2_out) - * block_1_relu_2 - * block_1_weight_scale3 - ) - block_1_rhf_same_scale = torch.clamp( - torch.round(block_1_conv3_out / block_0_relu_3), -128, 127 - ) - - block_1_skip_add = block_0_relu_3 * ( - block_1_rhf_same_scale + block_0_final_out - ) - block_1_final_out = torch.clamp( - torch.round(self.block_1_relu3(block_1_skip_add) / block_1_relu_3), - min, - max, - ) - - # **************** Bottleneck 2 **************** - block_2_conv1_out = ( - self.block_2_conv1(block_1_final_out) - * block_1_relu_3 - * block_2_weight_scale1 - ) - block_2_relu1_out = torch.clamp( - torch.round(self.block_2_relu1(block_2_conv1_out) / block_2_relu_1), - min, - max, - ) # convert to int and apply relu - block_2_conv2_out = ( - self.block_2_conv2(block_2_relu1_out) - * block_2_relu_1 - * block_2_weight_scale2 - ) - block_2_relu2_out = torch.clamp( - torch.round(self.block_2_relu2(block_2_conv2_out) / block_2_relu_2), - min, - max, - ) - block_2_conv3_out = ( - self.block_2_conv3(block_2_relu2_out) - * block_2_relu_2 - * block_2_weight_scale3 - ) - block_2_rhf_same_scale = torch.clamp( - torch.round(block_2_conv3_out / block_1_relu_3), -128, 127 - ) - - block_2_skip_add = block_1_relu_3 * ( - block_2_rhf_same_scale + block_1_final_out - ) - block_2_final_out = block_2_relu_3 * ( - torch.clamp( - torch.round(self.block_2_relu3(block_2_skip_add) / block_2_relu_3), - min, - max, - ) - ) - return block_2_final_out - - # ------------------------------------------------------ - # Pytorch baseline - # ------------------------------------------------------ - model = resnet_conv2_x_int8() - model.eval() - model.block_0_conv1.weight.data.copy_(block_0_int_weight_1) - model.block_0_conv2.weight.data.copy_(block_0_int_weight_2) - model.block_0_conv3.weight.data.copy_(block_0_int_weight_3) - model.shortcut.weight.data.copy_(block_0_int_weight_skip) - - model.block_1_conv1.weight.data.copy_(block_1_int_weight_1) - model.block_1_conv2.weight.data.copy_(block_1_int_weight_2) - model.block_1_conv3.weight.data.copy_(block_1_int_weight_3) - - model.block_2_conv1.weight.data.copy_(block_2_int_weight_1) - model.block_2_conv2.weight.data.copy_(block_2_int_weight_2) - model.block_2_conv3.weight.data.copy_(block_2_int_weight_3) - - golden_output = model(int_inp) - # ------------------------------------------------------ # Reorder input data-layout # ------------------------------------------------------ ds = DataShaper() - before_input = int_inp.squeeze().data.numpy().astype(dtype_in) - before_input.tofile( - log_folder + "/before_ifm_mem_fmt_1x1.txt", sep=",", format="%d" - ) - ifm_mem_fmt = ds.reorder_mat(before_input, "YCXC8", "CYX") - ifm_mem_fmt.tofile(log_folder + "/after_ifm_mem_fmt_1x1.txt", sep=",", format="%d") block0_wts1 = ds.reorder_mat( block_0_int_weight_1.data.numpy().astype(dtype_wts), "OIYXI8O8", "OIYX" @@ -437,51 +362,105 @@ def forward(self, x): total_wts3.tofile(log_folder + "/weights_mem_fmt_final.txt", sep=",", format="%d") - # ------------------------------------------------------ - # Main run loop - # ------------------------------------------------------ - for i in range(num_iter): - start = time.time_ns() - aie_output = execute(app, ifm_mem_fmt, total_wts3) * block_2_relu_3 - stop = time.time_ns() - - if enable_trace: - aie_output, trace = extract_trace( - aie_output, shape_out, dtype_out, trace_size - ) - write_out_trace(trace, trace_file) + import time + import cv2 + + predicted_label = [None] * 64 + cpu_predicted_label = [None] * 64 + aie_time = [None] * 64 + metafile = r"./data/cifar-10-batches-py/batches.meta" + datafile = r"./data/cifar-10-batches-py/test_batch" + data_batch_1 = unpickle(datafile) + metadata = unpickle(metafile) + images = data_batch_1["data"] + labels = data_batch_1["labels"] + images = np.reshape(images, (10000, 3, 32, 32)) + dirname = "cifar_images" + if not os.path.exists(dirname): + os.mkdir(dirname) + + # Extract and dump first 10 images + for i in range(0, 100): + im = images[i] + im = im.transpose(1, 2, 0) + im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + im_name = f"./cifar_images/image_{i}.png" + cv2.imwrite(im_name, im) + + + label_path = "data/cifar10_label_map.txt" + model_num_classes = 10 + class_label_map = load_class_label(label_path, model_num_classes) + quant_id_1 = QuantIdentity( + act_quant=Uint8ActPerTensorFixedPoint, bit_width=8, return_quant_tensor=True + ) + quant_id_1.eval() - npu_time = stop - start - npu_time_total = npu_time_total + npu_time # ------------------------------------------------------ - # Reorder output data-layout + # Main run loop # ------------------------------------------------------ - temp_out = aie_output.reshape(32, 32, 32, 8) - temp_out = ds.reorder_mat(temp_out, "CDYX", "YCXD") - ofm_mem_fmt = temp_out.reshape(256, 32, 32) - ofm_mem_fmt.tofile( - log_folder + "/after_ofm_mem_fmt_final.txt", sep=",", format="%d" - ) - ofm_mem_fmt_out = torch.from_numpy(ofm_mem_fmt).unsqueeze(0) + + for i in range(0, 64): + print("____________________________________IMAGE {}____________________________________________".format(i)) + image_name = f"./cifar_images/image_{i}.png" + img = Image.open(image_name) + input_tensor = transform_test(img) + input_batch = input_tensor.unsqueeze(0) + with torch.no_grad(): + # print(input_batch.shape + start = time.time() * 1000 + output1 = model.first(input_batch) + + # AIE OFFLOAD + qnt_inp = model.aie.x_quant(output1) + int_inp = model.aie.x_quant(output1).int(float_datatype=True) + before_input = int_inp.squeeze().data.numpy().astype(dtype_in) + ifm_mem_fmt = ds.reorder_mat(before_input, "YCXC8", "CYX") + start = time.time_ns() + aie_output = execute(app, ifm_mem_fmt, total_wts3) * block_2_relu_3 + stop = time.time_ns() + temp_out = aie_output.reshape(32, 32, 32, 8) + temp2_out = ds.reorder_mat(temp_out, "CDYX", "YCXD") + ofm_mem_fmt = temp2_out.reshape(256, 32, 32) + ofm_mem_fmt = torch.from_numpy(ofm_mem_fmt).unsqueeze(0) + final_output_aie = model.post(ofm_mem_fmt) + + # ------------------------------------------------------------------------------ + # Baseline output for functional correctness + output_golden = model.aie(output1) + max_error = torch.max(torch.abs(ofm_mem_fmt - output_golden)) + # print(max_error) + final_output_base = model.post(output_golden) + predicted_class = np.argmax(final_output_aie) + predicted_label[i] = metadata["label_names"][predicted_class] + cpu_predicted_class = np.argmax(final_output_base) + cpu_predicted_label[i] = metadata["label_names"][cpu_predicted_class] + label = metadata["label_names"][labels[i]] + print( + f" Predicted AIE: {predicted_label[i]}, Predicted CPU: {predicted_label[i]}" + ) + + # Calculate the five categories with the highest classification probability + prediction_class_index = ( + torch.topk(final_output_aie, k=5, sorted=True).indices.squeeze(0).tolist() + ) + golden_prediction_class_index = ( + torch.topk(final_output_base, k=5, sorted=True).indices.squeeze(0).tolist() + ) + npu_time = stop - start + npu_time_total = npu_time_total + npu_time # ------------------------------------------------------ # Compare the AIE output and the golden reference # ------------------------------------------------------ - print("\nAvg NPU time: {}us.".format(int((npu_time_total / num_iter) / 1000))) - - if np.allclose( - ofm_mem_fmt_out.detach().numpy(), - golden_output.detach().numpy(), - rtol=0, - atol=block_2_relu_3, - ): - print("\nPASS!\n") - exit(0) - else: - print("\nFailed.\n") - exit(-1) - + print("\nAvg NPU time: {}us.".format(int((npu_time_total / 64) / 1000))) + for x, y in zip(predicted_label, predicted_label): + if x != y: + print("\nFailed.\n") + exit(-1) + print("\nPASS!\n") + exit(0) if __name__ == "__main__": p = test_utils.create_default_argparser() diff --git a/programming_examples/ml/resnet/ptq_conv2x/utils.py b/programming_examples/ml/resnet/ptq_conv2x/utils.py new file mode 100644 index 0000000000..21a12f45c7 --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/utils.py @@ -0,0 +1,40 @@ +import json +import cv2 +import numpy as np + +def unpickle(file): + import pickle + + with open(file, "rb") as fo: + dict = pickle.load(fo, encoding="latin1") + return dict + + +def load_class_label(class_label_file: str, num_classes: int) -> list: + class_label = json.load(open(class_label_file)) + class_label_list = [class_label[str(i)] for i in range(num_classes)] + + return class_label_list + + +def extract_cifar(): + data_batch_1 = unpickle(datafile) + metadata = unpickle(metafile) + + images = data_batch_1["data"] + labels = data_batch_1["labels"] + images = np.reshape(images, (10000, 3, 32, 32)) + + import os + + dirname = "cifar_images" + if not os.path.exists(dirname): + os.mkdir(dirname) + + # Extract and dump first 10 images + for i in range(0, 100): + im = images[i] + im = im.transpose(1, 2, 0) + im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + im_name = f"./cifar_images/image_{i}.png" + cv2.imwrite(im_name, im) \ No newline at end of file