Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor Speculative Decoding Performance on M2 Ultra #1281

Open
mattjcly opened this issue Feb 12, 2025 · 2 comments
Open

Poor Speculative Decoding Performance on M2 Ultra #1281

mattjcly opened this issue Feb 12, 2025 · 2 comments

Comments

@mattjcly
Copy link
Contributor

Speculative decoding does not seem to improve generation speed as expected on M2 Ultra Mac Studio, 128GB.

Main model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit
Draft model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit or https://huggingface.co/mlx-community/Qwen2.5-0.5B-Instruct-4bit

Prompt: "Write a quicksort algorithm"
Without spec decoding: 29.803 tokens-per-sec
With spec decoding: 29.051 tokens-per-sec
Qwen2.5-Coder-0.5B-Instruct-MLX-4Bit alone: 284.647 tokens-per-sec

In the same situation on an M3 Pro, 32GB of ram, we see tremendous speedup (~7tok/sec -> ~16tok/sec)

Full logs:

Click to expand

(venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-4bit -m 1000 --temp 0

==========
Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]  # Choose the middle element as the pivot
        left = [x for x in arr if x < pivot]  # Elements less than the pivot
        middle = [x for x in arr if x == pivot]  # Elements equal to the pivot
        right = [x for x in arr if x > pivot]  # Elements greater than the pivot
        return quicksort(left) + middle + quicksort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)

Explanation:

  1. Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
  2. Pivot Selection: We choose the middle element of the array as the pivot.
  3. Partitioning: We create three lists:
    • left for elements less than the pivot.
    • middle for elements equal to the pivot.
    • right for elements greater than the pivot.
  4. Recursive Sorting: We recursively apply the quicksort function to the left and right lists and concatenate the results with the middle list.

This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:

def quicksort_inplace(arr, low, high):
    if low < high:
        pi = partition(arr, low, high)  # Partitioning index
        quicksort_inplace(arr, low, pi - 1)  # Sort left part
        quicksort_inplace(arr, pi + 1, high)  # Sort right part

def partition(arr, low, high):
    pivot = arr[high]  # Choose the last element as the pivot
    i = low - 1  # Index of smaller element
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]  # Swap
    arr[i + 1], arr[high] = arr[high], arr[i + 1]  # Swap pivot element
    return i + 1

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)

In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.

==========
Prompt: 34 tokens, 71.386 tokens-per-sec
Generation: 709 tokens, 29.051 tokens-per-sec
Peak memory: 18.932 GB
(venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

==========
Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]  # Choose the middle element as the pivot
        left = [x for x in arr if x < pivot]  # Elements less than the pivot
        middle = [x for x in arr if x == pivot]  # Elements equal to the pivot
        right = [x for x in arr if x > pivot]  # Elements greater than the pivot
        return quicksort(left) + middle + quicksort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)

Explanation:

  1. Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
  2. Pivot Selection: We choose the middle element of the array as the pivot.
  3. Partitioning: We create three lists:
    • left for elements less than the pivot.
    • middle for elements equal to the pivot.
    • right for elements greater than the pivot.
  4. Recursive Sorting: We recursively apply the quicksort function to the left and right lists and concatenate the results with the middle list.

This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:

def quicksort_inplace(arr, low, high):
    if low < high:
        pi = partition(arr, low, high)  # Partitioning index
        quicksort_inplace(arr, low, pi - 1)  # Sort left part
        quicksort_inplace(arr, pi + 1, high)  # Sort right part

def partition(arr, low, high):
    pivot = arr[high]  # Choose the last element as the pivot
    i = low - 1  # Index of smaller element
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]  # Swap
    arr[i + 1], arr[high] = arr[high], arr[i + 1]  # Swap pivot element
    return i + 1

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)

In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.

==========
Prompt: 34 tokens, 75.790 tokens-per-sec
Generation: 709 tokens, 29.803 tokens-per-sec
Peak memory: 18.643 GB
(venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

==========
Sure, here's a simple implementation of the quicksort algorithm in Python:

def quicksort(arr):
    # Base case: if the array is empty or has one element, it's already sorted
    if len(arr) <= 1:
        return arr
    
    # Choose a pivot element
    pivot = arr[len(arr) // 2]
    
    # Partition the array into two sub-arrays: elements less than or equal to the pivot and elements greater than or equal to the pivot
    less_than_pivot = [x for x in arr if x <= pivot]
    greater_than_pivot = [x for x in arr if x > pivot]
    
    # Recursively sort the two sub-arrays
    quicksort(less_than_pivot)
    quicksort(greater_than_pivot)
    
    # Merge the sorted sub-arrays
    return less_than_pivot + [pivot] + greater_than_pivot

This function takes an array as input and returns a new array sorted in ascending order. It uses a simple partitioning strategy: it selects a pivot element and partitions the array into two sub-arrays: all elements less than or equal to the pivot and all elements greater than or equal to the pivot. The function then recursively sorts the two sub-arrays and merges them to form the sorted array.

==========
Prompt: 34 tokens, 683.032 tokens-per-sec
Generation: 276 tokens, 284.647 tokens-per-sec
Peak memory: 0.299 GB

@awni
Copy link
Member

awni commented Feb 12, 2025

I ran a couple benchmarks on M3 max and M2 Ultra. As expected we get much better scaling of the big model w.r.t. sequence length on M3 max than M2 Ultra. This probably explains why we are seeing little to no performance improvement on M2 Ultra.

In the figure below you see time as you increase sequence length. You want the line to be as flat as possible for the best possible speedup with speculative generation.

Image

@awni
Copy link
Member

awni commented Feb 12, 2025

On the optimistic side, from conversations @angeloskath and @barronalex there is likely room to improve small batch qmm which should help this use case considerably.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants