diff --git a/src/sort/heap_sort.cpp b/src/sort/heap_sort.cpp index 0baacc8..143ffdd 100644 --- a/src/sort/heap_sort.cpp +++ b/src/sort/heap_sort.cpp @@ -3,8 +3,40 @@ #include #include +template void __shiftDown(T arr[], int n, int k) { + while (2 * k + 1 < n) { + int j = 2 * k + 1; // 在此轮循环中,arr[k]和arr[j]交换位置 + if (j + 1 < n && arr[j + 1] > arr[j]) { + j += 1; + } + if (arr[k] >= arr[j]) { + break; + } + swap(arr[k], arr[j]); + k = j; + } +} + +template void heapSort(T arr[], int n) { + // heapify + // 从(最后一个元素的索引-1)/2开始 + // 最后一个元素的索引 = n-1 + for (int i = (n - 1 - 1) / 2; i >= 0; i--) { + __shiftDown(arr, n, i); + } + + for (int i = n - 1; i > 0; i--) { + swap(arr[0], arr[i]); + __shiftDown(arr, i, 0); + } +} + class Solution { public: + vector heapSort(vector &nums) { + ::heapSort(&nums[0], nums.size()); + return nums; + } /** * @brief 堆排序 * nlogn @@ -32,7 +64,7 @@ class Solution { * @param nums * @return vector */ - vector heapSort(vector &nums) { + vector heapSort2(vector &nums) { int n = nums.size(); int arr[n]; for (int i = 0; i < n; i++) { diff --git a/test/lib/lib_test.cpp b/test/lib/lib_test.cpp index d5b2a62..47f80f9 100644 --- a/test/lib/lib_test.cpp +++ b/test/lib/lib_test.cpp @@ -1,4 +1,4 @@ -// 执行编译时间:2023-09-17 14:49:28 +// 执行编译时间:2023-09-17 15:52:26 #include #include diff --git a/test/sort/heap_sort_test.cpp b/test/sort/heap_sort_test.cpp index 86ca888..c9e03f3 100644 --- a/test/sort/heap_sort_test.cpp +++ b/test/sort/heap_sort_test.cpp @@ -11,5 +11,12 @@ TEST(堆排序, heapSort1) { Solution solution; vector nums = {2, 1, 3, 5, 4}; vector actionVec = {1, 2, 3, 4, 5}; - EXPECT_THAT(solution.heapSort(nums), ::testing::ContainerEq(actionVec)); + EXPECT_THAT(solution.heapSort1(nums), ::testing::ContainerEq(actionVec)); +} + +TEST(堆排序, heapSort2) { + Solution solution; + vector nums = {2, 1, 3, 5, 4}; + vector actionVec = {1, 2, 3, 4, 5}; + EXPECT_THAT(solution.heapSort2(nums), ::testing::ContainerEq(actionVec)); } \ No newline at end of file