diff --git a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp index 62729256dac23..cbec5d89bbac7 100644 --- a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp +++ b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp @@ -83,6 +83,8 @@ Return Value: #endif + int countb = 0; + do { float BElements00; @@ -116,6 +118,7 @@ Return Value: // const float* a = A; + const float* b = B; size_t k = CountK; while (k >= 2) { @@ -128,10 +131,10 @@ Return Value: Row1AElements1 = a[lda + 1]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -144,10 +147,10 @@ Return Value: Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; + BElements00 = b[16]; + BElements01 = b[17]; + BElements02 = b[18]; + BElements03 = b[19]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; @@ -161,7 +164,7 @@ Return Value: } a += 2; - B += 8; + b += 32; k -= 2; } @@ -173,10 +176,10 @@ Return Value: Row1AElements0 = a[lda]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -188,8 +191,6 @@ Return Value: Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - - B += 4; } // @@ -295,9 +296,14 @@ Return Value: break; } + B += 4; C += 4; CountN -= 4; + countb = (countb + 1) % 4; + if (countb == 0) { + B += CountK * 16 - 16; + } } while (CountN > 0); return ProcessTwoRows ? 2 : 1;